m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model/attention.py7
-rw-r--r--model/encoder.py11
-rw-r--r--model/encoder_layer.py4
-rw-r--r--train/train.py3
-rw-r--r--util/util.py2
5 files changed, 19 insertions, 8 deletions
diff --git a/model/attention.py b/model/attention.py
index ffc07d3..75ff5a0 100644
--- a/model/attention.py
+++ b/model/attention.py
@@ -31,8 +31,9 @@ class Head(nn.Module):
return value, weights
class Attention(nn.Module):
- def __init__(self, hidden_dim, num_heads):
+ def __init__(self, hidden_dim, num_heads, device):
super(Attention, self).__init__()
+ self._device = device
self._num_heads = num_heads
self._head_output_dim = hidden_dim // num_heads
# ensure hidden_dim is divisible by num_heads
@@ -45,9 +46,9 @@ class Attention(nn.Module):
def forward(self, x):
# x shape: (seqlen, batch, hiddendim)
- result = torch.zeros(x.shape)
+ result = torch.zeros(x.shape).to(self._device)
# attentions are (heads, seqlen, batch, seqlen)
- attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0])
+ attentions = torch.zeros(self._num_heads, x.shape[0], x.shape[1], x.shape[0]).to(self._device)
for i in range(self._num_heads):
from_index = i * self._head_output_dim
to_index = from_index + self._head_output_dim
diff --git a/model/encoder.py b/model/encoder.py
index 85b3141..d6527dd 100644
--- a/model/encoder.py
+++ b/model/encoder.py
@@ -17,14 +17,21 @@ class EncoderModel(nn.Module):
self._use_positional = use_positional
self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
self.layers = nn.ModuleList([
- EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
- use_feedforward) for i in range(n_layers)
+ EncoderLayer(
+ hidden_dim,
+ d_ff,
+ num_heads,
+ use_attention,
+ use_feedforward,
+ device=device
+ ) for i in range(n_layers)
])
self.output_layer = nn.Linear(hidden_dim, output_dim)
def forward(self, x, return_att_weights=False, verbose=False):
log(f'Handling {x}', verbose)
# x shape: (seqlen, batch)
+ x = x.to(self._device)
hidden = self.embedding_layer(x)
# hidden shape: (seqlen, batch, hiddendim)
diff --git a/model/encoder_layer.py b/model/encoder_layer.py
index 71a7d8f..311c39c 100644
--- a/model/encoder_layer.py
+++ b/model/encoder_layer.py
@@ -5,12 +5,12 @@ from model.forward import FeedForward
class EncoderLayer(nn.Module):
def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True,
- use_feedforward=True):
+ use_feedforward=True, device='cpu'):
super(EncoderLayer, self).__init__()
self._use_attention = use_attention
self._use_feedforward = use_feedforward
if use_attention:
- self.attention = Attention(hidden_dim, num_heads)
+ self.attention = Attention(hidden_dim, num_heads, device)
if use_feedforward:
self.feedforward = FeedForward(hidden_dim, d_ff)
diff --git a/train/train.py b/train/train.py
index 1d8f26e..e72fb3f 100644
--- a/train/train.py
+++ b/train/train.py
@@ -1,4 +1,5 @@
from time import time
+import sys
import torch
from torch import nn
@@ -50,6 +51,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
# Printing a summary of the current state of training every 1% of steps.
model.eval()
predicted_logits = model.forward(test_X).reshape(-1, max_count + 1)
+ predicted_logits = predicted_logits.to('cpu')
test_acc = (
torch.sum(torch.argmax(predicted_logits, dim=-1) == test_Y.reshape(-1))
/ test_Y.reshape(-1).shape[0])
@@ -58,6 +60,7 @@ def train_model(model, lr, num_steps, batch_size, n_tokens, seqlen, max_count, d
print('accuracy test', float(test_acc))
do_verbose_test(model, n_tokens, seqlen, max_count)
print()
+ sys.stdout.flush()
accs.append(test_acc)
print('\nTRAINING TIME:', time()-start_time)
model.eval()
diff --git a/util/util.py b/util/util.py
index 65f1838..bb9e0c3 100644
--- a/util/util.py
+++ b/util/util.py
@@ -14,4 +14,4 @@ def get_positional_encoding(n_positions, n_dimensions, device='cpu'):
positional_encoding[:, ::2] = torch.sin(positional_encoding[:, ::2])
positional_encoding[:, 1::2] = torch.cos(positional_encoding[:, 1::2])
# output shape: (seqlen, hiddendim)
- return positional_encoding
+ return positional_encoding.to(device)