m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/encoder.py
blob: 85b3141007af4713a40a2694ab876d632ecd01a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch import nn

from util.util import get_positional_encoding
from model.encoder_layer import EncoderLayer

def log(string, verbose):
    if verbose:
        print(string)

class EncoderModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_ff, output_dim, n_layers,
                 num_heads, use_attention=True, use_feedforward=True,
                 use_positional=True, device='cpu'):
        super(EncoderModel, self).__init__()
        self._device = device
        self._use_positional = use_positional
        self.embedding_layer = nn.Embedding(input_dim, hidden_dim)
        self.layers = nn.ModuleList([
            EncoderLayer(hidden_dim, d_ff, num_heads, use_attention,
                                     use_feedforward) for i in range(n_layers)
        ])
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, return_att_weights=False, verbose=False):
        log(f'Handling {x}', verbose)
        # x shape: (seqlen, batch)
        hidden = self.embedding_layer(x)
        # hidden shape: (seqlen, batch, hiddendim)

        if self._use_positional:
            positional_encoding = get_positional_encoding(
                    n_positions=hidden.shape[0],
                    n_dimensions=hidden.shape[-1],
                    device=self._device
                )
            # reshaping to (seqlen, 1, hiddendim)
            positional_encoding = torch.reshape(
                    positional_encoding,
                    (hidden.shape[0], 1, hidden.shape[-1])
            )
            hidden = hidden + positional_encoding

        list_att_weights = []
        for layer in self.layers:
            hidden, att_weights = layer(hidden)
            list_att_weights.append(att_weights)

        result = self.output_layer(hidden)

        log('Result: {result}', verbose)

        if return_att_weights:
            return result, list_att_weights
        else:
            return result