m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/model/encoder_layer.py
blob: 71a7d8f5ebf573cdea73d530b55d8dd90ff7a59a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from torch import nn

from model.attention import Attention
from model.forward import FeedForward

class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True,
                   use_feedforward=True):
        super(EncoderLayer, self).__init__()
        self._use_attention = use_attention
        self._use_feedforward = use_feedforward
        if use_attention:
            self.attention = Attention(hidden_dim, num_heads)
        if use_feedforward:
            self.feedforward = FeedForward(hidden_dim, d_ff)

    def forward(self, x):
        weights = None
        if self._use_attention:
            y, weights = self.attention(x)
            x = x + y
        if self._use_feedforward:
            y = self.feedforward(x)
            x = x + y

        return x, weights