diff options
author | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-27 19:53:20 +0200 |
---|---|---|
committer | Marcin Chrzanowski <m@m-chrzan.xyz> | 2021-05-27 19:53:20 +0200 |
commit | 8b5f31f9aeeefe647e00a2c581ca69378210f12b (patch) | |
tree | 4d880999fee217a798ec887e44dcb85ee66fbea6 /train/__main__.py | |
parent | fdf4bff14c1d694dbfae4e82c05226a7c9135ca5 (diff) |
Better parametrize train script
Diffstat (limited to 'train/__main__.py')
-rw-r--r-- | train/__main__.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/train/__main__.py b/train/__main__.py index ec544f7..00df2fe 100644 --- a/train/__main__.py +++ b/train/__main__.py @@ -1,5 +1,25 @@ from train.train import train_model from model.encoder import EncoderModel -model = EncoderModel(16, 64, 128, 10, 4, 4) -train_model(model, 0.1, 100, 10) +n_tokens = 4 +max_count = 3 + +model = EncoderModel( + input_dim=n_tokens, + hidden_dim=128, + d_ff=256, + output_dim=max_count + 1, + n_layers=6, + num_heads=8, + # use_positional=False, + # use_feedforward=False +) +train_model( + model, + lr=0.00001, + num_steps=1000, + batch_size=10, + n_tokens=n_tokens, + seqlen=16, + max_count=max_count +) |