diff options
-rw-r--r-- | model/attention.py | 7 | ||||
-rw-r--r-- | model/encoder.py | 11 | ||||
-rw-r--r-- | model/encoder_layer.py | 4 | ||||
-rw-r--r-- | train/train.py | 3 | ||||
-rw-r--r-- | util/util.py | 2 |
5 files changed, 19 insertions, 8 deletions
diff --git a/model/attention.py b/model/attention.py index ffc07d3..75ff5a0 100644 --- a/model/attention.py +++ b/model/attention.py @@ -31,8 +31,9 @@ class Head(nn.Module): return value, weights class Attention(nn.Module): - def __init__(self, hidden_dim, num_heads): + def __init__(self, hidden_dim, num_heads, device): super(Attention, self).__init__() + self._device = device self._num_heads = num_heads self._head_output_dim = hidden_dim // num_heads # ensure hidden_dim is divisible by num_heads @@ -45,9 +46,9 @@ class Attention(nn.Module): def forward(self, x): # x shape: (seqlen, batch, hiddendim) - result = torch.zeros(x.shape) + result = torch.zeros(x.shape).to(self._device) # attentions are (heads, seqlen, batch, seqlen) - attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]) + attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]).to(self._device) for i in range(self._num_heads): from_index = i * self._head_output_dim to_index = from_index + self._head_output_dim diff --git a/model/encoder.py b/model/encoder.py index 85b3141..d6527dd 100644 --- a/model/encoder.py +++ b/model/encoder.py @@ -17,14 +17,21 @@ class EncoderModel(nn.Module): self._use_positional = use_positional self.embedding_layer = nn.Embedding(input_dim, hidden_dim) self.layers = nn.ModuleList([ - EncoderLayer(hidden_dim, d_ff, num_heads, use_attention, - use_feedforward) for i in range(n_layers) + EncoderLayer( + hidden_dim, + d_ff, + num_heads, + use_attention, + use_feedforward, + device=device + ) for i in range(n_layers) ]) self.output_layer = nn.Linear(hidden_dim, output_dim) def forward(self, x, return_att_weights=False, verbose=False): log(f'Handling {x}', verbose) # x shape: (seqlen, batch) + x = x.to(self._device) hidden = self.embedding_layer(x) # hidden shape: (seqlen, batch, hiddendim) diff --git a/model/encoder_layer.py b/model/encoder_layer.py index 71a7d8f..311c39c 100644 --- a/model/encoder_layer.py +++ b/model/encoder_layer.py @@ -5,12 +5,12 @@ from model.forward import FeedForward class EncoderLayer(nn.Module): def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True, - use_feedforward=True): + use_feedforward=True, device='cpu'): super(EncoderLayer, self).__init__() self._use_attention = use_attention self._use_feedforward = use_feedforward if use_attention: - self.attention = Attention(hidden_dim, num_heads) + self.attention = Attention(hidden_dim, num_heads, device) if use_feedforward: self.feedforward = FeedForward(hidden_dim, d_ff) diff --git a/train/train.py b/train/train.py index 1d8f26e..e72fb3f 100644 --- a/train/train.py +++ b/train/train.py @@ -1,4 +1,5 @@ from time import time +import sys import torch from torch import nn @@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d # Printing a summary of the current state of training every 1% of steps. model.eval() predicted_logits = model.forward(test_X).reshape(-1, max_count + 1) + predicted_logits = predicted_logits.to('cpu') test_acc = ( torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1)) / test_Y.reshape(-1).shape[0]) @@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d print('accuracy test', float(test_acc)) do_verbose_test(model, n_tokens, seqlen, max_count) print() + sys.stdout.flush() accs.append(test_acc) print('\nTRAINING TIME:', time()-start_time) model.eval() diff --git a/util/util.py b/util/util.py index 65f1838..bb9e0c3 100644 --- a/util/util.py +++ b/util/util.py @@ -14,4 +14,4 @@ def get_positional_encoding(n_positions, n_dimensions, device='cpu'): positional_encoding[:, ::2] = torch.sin(positional_encoding[:, ::2]) positional_encoding[:, 1::2] = torch.cos(positional_encoding[:, 1::2]) # output shape: (seqlen, hiddendim) - return positional_encoding + return positional_encoding.to(device) |