1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
|
import subprocess
import os
import time
import pandas as pd
import torch
from train.train import train_model
from util.parse_config import parse_file
from model.encoder import EncoderModel
class Experiment:
def __init__(self, file, prefix, device):
self.file = file
self.device = device
self.make_dir(prefix)
self.copy_config(file)
def run(self):
model_config, train_config = parse_file(self.file)
model = EncoderModel(device=self.device, **model_config).to(self.device)
train_losses, test_losses, accuracies = train_model(model, device=self.device, **train_config)
self.save_model(model)
self.save_metrics(train_losses, test_losses, accuracies)
def save_model(self, model):
torch.save(model.state_dict(), self.dir_path('net.pt'))
def dir_path(self, file):
return '{}/{}'.format(self.dirname, file)
def save_metrics(self, train_losses, test_losses, accuracies):
data_frame = pd.DataFrame({
'train_loss': train_losses,
'test_loss': test_losses,
'accuracy': accuracies
})
data_frame.to_csv(self.dir_path('metrics.csv'))
def make_dir(self, prefix):
time_string = time.strftime('%Y%m%d%H%M%S')
prefix = '' if prefix == '' else '{}-'.format(prefix)
dirname = 'outputs/{}{}'.format(prefix, time_string)
self.dirname = dirname
os.mkdir(dirname)
def copy_config(self, file):
subprocess.run(['cp', file, '{}/config.yaml'.format(self.dirname)])
|