]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blobdiff - common_plot.py
doc update
[jacinto-ai/caffe-jacinto.git] / common_plot.py
index 48d17adf0264225d431071a7a9f5f574a7d30b25..28479e258f3170f47a6c5c2e1df59df768d623ec 100644 (file)
@@ -38,6 +38,23 @@ def get_train_loss(log):
     loss = [float(i) for i in loss]
     return iteration, loss
 
+def get_epochs(log):
+    gpus = re.findall(r' GPU (\d*):', log)
+    num_gpus = len(gpus)
+    #print num_gpus
+    max_iter = re.findall(r'max_iter: (\d*)', log)
+    iter_size = re.findall(r'iter_size: (\d*)', log)
+    batch_size = re.findall(r'batch_size: (\d*)',log)
+    max_iter = int(max_iter[0])
+    if len(iter_size) >0:
+        iter_size=int(iter_size[0])
+    else:
+        iter_size=1
+
+    batch_size = int(batch_size[0])
+   # print max_iter, iter_size, batch_size
+    num_epochs = int(round( (max_iter * iter_size * batch_size*num_gpus) /  1281167. +0.5))
+    return max_iter, num_epochs
 
 def get_net_name(log):
     return re.findall(r"Solving (.*)\n", log)[0]
@@ -61,15 +78,23 @@ def parse_files(files, top_k=1, separate=False):
                 data[net_name]["train_loss"]["loss"] = []
                 data[net_name]["train_loss"]["iteration"] = []
 
+            max_iter, epochs = get_epochs(log)
+            #print epochs
+            scale = float(epochs) / max_iter
+
             iteration, accuracy = get_test_accuracy(log, top_k)
+            iteration = [k*scale for k in iteration]
             data[net_name]["accuracy"]["iteration"].extend(iteration)
             data[net_name]["accuracy"]["accuracy"].extend(accuracy)
 
             iteration, loss = get_test_loss(log)
+            iteration = [k*scale for k in iteration]
             data[net_name]["loss"]["iteration"].extend(iteration)
             data[net_name]["loss"]["loss"].extend(loss)
 
+
             iteration, loss = get_train_loss(log)
+            iteration = [k*scale for k in iteration]
             data[net_name]["train_loss"]["iteration"].extend(iteration)
             data[net_name]["train_loss"]["loss"].extend(loss)
 
@@ -167,7 +192,7 @@ def plot_accuracy(top_k, data, value_at_hover=False):
 
     plt.legend(nets, loc='lower right')
     plt.title("Top {}".format(top_k))
-    plt.xlabel("Iteration")
+    plt.xlabel("Epochs")
     plt.ylabel("Accuracy [%]")
     plt.ylim(0,100)
     plt.grid()