blob: 56a3a0c82ca3b81fdb894195eec46a1c2943d54b (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
from torch import nn
class EncoderLayer(nn.Module):
def __init__(self, hidden_dim, d_ff, num_heads, use_attention=True,
use_feedforward=True):
super(EncoderLayer, self).__init__()
# TODO: implement a single encoder layer, using Attention and FeedForward.
pass
def forward(self, x):
# x shape: (seqlen, batch, hiddendim)
# TODO: implement a single encoder layer, using Attention and FeedForward.
result, att_weights = x, None # placeholder
pass
return result, att_weights
|