upload
This commit is contained in:
147
README.md
Normal file
147
README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
# Less is More: Recursive Reasoning with Tiny Networks
|
||||
|
||||
This is the codebase for the paper: "Less is More: Recursive Reasoning with Tiny Networks" were we present a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network.
|
||||
|
||||
[Paper](https://arxiv.org/abs/2510.04871)
|
||||
|
||||
### How TRM works
|
||||
|
||||
Tiny Recursion Model (TRM) recursively improves its predicted answer y with a tiny network. It starts with the embedded input question x and initial embedded answer y and latent z. For up to K improvements steps, it tries to improve its answer y. It does so by i) recursively updating n times its latent z given the question x, current answer y, and current latent z (recursive reasoning), and then ii) updating its answer y given the current answer y and current latent z. This recursive process allows the model to progressively improve its answer (potentially addressing any errors from its previous answer) in an extremely parameter-efficient manner while minimizing overfitting.
|
||||
|
||||
<p align="center">
|
||||
<img src="{{ site.baseurl }}/assets/images/TRM_fig.png" alt="TRM-Figure" style="width:50%">
|
||||
</p>
|
||||
|
||||
### Requirements
|
||||
|
||||
- Python 3.10 (or similar)
|
||||
- Cuda 12.6.0 (or similar)
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip wheel setuptools
|
||||
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 # install torch based on your cuda version
|
||||
pip install -r requirements.txt # install requirements
|
||||
pip install --no-cache-dir --no-build-isolation adam-atan2
|
||||
wandb login YOUR-LOGIN # login if you want the logger to sync results to your Weights & Biases (https://wandb.ai/)
|
||||
```
|
||||
|
||||
### Dataset Preparation
|
||||
|
||||
```bash
|
||||
# ARC-AGI-1
|
||||
python -m dataset.build_arc_dataset \
|
||||
--input-file-prefix kaggle/combined/arc-agi \
|
||||
--output-dir data/arc1concept-aug-1000 \
|
||||
--subsets training evaluation concept \
|
||||
--test-set-name evaluation
|
||||
|
||||
# ARC-AGI-2
|
||||
python -m dataset.build_arc_dataset \
|
||||
--input-file-prefix kaggle/combined/arc-agi \
|
||||
--output-dir data/arc2concept-aug-1000 \
|
||||
--subsets training2 evaluation2 concept \
|
||||
--test-set-name evaluation2
|
||||
|
||||
## Note: You cannot train on both ARC-AGI-1 and ARC-AGI-2 and evaluate them both because ARC-AGI-2 training data contains some ARC-AGI-1 eval data
|
||||
|
||||
# Sudoku-Extreme
|
||||
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples, 1000 augments
|
||||
|
||||
# Maze-Hard
|
||||
python dataset/build_maze_dataset.py # 1000 examples, 8 augments
|
||||
```
|
||||
|
||||
## Experiments
|
||||
|
||||
### ARC-AGI (assuming 4 H-100 GPUs):
|
||||
|
||||
```bash
|
||||
run_name="pretrain_att_arc12concept_4"
|
||||
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/arc12concept-aug-1000]" \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=4 \
|
||||
+run_name=${run_name} ema=True
|
||||
|
||||
```
|
||||
|
||||
*Runtime:* ~3 days
|
||||
|
||||
### Sudoku-Extreme (assuming 1 L40S GPU):
|
||||
|
||||
```bash
|
||||
run_name="pretrain_mlp_t_sudoku"
|
||||
python pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
||||
evaluators="[]" \
|
||||
epochs=50000 eval_interval=5000 \
|
||||
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||
arch.mlp_t=True arch.pos_encodings=none \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=6 \
|
||||
+run_name=${run_name} ema=True
|
||||
|
||||
run_name="pretrain_att_sudoku"
|
||||
python pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
||||
evaluators="[]" \
|
||||
epochs=50000 eval_interval=5000 \
|
||||
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=6 \
|
||||
+run_name=${run_name} ema=True
|
||||
```
|
||||
|
||||
*Runtime:* < 36 hours
|
||||
|
||||
### Maze-Hard (assuming 4 L40S GPUs):
|
||||
|
||||
```bash
|
||||
run_name="pretrain_att_maze30x30"
|
||||
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/maze-30x30-hard-1k]" \
|
||||
evaluators="[]" \
|
||||
epochs=50000 eval_interval=5000 \
|
||||
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=4 \
|
||||
+run_name=${run_name} ema=True
|
||||
```
|
||||
|
||||
*Runtime:* < 24 hours
|
||||
|
||||
## Reference
|
||||
|
||||
If you find our work useful, please consider citing:
|
||||
|
||||
```bibtex
|
||||
@misc{jolicoeurmartineau2025tinyrecursionmodel,
|
||||
title={Less is More: Recursive Reasoning with Tiny Networks},
|
||||
author={Alexia Jolicoeur-Martineau},
|
||||
year={2025},
|
||||
eprint={xxxxxxx},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI},
|
||||
url={https://arxiv.org/abs/xxxxxxxxx},
|
||||
}
|
||||
```
|
||||
|
||||
and the Hierarchical Reasoning Model (HRM):
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025hierarchicalreasoningmodel,
|
||||
title={Hierarchical Reasoning Model},
|
||||
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
|
||||
year={2025},
|
||||
eprint={2506.21734},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI},
|
||||
url={https://arxiv.org/abs/2506.21734},
|
||||
}
|
||||
```
|
||||
|
||||
This code is based on the Hierarchical Reasoning Model [code](https://github.com/sapientinc/HRM) and the Hierarchical Reasoning Model Analysis [code](https://github.com/arcprize/hierarchical-reasoning-model-analysis).
|
||||
Reference in New Issue
Block a user