144 lines
5.0 KiB
Markdown
144 lines
5.0 KiB
Markdown
# Less is More: Recursive Reasoning with Tiny Networks
|
|
|
|
This is the codebase for the paper: "Less is More: Recursive Reasoning with Tiny Networks" were we present a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network.
|
|
|
|
[Paper](https://arxiv.org/abs/2510.04871)
|
|
|
|
### How TRM works
|
|
|
|
Tiny Recursion Model (TRM) recursively improves its predicted answer y with a tiny network. It starts with the embedded input question x and initial embedded answer y and latent z. For up to K improvements steps, it tries to improve its answer y. It does so by i) recursively updating n times its latent z given the question x, current answer y, and current latent z (recursive reasoning), and then ii) updating its answer y given the current answer y and current latent z. This recursive process allows the model to progressively improve its answer (potentially addressing any errors from its previous answer) in an extremely parameter-efficient manner while minimizing overfitting.
|
|
|
|
### Requirements
|
|
|
|
- Python 3.10 (or similar)
|
|
- Cuda 12.6.0 (or similar)
|
|
|
|
```bash
|
|
pip install --upgrade pip wheel setuptools
|
|
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 # install torch based on your cuda version
|
|
pip install -r requirements.txt # install requirements
|
|
pip install --no-cache-dir --no-build-isolation adam-atan2
|
|
wandb login YOUR-LOGIN # login if you want the logger to sync results to your Weights & Biases (https://wandb.ai/)
|
|
```
|
|
|
|
### Dataset Preparation
|
|
|
|
```bash
|
|
# ARC-AGI-1
|
|
python -m dataset.build_arc_dataset \
|
|
--input-file-prefix kaggle/combined/arc-agi \
|
|
--output-dir data/arc1concept-aug-1000 \
|
|
--subsets training evaluation concept \
|
|
--test-set-name evaluation
|
|
|
|
# ARC-AGI-2
|
|
python -m dataset.build_arc_dataset \
|
|
--input-file-prefix kaggle/combined/arc-agi \
|
|
--output-dir data/arc2concept-aug-1000 \
|
|
--subsets training2 evaluation2 concept \
|
|
--test-set-name evaluation2
|
|
|
|
## Note: You cannot train on both ARC-AGI-1 and ARC-AGI-2 and evaluate them both because ARC-AGI-2 training data contains some ARC-AGI-1 eval data
|
|
|
|
# Sudoku-Extreme
|
|
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples, 1000 augments
|
|
|
|
# Maze-Hard
|
|
python dataset/build_maze_dataset.py # 1000 examples, 8 augments
|
|
```
|
|
|
|
## Experiments
|
|
|
|
### ARC-AGI (assuming 4 H-100 GPUs):
|
|
|
|
```bash
|
|
run_name="pretrain_att_arc12concept_4"
|
|
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
|
arch=trm \
|
|
data_paths="[data/arc12concept-aug-1000]" \
|
|
arch.L_layers=2 \
|
|
arch.H_cycles=3 arch.L_cycles=4 \
|
|
+run_name=${run_name} ema=True
|
|
|
|
```
|
|
|
|
*Runtime:* ~3 days
|
|
|
|
### Sudoku-Extreme (assuming 1 L40S GPU):
|
|
|
|
```bash
|
|
run_name="pretrain_mlp_t_sudoku"
|
|
python pretrain.py \
|
|
arch=trm \
|
|
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
|
evaluators="[]" \
|
|
epochs=50000 eval_interval=5000 \
|
|
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
|
arch.mlp_t=True arch.pos_encodings=none \
|
|
arch.L_layers=2 \
|
|
arch.H_cycles=3 arch.L_cycles=6 \
|
|
+run_name=${run_name} ema=True
|
|
|
|
run_name="pretrain_att_sudoku"
|
|
python pretrain.py \
|
|
arch=trm \
|
|
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
|
evaluators="[]" \
|
|
epochs=50000 eval_interval=5000 \
|
|
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
|
arch.L_layers=2 \
|
|
arch.H_cycles=3 arch.L_cycles=6 \
|
|
+run_name=${run_name} ema=True
|
|
```
|
|
|
|
*Runtime:* < 36 hours
|
|
|
|
### Maze-Hard (assuming 4 L40S GPUs):
|
|
|
|
```bash
|
|
run_name="pretrain_att_maze30x30"
|
|
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
|
arch=trm \
|
|
data_paths="[data/maze-30x30-hard-1k]" \
|
|
evaluators="[]" \
|
|
epochs=50000 eval_interval=5000 \
|
|
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
|
arch.L_layers=2 \
|
|
arch.H_cycles=3 arch.L_cycles=4 \
|
|
+run_name=${run_name} ema=True
|
|
```
|
|
|
|
*Runtime:* < 24 hours
|
|
|
|
## Reference
|
|
|
|
If you find our work useful, please consider citing:
|
|
|
|
```bibtex
|
|
@misc{jolicoeurmartineau2025morerecursivereasoningtiny,
|
|
title={Less is More: Recursive Reasoning with Tiny Networks},
|
|
author={Alexia Jolicoeur-Martineau},
|
|
year={2025},
|
|
eprint={2510.04871},
|
|
archivePrefix={arXiv},
|
|
primaryClass={cs.LG},
|
|
url={https://arxiv.org/abs/2510.04871},
|
|
}
|
|
```
|
|
|
|
and the Hierarchical Reasoning Model (HRM):
|
|
|
|
```bibtex
|
|
@misc{wang2025hierarchicalreasoningmodel,
|
|
title={Hierarchical Reasoning Model},
|
|
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
|
|
year={2025},
|
|
eprint={2506.21734},
|
|
archivePrefix={arXiv},
|
|
primaryClass={cs.AI},
|
|
url={https://arxiv.org/abs/2506.21734},
|
|
}
|
|
```
|
|
|
|
This code is based on the Hierarchical Reasoning Model [code](https://github.com/sapientinc/HRM) and the Hierarchical Reasoning Model Analysis [code](https://github.com/arcprize/hierarchical-reasoning-model-analysis).
|