m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-23 13:18:52 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-23 13:18:52 +0200
commit49f264647b6073c304936a95fea1704a8c0965dc (patch)
tree44b19e5f881076a6b15fa1a755d2dc1e6d9307e9 /model
parent8ff8739b236a00169339b0b78e1f39357fdfff17 (diff)
Add model skeleton
Diffstat (limited to 'model')
-rw-r--r--model/__init__.py0
-rw-r--r--model/attention.py15
-rw-r--r--model/encoder.py48
-rw-r--r--model/encoder_layer.py15
-rw-r--r--model/forward.py14
5 files changed, 92 insertions, 0 deletions
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/model/__init__.py
diff --git a/model/attention.py b/model/attention.py
new file mode 100644
index 0000000..d82ca17
--- /dev/null
+++ b/model/attention.py
@@ -0,0 +1,15 @@
+from torch import nn
+
+class Attention(nn.Module):
+ def __init__(self, hidden_dim, num_heads):
+ super(Attention, self).__init__()
+ # TODO: implement Attention
+ pass
+
+ def forward(self, x):
+ # TODO: implement Attention; return both result of attention mechanism and
+ # attention weights (for visualization).
+ # x shape: (seqlen, batch, hiddendim)
+ result, att_weights = x, None # placeholder
+ pass
+ return result, att_weights
diff --git a/model/encoder.py b/model/encoder.py
new file mode 100644
index 0000000..63a5149
--- /dev/null
+++ b/model/encoder.py
@@ -0,0 +1,48 @@
+import torch
+from torch import nn
+
+from util.util import get_positional_encoding
+from model.encoder_layer import EncoderLayer
+
+class EncoderModel(nn.Module):
+ def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers,
+ num_heads, use_attention=True, use_feedforward=True,
+ use_positional=True, device='cpu'):
+ super(EncoderModel, self).__init__()
+ self._device = device
+ self._use_positional = use_positional
+ self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
+ self.layers = nn.ModuleList([
+ EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
+ use_feedforward) for i in range(n_layers)])
+ self.output_layer = nn.Linear(hidden_dim, output_dim)
+
+ def forward(self, x, return_att_weights=False):
+ # x shape: (seqlen, batch)
+ hidden = self.embedding_layer(x)
+ # hidden shape: (seqlen, batch, hiddendim)
+
+ if self._use_positional:
+ positional_encoding = get_positional_encoding(
+ n_positions=hidden.shape[0],
+ n_dimensions=hidden.shape[-1],
+ device=self._device
+ )
+ # reshaping to (seqlen, 1, hiddendim)
+ positional_encoding = torch.reshape(
+ positional_encoding,
+ (hidden.shape[0], 1, hidden.shape[-1])
+ )
+ hidden = hidden + positional_encoding
+
+ list_att_weights = []
+ for layer in self.layers:
+ hidden, att_weights = layer(hidden)
+ list_att_weights.append(att_weights)
+
+ result = self.output_layer(hidden)
+
+ if return_att_weights:
+ return result, list_att_weights
+ else:
+ return result
diff --git a/model/encoder_layer.py b/model/encoder_layer.py
new file mode 100644
index 0000000..56a3a0c
--- /dev/null
+++ b/model/encoder_layer.py
@@ -0,0 +1,15 @@
+from torch import nn
+
+class EncoderLayer(nn.Module):
+ def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True,
+ use_feedforward=True):
+ super(EncoderLayer, self).__init__()
+ # TODO: implement a single encoder layer, using Attention and FeedForward.
+ pass
+
+ def forward(self, x):
+ # x shape: (seqlen, batch, hiddendim)
+ # TODO: implement a single encoder layer, using Attention and FeedForward.
+ result, att_weights = x, None # placeholder
+ pass
+ return result, att_weights
diff --git a/model/forward.py b/model/forward.py
new file mode 100644
index 0000000..63ec6ea
--- /dev/null
+++ b/model/forward.py
@@ -0,0 +1,14 @@
+from torch import nn
+
+class FeedForward(nn.Module):
+ def __init__(self, hidden_dim, d_ff):
+ super(FeedForward, self).__init__()
+ # TODO: implement FeedForward layer
+ pass
+
+ def forward(self, x):
+ # TODO: implement
+ # x shape: (seqlen, batch, hiddendim)
+ result = x # placeholder
+ pass
+ return result