diff options
Diffstat (limited to 'train')
-rw-r--r-- | train/__main__.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/train/__main__.py b/train/__main__.py index ec544f7..00df2fe 100644 --- a/train/__main__.py +++ b/train/__main__.py @@ -1,5 +1,25 @@ from train.train import train_model from model.encoder import EncoderModel -model = EncoderModel(16, 64, 128, 10, 4, 4) -train_model(model, 0.1, 100, 10) +n_tokens = 4 +max_count = 3 + +model = EncoderModel( + input_dim=n_tokens, + hidden_dim=128, + d_ff=256, + output_dim=max_count + 1, + n_layers=6, + num_heads=8, + # use_positional=False, + # use_feedforward=False +) +train_model( + model, + lr=0.00001, + num_steps=1000, + batch_size=10, + n_tokens=n_tokens, + seqlen=16, + max_count=max_count +) |