blob: 00df2fec46304cfe48b79772faa3d95ae47cfee4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
from train.train import train_model
from model.encoder import EncoderModel
n_tokens = 4
max_count = 3
model = EncoderModel(
input_dim=n_tokens,
hidden_dim=128,
d_ff=256,
output_dim=max_count + 1,
n_layers=6,
num_heads=8,
# use_positional=False,
# use_feedforward=False
)
train_model(
model,
lr=0.00001,
num_steps=1000,
batch_size=10,
n_tokens=n_tokens,
seqlen=16,
max_count=max_count
)
|