blob: 311c39cab95fec803b87d2497b2212eb2bb19d66 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
from torch import nn
from model.attention import Attention
from model.forward import FeedForward
class EncoderLayer(nn.Module):
def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True,
use_feedforward=True, device='cpu'):
super(EncoderLayer, self).__init__()
self._use_attention = use_attention
self._use_feedforward = use_feedforward
if use_attention:
self.attention = Attention(hidden_dim, num_heads, device)
if use_feedforward:
self.feedforward = FeedForward(hidden_dim, d_ff)
def forward(self, x):
weights = None
if self._use_attention:
y, weights = self.attention(x)
x = x + y
if self._use_feedforward:
y = self.feedforward(x)
x = x + y
return x, weights
|