From 8b5f31f9aeeefe647e00a2c581ca69378210f12b Mon Sep 17 00:00:00 2001 From: Marcin Chrzanowski Date: Thu, 27 May 2021 19:53:20 +0200 Subject: Better parametrize train script --- train/__main__.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) (limited to 'train/__main__.py') diff --git a/train/__main__.py b/train/__main__.py index ec544f7..00df2fe 100644 --- a/train/__main__.py +++ b/train/__main__.py @@ -1,5 +1,25 @@ from train.train import train_model from model.encoder import EncoderModel -model = EncoderModel(16, 64, 128, 10, 4, 4) -train_model(model, 0.1, 100, 10) +n_tokens = 4 +max_count = 3 + +model = EncoderModel( + input_dim=n_tokens, + hidden_dim=128, + d_ff=256, + output_dim=max_count + 1, + n_layers=6, + num_heads=8, + # use_positional=False, + # use_feedforward=False +) +train_model( + model, + lr=0.00001, + num_steps=1000, + batch_size=10, + n_tokens=n_tokens, + seqlen=16, + max_count=max_count +) -- cgit v1.2.3