m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/encoder.py90
-rw-r--r--train/train.py97
2 files changed, 104 insertions, 83 deletions
diff --git a/model/encoder.py b/model/encoder.py
index 63a5149..85b3141 100644
--- a/model/encoder.py
+++ b/model/encoder.py
@@ -4,45 +4,53 @@ from torch import nn
from util.util import get_positional_encoding
from model.encoder_layer import EncoderLayer
+def log(string, verbose):
+ if verbose:
+ print(string)
+
class EncoderModel(nn.Module):
- def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers,
- num_heads, use_attention=True, use_feedforward=True,
- use_positional=True, device='cpu'):
- super(EncoderModel, self).__init__()
- self._device = device
- self._use_positional = use_positional
- self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
- self.layers = nn.ModuleList([
- EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
- use_feedforward) for i in range(n_layers)])
- self.output_layer = nn.Linear(hidden_dim, output_dim)
-
- def forward(self, x, return_att_weights=False):
- # x shape: (seqlen, batch)
- hidden = self.embedding_layer(x)
- # hidden shape: (seqlen, batch, hiddendim)
-
- if self._use_positional:
- positional_encoding = get_positional_encoding(
- n_positions=hidden.shape[0],
- n_dimensions=hidden.shape[-1],
- device=self._device
- )
- # reshaping to (seqlen, 1, hiddendim)
- positional_encoding = torch.reshape(
- positional_encoding,
- (hidden.shape[0], 1, hidden.shape[-1])
- )
- hidden = hidden + positional_encoding
-
- list_att_weights = []
- for layer in self.layers:
- hidden, att_weights = layer(hidden)
- list_att_weights.append(att_weights)
-
- result = self.output_layer(hidden)
-
- if return_att_weights:
- return result, list_att_weights
- else:
- return result
+ def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers,
+ num_heads, use_attention=True, use_feedforward=True,
+ use_positional=True, device='cpu'):
+ super(EncoderModel, self).__init__()
+ self._device = device
+ self._use_positional = use_positional
+ self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
+ self.layers = nn.ModuleList([
+ EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
+ use_feedforward) for i in range(n_layers)
+ ])
+ self.output_layer = nn.Linear(hidden_dim, output_dim)
+
+ def forward(self, x, return_att_weights=False, verbose=False):
+ log(f'Handling {x}', verbose)
+ # x shape: (seqlen, batch)
+ hidden = self.embedding_layer(x)
+ # hidden shape: (seqlen, batch, hiddendim)
+
+ if self._use_positional:
+ positional_encoding = get_positional_encoding(
+ n_positions=hidden.shape[0],
+ n_dimensions=hidden.shape[-1],
+ device=self._device
+ )
+ # reshaping to (seqlen, 1, hiddendim)
+ positional_encoding = torch.reshape(
+ positional_encoding,
+ (hidden.shape[0], 1, hidden.shape[-1])
+ )
+ hidden = hidden + positional_encoding
+
+ list_att_weights = []
+ for layer in self.layers:
+ hidden, att_weights = layer(hidden)
+ list_att_weights.append(att_weights)
+
+ result = self.output_layer(hidden)
+
+ log('Result: {result}', verbose)
+
+ if return_att_weights:
+ return result, list_att_weights
+ else:
+ return result
diff --git a/train/train.py b/train/train.py
index a88129a..1d8f26e 100644
--- a/train/train.py
+++ b/train/train.py
@@ -7,45 +7,58 @@ from torch import optim
from data.generate import get_single_example
from data.testset import get_testset
-def train_model(model, lr, num_steps, batch_size, device='cpu'):
- model.to(device)
-
- start_time = time()
- accs = []
-
- loss_function = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=lr)
-
- test_X, test_Y = get_testset()
-
- for step in range(num_steps):
- batch_examples = [get_single_example() for i in range(batch_size)]
-
- batch_X = torch.tensor([x[0] for x in batch_examples],
- device=device
- ).transpose(0, 1)
- batch_Y = torch.tensor([x[1] for x in batch_examples],
- device=device).transpose(0, 1)
-
- model.train()
- model.zero_grad()
- logits = model(batch_X)
- loss = loss_function(logits.reshape(-1, 10), batch_Y.reshape(-1))
- loss.backward()
- optimizer.step()
-
- if step % (num_steps//100) == 0 or step == num_steps - 1:
- # Printing a summary of the current state of training every 1% of steps.
- model.eval()
- predicted_logits = model.forward(test_X).reshape(-1, 10)
- test_acc = (
- torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
- / test_Y.reshape(-1).shape[0])
- print('step', step, 'out of', num_steps)
- print('loss train', float(loss))
- print('accuracy test', float(test_acc))
- print()
- accs.append(test_acc)
- print('\nTRAINING TIME:', time()-start_time)
- model.eval()
- return accs
+def do_verbose_test(model, n_tokens, seqlen, max_count):
+ print('verbose test:')
+ x, y = get_single_example(n_tokens, seqlen, max_count)
+ x = torch.tensor([x]).transpose(0, 1)
+ print('in :', x.squeeze())
+ print('expected out:', torch.tensor(y))
+ print('model out :', torch.argmax(model(x), dim=2).squeeze())
+
+def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, device='cpu'):
+ torch.autograd.set_detect_anomaly(True)
+ model.to(device)
+
+ start_time = time()
+ accs = []
+
+ loss_function = nn.CrossEntropyLoss(
+ # weight=torch.log(2 + torch.tensor(range(max_count+1), dtype=torch.float))
+ )
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+
+ test_X, test_Y = get_testset(n_tokens, seqlen, max_count)
+ print('test size', test_X.shape)
+
+ for step in range(num_steps):
+ batch_examples = [get_single_example(n_tokens, seqlen, max_count) for i in range(batch_size)]
+
+ batch_X = torch.tensor([x[0] for x in batch_examples],
+ device=device
+ ).transpose(0, 1)
+ batch_Y = torch.tensor([x[1] for x in batch_examples],
+ device=device).transpose(0, 1)
+
+ model.train()
+ model.zero_grad()
+ logits = model(batch_X)
+ loss = loss_function(logits.reshape(-1, max_count + 1), batch_Y.reshape(-1))
+ loss.backward()
+ optimizer.step()
+
+ if step % (num_steps//100) == 0 or step == num_steps - 1:
+ # Printing a summary of the current state of training every 1% of steps.
+ model.eval()
+ predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
+ test_acc = (
+ torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
+ / test_Y.reshape(-1).shape[0])
+ print('step', step, 'out of', num_steps)
+ print('loss train', float(loss))
+ print('accuracy test', float(test_acc))
+ do_verbose_test(model, n_tokens, seqlen, max_count)
+ print()
+ accs.append(test_acc)
+ print('\nTRAINING TIME:', time()-start_time)
+ model.eval()
+ return accs