diff --git a/README.md b/README.md index c0c0fd8..decab5f 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ Tiny Recursion Model (TRM) recursively improves its predicted answer y with a ti ### Requirements +Installation should take a few minutes. For the smallest experiments on Sudoku-Extreme (pretrain_mlp_t_sudoku), you need 1 GPU with enough memory. With 1 L40S (48Gb Ram), it takes around 18h to finish. In case that you run into issues due to library versions, here is the requirements with the exact versions used: [specific_requirements.txt](https://github.com/SamsungSAILMontreal/TinyRecursiveModels/blob/main/specific_requirements.txt). + - Python 3.10 (or similar) - Cuda 12.6.0 (or similar) @@ -59,6 +61,74 @@ python dataset/build_maze_dataset.py # 1000 examples, 8 augments ## Experiments +### Sudoku-Extreme (assuming 1 L40S GPU): + +```bash +run_name="pretrain_mlp_t_sudoku" +python pretrain.py \ +arch=trm \ +data_paths="[data/sudoku-extreme-1k-aug-1000]" \ +evaluators="[]" \ +epochs=50000 eval_interval=5000 \ +lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ +arch.mlp_t=True arch.pos_encodings=none \ +arch.L_layers=2 \ +arch.H_cycles=3 arch.L_cycles=6 \ ++run_name=${run_name} ema=True + +Expected: Around 87% exact-accuracy (+- 2%) + +run_name="pretrain_att_sudoku" +python pretrain.py \ +arch=trm \ +data_paths="[data/sudoku-extreme-1k-aug-1000]" \ +evaluators="[]" \ +epochs=50000 eval_interval=5000 \ +lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ +arch.L_layers=2 \ +arch.H_cycles=3 arch.L_cycles=6 \ ++run_name=${run_name} ema=True +``` + +Expected: Around 75% exact-accuracy (+- 2%) + +*Runtime:* < 20 hours + +### Maze-Hard (assuming 4 L40S GPUs): + +```bash +run_name="pretrain_att_maze30x30" +torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \ +arch=trm \ +data_paths="[data/maze-30x30-hard-1k]" \ +evaluators="[]" \ +epochs=50000 eval_interval=5000 \ +lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ +arch.L_layers=2 \ +arch.H_cycles=3 arch.L_cycles=4 \ ++run_name=${run_name} ema=True +``` + +*Runtime:* < 24 hours + +Actually, you can run Maze-Hard with 1 L40S GPU by reducing the batch-size with no noticable loss in performance: + +```bash +run_name="pretrain_att_maze30x30_1gpu" +python pretrain.py \ +arch=trm \ +data_paths="[data/maze-30x30-hard-1k]" \ +evaluators="[]" \ +epochs=50000 eval_interval=5000 \ +lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 global_batch_size=128 \ +arch.L_layers=2 \ +arch.H_cycles=3 arch.L_cycles=4 \ ++run_name=${run_name} ema=True +``` + +*Runtime:* < 24 hours + + ### ARC-AGI-1 (assuming 4 H-100 GPUs): ```bash @@ -89,51 +159,6 @@ arch.H_cycles=3 arch.L_cycles=4 \ *Runtime:* ~3 days -### Sudoku-Extreme (assuming 1 L40S GPU): - -```bash -run_name="pretrain_mlp_t_sudoku" -python pretrain.py \ -arch=trm \ -data_paths="[data/sudoku-extreme-1k-aug-1000]" \ -evaluators="[]" \ -epochs=50000 eval_interval=5000 \ -lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ -arch.mlp_t=True arch.pos_encodings=none \ -arch.L_layers=2 \ -arch.H_cycles=3 arch.L_cycles=6 \ -+run_name=${run_name} ema=True - -run_name="pretrain_att_sudoku" -python pretrain.py \ -arch=trm \ -data_paths="[data/sudoku-extreme-1k-aug-1000]" \ -evaluators="[]" \ -epochs=50000 eval_interval=5000 \ -lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ -arch.L_layers=2 \ -arch.H_cycles=3 arch.L_cycles=6 \ -+run_name=${run_name} ema=True -``` - -*Runtime:* < 36 hours - -### Maze-Hard (assuming 4 L40S GPUs): - -```bash -run_name="pretrain_att_maze30x30" -torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \ -arch=trm \ -data_paths="[data/maze-30x30-hard-1k]" \ -evaluators="[]" \ -epochs=50000 eval_interval=5000 \ -lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \ -arch.L_layers=2 \ -arch.H_cycles=3 arch.L_cycles=4 \ -+run_name=${run_name} ema=True -``` - -*Runtime:* < 24 hours ## Reference