m-chrzan.xyz
aboutsummaryrefslogtreecommitdiff
path: root/train/__main__.py
diff options
context:
space:
mode:
authorMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-27 19:53:20 +0200
committerMarcin Chrzanowski <m@m-chrzan.xyz>2021-05-27 19:53:20 +0200
commit8b5f31f9aeeefe647e00a2c581ca69378210f12b (patch)
tree4d880999fee217a798ec887e44dcb85ee66fbea6 /train/__main__.py
parentfdf4bff14c1d694dbfae4e82c05226a7c9135ca5 (diff)
Better parametrize train script
Diffstat (limited to 'train/__main__.py')
-rw-r--r--train/__main__.py24
1 files changed, 22 insertions, 2 deletions
diff --git a/train/__main__.py b/train/__main__.py
index ec544f7..00df2fe 100644
--- a/train/__main__.py
+++ b/train/__main__.py
@@ -1,5 +1,25 @@
from train.train import train_model
from model.encoder import EncoderModel
-model = EncoderModel(16, 64, 128, 10, 4, 4)
-train_model(model, 0.1, 100, 10)
+n_tokens = 4
+max_count = 3
+
+model = EncoderModel(
+ input_dim=n_tokens,
+ hidden_dim=128,
+ d_ff=256,
+ output_dim=max_count + 1,
+ n_layers=6,
+ num_heads=8,
+ # use_positional=False,
+ # use_feedforward=False
+)
+train_model(
+ model,
+ lr=0.00001,
+ num_steps=1000,
+ batch_size=10,
+ n_tokens=n_tokens,
+ seqlen=16,
+ max_count=max_count
+)