m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train/__main__.py
blob: 00df2fec46304cfe48b79772faa3d95ae47cfee4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from train.train import train_model
from model.encoder import EncoderModel

n_tokens = 4
max_count = 3

model = EncoderModel(
    input_dim=n_tokens,
    hidden_dim=128,
    d_ff=256,
    output_dim=max_count + 1,
    n_layers=6,
    num_heads=8,
    # use_positional=False,
    # use_feedforward=False
)
train_model(
    model,
    lr=0.00001,
    num_steps=1000,
    batch_size=10,
    n_tokens=n_tokens,
    seqlen=16,
    max_count=max_count
)