m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-02 20:35:37 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-02 20:35:37 +0200
commitc0954ee9ec700e0a0ebf4036a2bfccbcfaa7bce6 (patch)
tree1048c2f969308a992a7940165f278e10455ac0cb
parent3bae8f62fd5879b8448c0edb0758e2483b09b197 (diff)
Save trained model
-rw-r--r--src/experiment.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/src/experiment.py b/src/experiment.py
index b508438..f8a219a 100644
--- a/src/experiment.py
+++ b/src/experiment.py
@@ -4,6 +4,7 @@ import time
import pandas as pd
import matplotlib.pyplot as plt
+import torch
from runner import Runner
@@ -35,6 +36,8 @@ class Experiment:
plt.legend()
plt.savefig(self.dir_path('accuracies.png'))
+ torch.save(self.runner.net.state_dict(), self.dir_path('net.pt'))
+
def dir_path(self, file):
return '{}/{}'.format(self.dirname, file)