diff options
-rw-r--r-- | model/encoder.py | 90 | ||||
-rw-r--r-- | train/train.py | 97 |
2 files changed, 104 insertions, 83 deletions
diff --git a/model/encoder.py b/model/encoder.py index 63a5149..85b3141 100644 --- a/model/encoder.py +++ b/model/encoder.py @@ -4,45 +4,53 @@ from torch import nn from util.util import get_positional_encoding from model.encoder_layer import EncoderLayer +def log(string, verbose): + if verbose: + print(string) + class EncoderModel(nn.Module): - def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers, - num_heads, use_attention=True, use_feedforward=True, - use_positional=True, device='cpu'): - super(EncoderModel, self).__init__() - self._device = device - self._use_positional = use_positional - self.embedding_layer = nn.Embedding(input_dim, hidden_dim) - self.layers = nn.ModuleList([ - EncoderLayer(hidden_dim, d_ff, num_heads, use_attention, - use_feedforward) for i in range(n_layers)]) - self.output_layer = nn.Linear(hidden_dim, output_dim) - - def forward(self, x, return_att_weights=False): - # x shape: (seqlen, batch) - hidden = self.embedding_layer(x) - # hidden shape: (seqlen, batch, hiddendim) - - if self._use_positional: - positional_encoding = get_positional_encoding( - n_positions=hidden.shape[0], - n_dimensions=hidden.shape[-1], - device=self._device - ) - # reshaping to (seqlen, 1, hiddendim) - positional_encoding = torch.reshape( - positional_encoding, - (hidden.shape[0], 1, hidden.shape[-1]) - ) - hidden = hidden + positional_encoding - - list_att_weights = [] - for layer in self.layers: - hidden, att_weights = layer(hidden) - list_att_weights.append(att_weights) - - result = self.output_layer(hidden) - - if return_att_weights: - return result, list_att_weights - else: - return result + def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers, + num_heads, use_attention=True, use_feedforward=True, + use_positional=True, device='cpu'): + super(EncoderModel, self).__init__() + self._device = device + self._use_positional = use_positional + self.embedding_layer = nn.Embedding(input_dim, hidden_dim) + self.layers = nn.ModuleList([ + EncoderLayer(hidden_dim, d_ff, num_heads, use_attention, + use_feedforward) for i in range(n_layers) + ]) + self.output_layer = nn.Linear(hidden_dim, output_dim) + + def forward(self, x, return_att_weights=False, verbose=False): + log(f'Handling {x}', verbose) + # x shape: (seqlen, batch) + hidden = self.embedding_layer(x) + # hidden shape: (seqlen, batch, hiddendim) + + if self._use_positional: + positional_encoding = get_positional_encoding( + n_positions=hidden.shape[0], + n_dimensions=hidden.shape[-1], + device=self._device + ) + # reshaping to (seqlen, 1, hiddendim) + positional_encoding = torch.reshape( + positional_encoding, + (hidden.shape[0], 1, hidden.shape[-1]) + ) + hidden = hidden + positional_encoding + + list_att_weights = [] + for layer in self.layers: + hidden, att_weights = layer(hidden) + list_att_weights.append(att_weights) + + result = self.output_layer(hidden) + + log('Result: {result}', verbose) + + if return_att_weights: + return result, list_att_weights + else: + return result diff --git a/train/train.py b/train/train.py index a88129a..1d8f26e 100644 --- a/train/train.py +++ b/train/train.py @@ -7,45 +7,58 @@ from torch import optim from data.generate import get_single_example from data.testset import get_testset -def train_model(model, lr, num_steps, batch_size, device='cpu'): - model.to(device) - - start_time = time() - accs = [] - - loss_function = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=lr) - - test_X, test_Y = get_testset() - - for step in range(num_steps): - batch_examples = [get_single_example() for i in range(batch_size)] - - batch_X = torch.tensor([x[0] for x in batch_examples], - device=device - ).transpose(0, 1) - batch_Y = torch.tensor([x[1] for x in batch_examples], - device=device).transpose(0, 1) - - model.train() - model.zero_grad() - logits = model(batch_X) - loss = loss_function(logits.reshape(-1, 10), batch_Y.reshape(-1)) - loss.backward() - optimizer.step() - - if step % (num_steps//100) == 0 or step == num_steps - 1: - # Printing a summary of the current state of training every 1% of steps. - model.eval() - predicted_logits = model.forward(test_X).reshape(-1, 10) - test_acc = ( - torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) - / test_Y.reshape(-1).shape[0]) - print('step', step, 'out of', num_steps) - print('loss train', float(loss)) - print('accuracy test', float(test_acc)) - print() - accs.append(test_acc) - print('\nTRAINING TIME:', time()-start_time) - model.eval() - return accs +def do_verbose_test(model, n_tokens, seqlen, max_count): + print('verbose test:') + x, y = get_single_example(n_tokens, seqlen, max_count) + x = torch.tensor([x]).transpose(0, 1) + print('in :', x.squeeze()) + print('expected out:', torch.tensor(y)) + print('model out :', torch.argmax(model(x), dim=2).squeeze()) + +def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, device='cpu'): + torch.autograd.set_detect_anomaly(True) + model.to(device) + + start_time = time() + accs = [] + + loss_function = nn.CrossEntropyLoss( + # weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float)) + ) + optimizer = optim.Adam(model.parameters(), lr=lr) + + test_X, test_Y = get_testset(n_tokens, seqlen, max_count) + print('test size', test_X.shape) + + for step in range(num_steps): + batch_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(batch_size)] + + batch_X = torch.tensor([x[0] for x in batch_examples], + device=device + ).transpose(0, 1) + batch_Y = torch.tensor([x[1] for x in batch_examples], + device=device).transpose(0, 1) + + model.train() + model.zero_grad() + logits = model(batch_X) + loss = loss_function(logits.reshape(-1, max_count + 1), batch_Y.reshape(-1)) + loss.backward() + optimizer.step() + + if step % (num_steps//100) == 0 or step == num_steps - 1: + # Printing a summary of the current state of training every 1% of steps. + model.eval() + predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) + test_acc = ( + torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) + / test_Y.reshape(-1).shape[0]) + print('step', step, 'out of', num_steps) + print('loss train', float(loss)) + print('accuracy test', float(test_acc)) + do_verbose_test(model, n_tokens, seqlen, max_count) + print() + accs.append(test_acc) + print('\nTRAINING TIME:', time()-start_time) + model.eval() + return accs |