upload
This commit is contained in:
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025. Samsung Electronics Co., Ltd. All Rights Reserved.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
# Less is More: Recursive Reasoning with Tiny Networks
|
||||||
|
|
||||||
|
This is the codebase for the paper: "Less is More: Recursive Reasoning with Tiny Networks" were we present a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network.
|
||||||
|
|
||||||
|
[Paper](https://arxiv.org/abs/2510.04871)
|
||||||
|
|
||||||
|
### How TRM works
|
||||||
|
|
||||||
|
Tiny Recursion Model (TRM) recursively improves its predicted answer y with a tiny network. It starts with the embedded input question x and initial embedded answer y and latent z. For up to K improvements steps, it tries to improve its answer y. It does so by i) recursively updating n times its latent z given the question x, current answer y, and current latent z (recursive reasoning), and then ii) updating its answer y given the current answer y and current latent z. This recursive process allows the model to progressively improve its answer (potentially addressing any errors from its previous answer) in an extremely parameter-efficient manner while minimizing overfitting.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="{{ site.baseurl }}/assets/images/TRM_fig.png" alt="TRM-Figure" style="width:50%">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
- Python 3.10 (or similar)
|
||||||
|
- Cuda 12.6.0 (or similar)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --upgrade pip wheel setuptools
|
||||||
|
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 # install torch based on your cuda version
|
||||||
|
pip install -r requirements.txt # install requirements
|
||||||
|
pip install --no-cache-dir --no-build-isolation adam-atan2
|
||||||
|
wandb login YOUR-LOGIN # login if you want the logger to sync results to your Weights & Biases (https://wandb.ai/)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dataset Preparation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# ARC-AGI-1
|
||||||
|
python -m dataset.build_arc_dataset \
|
||||||
|
--input-file-prefix kaggle/combined/arc-agi \
|
||||||
|
--output-dir data/arc1concept-aug-1000 \
|
||||||
|
--subsets training evaluation concept \
|
||||||
|
--test-set-name evaluation
|
||||||
|
|
||||||
|
# ARC-AGI-2
|
||||||
|
python -m dataset.build_arc_dataset \
|
||||||
|
--input-file-prefix kaggle/combined/arc-agi \
|
||||||
|
--output-dir data/arc2concept-aug-1000 \
|
||||||
|
--subsets training2 evaluation2 concept \
|
||||||
|
--test-set-name evaluation2
|
||||||
|
|
||||||
|
## Note: You cannot train on both ARC-AGI-1 and ARC-AGI-2 and evaluate them both because ARC-AGI-2 training data contains some ARC-AGI-1 eval data
|
||||||
|
|
||||||
|
# Sudoku-Extreme
|
||||||
|
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples, 1000 augments
|
||||||
|
|
||||||
|
# Maze-Hard
|
||||||
|
python dataset/build_maze_dataset.py # 1000 examples, 8 augments
|
||||||
|
```
|
||||||
|
|
||||||
|
## Experiments
|
||||||
|
|
||||||
|
### ARC-AGI (assuming 4 H-100 GPUs):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
run_name="pretrain_att_arc12concept_4"
|
||||||
|
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
||||||
|
arch=trm \
|
||||||
|
data_paths="[data/arc12concept-aug-1000]" \
|
||||||
|
arch.L_layers=2 \
|
||||||
|
arch.H_cycles=3 arch.L_cycles=4 \
|
||||||
|
+run_name=${run_name} ema=True
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
*Runtime:* ~3 days
|
||||||
|
|
||||||
|
### Sudoku-Extreme (assuming 1 L40S GPU):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
run_name="pretrain_mlp_t_sudoku"
|
||||||
|
python pretrain.py \
|
||||||
|
arch=trm \
|
||||||
|
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
||||||
|
evaluators="[]" \
|
||||||
|
epochs=50000 eval_interval=5000 \
|
||||||
|
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||||
|
arch.mlp_t=True arch.pos_encodings=none \
|
||||||
|
arch.L_layers=2 \
|
||||||
|
arch.H_cycles=3 arch.L_cycles=6 \
|
||||||
|
+run_name=${run_name} ema=True
|
||||||
|
|
||||||
|
run_name="pretrain_att_sudoku"
|
||||||
|
python pretrain.py \
|
||||||
|
arch=trm \
|
||||||
|
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
||||||
|
evaluators="[]" \
|
||||||
|
epochs=50000 eval_interval=5000 \
|
||||||
|
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||||
|
arch.L_layers=2 \
|
||||||
|
arch.H_cycles=3 arch.L_cycles=6 \
|
||||||
|
+run_name=${run_name} ema=True
|
||||||
|
```
|
||||||
|
|
||||||
|
*Runtime:* < 36 hours
|
||||||
|
|
||||||
|
### Maze-Hard (assuming 4 L40S GPUs):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
run_name="pretrain_att_maze30x30"
|
||||||
|
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
||||||
|
arch=trm \
|
||||||
|
data_paths="[data/maze-30x30-hard-1k]" \
|
||||||
|
evaluators="[]" \
|
||||||
|
epochs=50000 eval_interval=5000 \
|
||||||
|
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||||
|
arch.L_layers=2 \
|
||||||
|
arch.H_cycles=3 arch.L_cycles=4 \
|
||||||
|
+run_name=${run_name} ema=True
|
||||||
|
```
|
||||||
|
|
||||||
|
*Runtime:* < 24 hours
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
If you find our work useful, please consider citing:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{jolicoeurmartineau2025tinyrecursionmodel,
|
||||||
|
title={Less is More: Recursive Reasoning with Tiny Networks},
|
||||||
|
author={Alexia Jolicoeur-Martineau},
|
||||||
|
year={2025},
|
||||||
|
eprint={xxxxxxx},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.AI},
|
||||||
|
url={https://arxiv.org/abs/xxxxxxxxx},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
and the Hierarchical Reasoning Model (HRM):
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{wang2025hierarchicalreasoningmodel,
|
||||||
|
title={Hierarchical Reasoning Model},
|
||||||
|
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
|
||||||
|
year={2025},
|
||||||
|
eprint={2506.21734},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.AI},
|
||||||
|
url={https://arxiv.org/abs/2506.21734},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This code is based on the Hierarchical Reasoning Model [code](https://github.com/sapientinc/HRM) and the Hierarchical Reasoning Model Analysis [code](https://github.com/arcprize/hierarchical-reasoning-model-analysis).
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 346 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 261 KiB |
@@ -0,0 +1,24 @@
|
|||||||
|
name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
|
||||||
|
loss:
|
||||||
|
name: losses@ACTLossHead
|
||||||
|
loss_type: stablemax_cross_entropy
|
||||||
|
|
||||||
|
halt_exploration_prob: 0.1
|
||||||
|
halt_max_steps: 16
|
||||||
|
|
||||||
|
H_cycles: 2
|
||||||
|
L_cycles: 2
|
||||||
|
|
||||||
|
H_layers: 4
|
||||||
|
L_layers: 4
|
||||||
|
|
||||||
|
hidden_size: 512
|
||||||
|
num_heads: 8 # min(2, hidden_size // 64)
|
||||||
|
expansion: 4
|
||||||
|
|
||||||
|
puzzle_emb_ndim: ${.hidden_size}
|
||||||
|
|
||||||
|
pos_encodings: rope
|
||||||
|
forward_dtype: bfloat16
|
||||||
|
|
||||||
|
mlp_t: False # use mlp on L instead of transformer
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
name: recursive_reasoning.transformers_baseline@Model_ACTV2
|
||||||
|
loss:
|
||||||
|
name: losses@ACTLossHead
|
||||||
|
loss_type: stablemax_cross_entropy
|
||||||
|
|
||||||
|
halt_exploration_prob: 0.1
|
||||||
|
halt_max_steps: 16
|
||||||
|
|
||||||
|
H_cycles: 1 # kept for compatibility
|
||||||
|
H_layers: 8
|
||||||
|
|
||||||
|
hidden_size: 512
|
||||||
|
num_heads: 12
|
||||||
|
expansion: 4
|
||||||
|
|
||||||
|
puzzle_emb_ndim: ${.hidden_size}
|
||||||
|
|
||||||
|
pos_encodings: rope
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
|
||||||
|
loss:
|
||||||
|
name: losses@ACTLossHead
|
||||||
|
loss_type: stablemax_cross_entropy
|
||||||
|
|
||||||
|
halt_exploration_prob: 0.1
|
||||||
|
halt_max_steps: 16
|
||||||
|
|
||||||
|
H_cycles: 3
|
||||||
|
L_cycles: 6
|
||||||
|
|
||||||
|
H_layers: 0
|
||||||
|
L_layers: 2
|
||||||
|
|
||||||
|
hidden_size: 512
|
||||||
|
num_heads: 8 # min(2, hidden_size // 64)
|
||||||
|
expansion: 4
|
||||||
|
|
||||||
|
puzzle_emb_ndim: ${.hidden_size}
|
||||||
|
|
||||||
|
pos_encodings: rope
|
||||||
|
forward_dtype: bfloat16
|
||||||
|
|
||||||
|
mlp_t: False # use mlp on L instead of transformer
|
||||||
|
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
||||||
|
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
name: recursive_reasoning.trm_hier6@TinyRecursiveReasoningModel_ACTV1
|
||||||
|
loss:
|
||||||
|
name: losses@ACTLossHead
|
||||||
|
loss_type: stablemax_cross_entropy
|
||||||
|
|
||||||
|
halt_exploration_prob: 0.1
|
||||||
|
halt_max_steps: 16
|
||||||
|
|
||||||
|
H_cycles: 3
|
||||||
|
L_cycles: 6
|
||||||
|
|
||||||
|
H_layers: 0
|
||||||
|
L_layers: 2
|
||||||
|
|
||||||
|
hidden_size: 512
|
||||||
|
num_heads: 8 # min(2, hidden_size // 64)
|
||||||
|
expansion: 4
|
||||||
|
|
||||||
|
puzzle_emb_ndim: ${.hidden_size}
|
||||||
|
|
||||||
|
pos_encodings: rope
|
||||||
|
forward_dtype: bfloat16
|
||||||
|
|
||||||
|
mlp_t: False # use mlp on L instead of transformer
|
||||||
|
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
||||||
|
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
name: recursive_reasoning.trm_singlez@TinyRecursiveReasoningModel_ACTV1
|
||||||
|
loss:
|
||||||
|
name: losses@ACTLossHead
|
||||||
|
loss_type: stablemax_cross_entropy
|
||||||
|
|
||||||
|
halt_exploration_prob: 0.1
|
||||||
|
halt_max_steps: 16
|
||||||
|
|
||||||
|
H_cycles: 3
|
||||||
|
L_cycles: 6
|
||||||
|
|
||||||
|
H_layers: 0
|
||||||
|
L_layers: 2
|
||||||
|
|
||||||
|
hidden_size: 512
|
||||||
|
num_heads: 8 # min(2, hidden_size // 64)
|
||||||
|
expansion: 4
|
||||||
|
|
||||||
|
puzzle_emb_ndim: ${.hidden_size}
|
||||||
|
|
||||||
|
pos_encodings: rope
|
||||||
|
forward_dtype: bfloat16
|
||||||
|
|
||||||
|
mlp_t: False # use mlp on L instead of transformer
|
||||||
|
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
||||||
|
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
# ARC training config
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- arch: trm
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
hydra:
|
||||||
|
output_subdir: null
|
||||||
|
|
||||||
|
# Data path
|
||||||
|
data_paths: ['data/arc-aug-1000']
|
||||||
|
data_paths_test: []
|
||||||
|
|
||||||
|
evaluators:
|
||||||
|
- name: arc@ARC
|
||||||
|
|
||||||
|
# Hyperparams - Training
|
||||||
|
global_batch_size: 768
|
||||||
|
|
||||||
|
epochs: 100000
|
||||||
|
eval_interval: 10000
|
||||||
|
checkpoint_every_eval: True
|
||||||
|
|
||||||
|
lr: 1e-4
|
||||||
|
lr_min_ratio: 1.0
|
||||||
|
lr_warmup_steps: 2000
|
||||||
|
|
||||||
|
# Standard hyperparameter settings for LM, as used in Llama
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.95
|
||||||
|
weight_decay: 0.1
|
||||||
|
puzzle_emb_weight_decay: 0.1
|
||||||
|
|
||||||
|
# Hyperparams - Puzzle embeddings training
|
||||||
|
puzzle_emb_lr: 1e-2
|
||||||
|
|
||||||
|
seed: 0
|
||||||
|
min_eval_interval: 0 # when to start the eval
|
||||||
|
|
||||||
|
ema: False # use Exponential-Moving-Average
|
||||||
|
ema_rate: 0.999 # EMA-rate
|
||||||
|
freeze_weights: False # If True, freeze weights and only learn the embeddings
|
||||||
@@ -0,0 +1,341 @@
|
|||||||
|
from typing import List, Tuple, Dict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from argdantic import ArgParser
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from dataset.common import PuzzleDatasetMetadata, dihedral_transform, inverse_dihedral_transform
|
||||||
|
|
||||||
|
|
||||||
|
cli = ArgParser()
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessConfig(BaseModel):
|
||||||
|
input_file_prefix: str
|
||||||
|
output_dir: str
|
||||||
|
subsets: List[str]
|
||||||
|
test_set_name: str
|
||||||
|
test_set_name2: str = "your_test_set"
|
||||||
|
seed: int = 42
|
||||||
|
num_aug: int = 1000
|
||||||
|
puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
|
||||||
|
|
||||||
|
ARCMaxGridSize = 30
|
||||||
|
ARCAugmentRetriesFactor = 5
|
||||||
|
|
||||||
|
PuzzleIdSeparator = "|||"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ARCPuzzle:
|
||||||
|
id: str
|
||||||
|
examples: List[Tuple[np.ndarray, np.ndarray]]
|
||||||
|
|
||||||
|
|
||||||
|
def arc_grid_to_np(grid: List[List[int]]):
|
||||||
|
arr = np.array(grid)
|
||||||
|
|
||||||
|
# Shape check
|
||||||
|
assert arr.ndim == 2
|
||||||
|
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
|
||||||
|
# Element check
|
||||||
|
assert np.all((arr >= 0) & (arr <= 9))
|
||||||
|
return arr.astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
|
||||||
|
# PAD: 0, <eos>: 1, digits: 2 ... 11
|
||||||
|
# Compute random top-left pad
|
||||||
|
if do_translation:
|
||||||
|
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
|
||||||
|
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
|
||||||
|
else:
|
||||||
|
pad_r = pad_c = 0
|
||||||
|
|
||||||
|
# Pad grid
|
||||||
|
result = []
|
||||||
|
for grid in [inp, out]:
|
||||||
|
nrow, ncol = grid.shape
|
||||||
|
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
|
||||||
|
|
||||||
|
# Add <eos>
|
||||||
|
eos_row, eos_col = pad_r + nrow, pad_c + ncol
|
||||||
|
if eos_row < ARCMaxGridSize:
|
||||||
|
grid[eos_row, pad_c:eos_col] = 1
|
||||||
|
if eos_col < ARCMaxGridSize:
|
||||||
|
grid[pad_r:eos_row, eos_col] = 1
|
||||||
|
|
||||||
|
result.append(grid.flatten())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def grid_hash(grid: np.ndarray):
|
||||||
|
assert grid.ndim == 2
|
||||||
|
assert grid.dtype == np.uint8
|
||||||
|
|
||||||
|
buffer = [x.to_bytes(1, byteorder='big') for x in grid.shape]
|
||||||
|
buffer.append(grid.tobytes())
|
||||||
|
|
||||||
|
return hashlib.sha256(b"".join(buffer)).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def puzzle_hash(puzzle: dict):
|
||||||
|
# Hash the puzzle for checking equivalence
|
||||||
|
hashes = []
|
||||||
|
for example_type, example in puzzle.items():
|
||||||
|
for input, label in example.examples:
|
||||||
|
hashes.append(f"{grid_hash(input)}|{grid_hash(label)}")
|
||||||
|
|
||||||
|
hashes.sort()
|
||||||
|
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def aug(name: str):
|
||||||
|
# Augment plan
|
||||||
|
trans_id = np.random.randint(0, 8)
|
||||||
|
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
|
||||||
|
|
||||||
|
name_with_aug_repr = f"{name}{PuzzleIdSeparator}t{trans_id}{PuzzleIdSeparator}{''.join(str(x) for x in mapping)}"
|
||||||
|
|
||||||
|
def _map_grid(grid: np.ndarray):
|
||||||
|
return dihedral_transform(mapping[grid], trans_id)
|
||||||
|
|
||||||
|
return name_with_aug_repr, _map_grid
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_aug(name: str):
|
||||||
|
# Inverse the "aug" function
|
||||||
|
if PuzzleIdSeparator not in name:
|
||||||
|
return name, lambda x: x
|
||||||
|
|
||||||
|
trans_id, perm = name.split(PuzzleIdSeparator)[-2:]
|
||||||
|
trans_id = int(trans_id[1:]) # Remove "t" letter
|
||||||
|
inv_perm = np.argsort(list(perm)).astype(np.uint8)
|
||||||
|
|
||||||
|
def _map_grid(grid: np.ndarray):
|
||||||
|
return inv_perm[inverse_dihedral_transform(grid, trans_id)]
|
||||||
|
|
||||||
|
return name.split(PuzzleIdSeparator)[0], _map_grid
|
||||||
|
|
||||||
|
|
||||||
|
def convert_single_arc_puzzle(results: dict, name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
|
||||||
|
# Convert
|
||||||
|
dests = set(dest_mapping.values())
|
||||||
|
converted = {dest: ARCPuzzle(name, []) for dest in dests}
|
||||||
|
for example_type, examples in puzzle.items():
|
||||||
|
# Map to target split
|
||||||
|
dest = dest_mapping[example_type]
|
||||||
|
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
|
||||||
|
|
||||||
|
group = [converted]
|
||||||
|
|
||||||
|
# Augment
|
||||||
|
if aug_count > 0:
|
||||||
|
hashes = {puzzle_hash(converted)}
|
||||||
|
|
||||||
|
for _trial in range(ARCAugmentRetriesFactor * aug_count):
|
||||||
|
aug_name, _map_grid = aug(name)
|
||||||
|
|
||||||
|
# Check duplicate
|
||||||
|
augmented = {dest: ARCPuzzle(aug_name, [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
|
||||||
|
h = puzzle_hash(augmented)
|
||||||
|
if h not in hashes:
|
||||||
|
hashes.add(h)
|
||||||
|
group.append(augmented)
|
||||||
|
|
||||||
|
if len(group) >= aug_count + 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(group) < aug_count + 1:
|
||||||
|
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
|
||||||
|
|
||||||
|
# Append
|
||||||
|
for dest in dests:
|
||||||
|
# Convert the examples
|
||||||
|
dest_split, dest_set = dest
|
||||||
|
|
||||||
|
results.setdefault(dest_split, {})
|
||||||
|
results[dest_split].setdefault(dest_set, [])
|
||||||
|
results[dest_split][dest_set].append([converted[dest] for converted in group])
|
||||||
|
|
||||||
|
|
||||||
|
def load_puzzles_arcagi(config: DataProcessConfig):
|
||||||
|
train_examples_dest = ("train", "all")
|
||||||
|
test_examples_map = {
|
||||||
|
config.test_set_name: [(1.0, ("test", "all"))],
|
||||||
|
config.test_set_name2: [(1.0, ("test", "all"))],
|
||||||
|
"_default": [(1.0, ("train", "all"))]
|
||||||
|
}
|
||||||
|
|
||||||
|
test_puzzles = {}
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
total_puzzles = 0
|
||||||
|
for subset_name in config.subsets:
|
||||||
|
# Load all puzzles in this subset
|
||||||
|
with open(f"{config.input_file_prefix}_{subset_name}_challenges.json", "r") as f:
|
||||||
|
puzzles = json.load(f)
|
||||||
|
|
||||||
|
sols_filename = f"{config.input_file_prefix}_{subset_name}_solutions.json"
|
||||||
|
if os.path.isfile(sols_filename):
|
||||||
|
with open(sols_filename, "r") as f:
|
||||||
|
sols = json.load(f)
|
||||||
|
|
||||||
|
for puzzle_id in puzzles.keys():
|
||||||
|
for idx, sol_grid in enumerate(sols[puzzle_id]):
|
||||||
|
puzzles[puzzle_id]["test"][idx]["output"] = sol_grid
|
||||||
|
else:
|
||||||
|
# Fill with dummy
|
||||||
|
print (f"{subset_name} solutions not found, filling with dummy")
|
||||||
|
|
||||||
|
for puzzle_id, puzzle in puzzles.items():
|
||||||
|
for example in puzzle["test"]:
|
||||||
|
example.setdefault("output", [[0]])
|
||||||
|
|
||||||
|
# Shuffle puzzles
|
||||||
|
puzzles = list(puzzles.items())
|
||||||
|
np.random.shuffle(puzzles)
|
||||||
|
|
||||||
|
# Assign by fraction
|
||||||
|
for idx, (name, puzzle) in enumerate(puzzles):
|
||||||
|
fraction = idx / len(puzzles)
|
||||||
|
test_examples_dest = None
|
||||||
|
for f, dest in test_examples_map.get(subset_name, test_examples_map["_default"]):
|
||||||
|
if fraction < f:
|
||||||
|
test_examples_dest = dest
|
||||||
|
break
|
||||||
|
|
||||||
|
assert test_examples_dest is not None
|
||||||
|
|
||||||
|
if test_examples_dest[0] == "test":
|
||||||
|
test_puzzles[name] = puzzle
|
||||||
|
|
||||||
|
convert_single_arc_puzzle(results, name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
|
||||||
|
total_puzzles += 1
|
||||||
|
|
||||||
|
print (f"Total puzzles: {total_puzzles}")
|
||||||
|
return results, test_puzzles
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset(config: DataProcessConfig):
|
||||||
|
np.random.seed(config.seed)
|
||||||
|
|
||||||
|
# Read dataset
|
||||||
|
data, test_puzzles = load_puzzles_arcagi(config)
|
||||||
|
|
||||||
|
# Map global puzzle identifiers
|
||||||
|
num_identifiers = config.puzzle_identifiers_start # 0 is blank, start at 1
|
||||||
|
identifier_map = {}
|
||||||
|
for split_name, split in data.items():
|
||||||
|
for subset_name, subset in split.items():
|
||||||
|
for group in subset:
|
||||||
|
for puzzle in group:
|
||||||
|
if puzzle.id not in identifier_map:
|
||||||
|
identifier_map[puzzle.id] = num_identifiers
|
||||||
|
num_identifiers += 1
|
||||||
|
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
|
||||||
|
|
||||||
|
# Save
|
||||||
|
for split_name, split in data.items():
|
||||||
|
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
|
||||||
|
|
||||||
|
# Translational augmentations
|
||||||
|
enable_translational_augment = split_name == "train"
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
total_examples = 0
|
||||||
|
total_puzzles = 0
|
||||||
|
total_groups = 0
|
||||||
|
|
||||||
|
for subset_name, subset in split.items(): # "all" is the only subset
|
||||||
|
# Construct subset
|
||||||
|
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||||
|
results["puzzle_indices"].append(0)
|
||||||
|
results["group_indices"].append(0)
|
||||||
|
|
||||||
|
example_id = 0
|
||||||
|
puzzle_id = 0
|
||||||
|
|
||||||
|
for group in subset:
|
||||||
|
for puzzle in group:
|
||||||
|
# Push puzzle
|
||||||
|
no_aug_id = np.random.randint(0, len(puzzle.examples))
|
||||||
|
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
|
||||||
|
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
|
||||||
|
|
||||||
|
results["inputs"].append(inp)
|
||||||
|
results["labels"].append(out)
|
||||||
|
example_id += 1
|
||||||
|
|
||||||
|
total_examples += 1
|
||||||
|
|
||||||
|
results["puzzle_indices"].append(example_id)
|
||||||
|
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
|
||||||
|
|
||||||
|
puzzle_id += 1
|
||||||
|
total_puzzles += 1
|
||||||
|
|
||||||
|
# Push group
|
||||||
|
results["group_indices"].append(puzzle_id)
|
||||||
|
total_groups += 1
|
||||||
|
|
||||||
|
for k, v in results.items():
|
||||||
|
if k in {"inputs", "labels"}:
|
||||||
|
v = np.stack(v, 0)
|
||||||
|
else:
|
||||||
|
v = np.array(v, dtype=np.int32)
|
||||||
|
|
||||||
|
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
metadata = PuzzleDatasetMetadata(
|
||||||
|
seq_len=ARCMaxGridSize * ARCMaxGridSize,
|
||||||
|
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
|
||||||
|
pad_id=0,
|
||||||
|
ignore_label_id=0,
|
||||||
|
blank_identifier_id=0,
|
||||||
|
num_puzzle_identifiers=num_identifiers,
|
||||||
|
total_groups=total_groups,
|
||||||
|
mean_puzzle_examples=total_examples / total_puzzles,
|
||||||
|
total_puzzles=total_puzzles,
|
||||||
|
sets=list(split.keys())
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save metadata as JSON.
|
||||||
|
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
|
||||||
|
json.dump(metadata.model_dump(), f)
|
||||||
|
|
||||||
|
# Save IDs mapping
|
||||||
|
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||||
|
ids_mapping = {v: k for k, v in identifier_map.items()}
|
||||||
|
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
|
||||||
|
|
||||||
|
# Save Test Puzzles
|
||||||
|
with open(os.path.join(config.output_dir, "test_puzzles.json"), "w") as f:
|
||||||
|
json.dump(test_puzzles, f)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(singleton=True)
|
||||||
|
def main(config: DataProcessConfig):
|
||||||
|
convert_dataset(config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,140 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from argdantic import ArgParser
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from common import PuzzleDatasetMetadata, dihedral_transform
|
||||||
|
|
||||||
|
|
||||||
|
CHARSET = "# SGo"
|
||||||
|
|
||||||
|
|
||||||
|
cli = ArgParser()
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessConfig(BaseModel):
|
||||||
|
source_repo: str = "sapientinc/maze-30x30-hard-1k"
|
||||||
|
output_dir: str = "data/maze-30x30-hard-1k"
|
||||||
|
|
||||||
|
subsample_size: Optional[int] = None
|
||||||
|
aug: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def convert_subset(set_name: str, config: DataProcessConfig):
|
||||||
|
# Read CSV
|
||||||
|
all_chars = set()
|
||||||
|
grid_size = None
|
||||||
|
inputs = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
|
||||||
|
reader = csv.reader(csvfile)
|
||||||
|
next(reader) # Skip header
|
||||||
|
for source, q, a, rating in reader:
|
||||||
|
all_chars.update(q)
|
||||||
|
all_chars.update(a)
|
||||||
|
|
||||||
|
if grid_size is None:
|
||||||
|
n = int(len(q) ** 0.5)
|
||||||
|
grid_size = (n, n)
|
||||||
|
|
||||||
|
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
|
||||||
|
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
|
||||||
|
|
||||||
|
# If subsample_size is specified for the training set,
|
||||||
|
# randomly sample the desired number of examples.
|
||||||
|
if set_name == "train" and config.subsample_size is not None:
|
||||||
|
total_samples = len(inputs)
|
||||||
|
if config.subsample_size < total_samples:
|
||||||
|
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
||||||
|
inputs = [inputs[i] for i in indices]
|
||||||
|
labels = [labels[i] for i in indices]
|
||||||
|
|
||||||
|
# Generate dataset
|
||||||
|
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||||
|
puzzle_id = 0
|
||||||
|
example_id = 0
|
||||||
|
|
||||||
|
results["puzzle_indices"].append(0)
|
||||||
|
results["group_indices"].append(0)
|
||||||
|
|
||||||
|
for inp, out in zip(tqdm(inputs), labels):
|
||||||
|
# Dihedral transformations for augmentation
|
||||||
|
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
|
||||||
|
results["inputs"].append(dihedral_transform(inp, aug_idx))
|
||||||
|
results["labels"].append(dihedral_transform(out, aug_idx))
|
||||||
|
example_id += 1
|
||||||
|
puzzle_id += 1
|
||||||
|
|
||||||
|
results["puzzle_indices"].append(example_id)
|
||||||
|
results["puzzle_identifiers"].append(0)
|
||||||
|
|
||||||
|
# Push group
|
||||||
|
results["group_indices"].append(puzzle_id)
|
||||||
|
|
||||||
|
# Char mappings
|
||||||
|
assert len(all_chars - set(CHARSET)) == 0
|
||||||
|
|
||||||
|
char2id = np.zeros(256, np.uint8)
|
||||||
|
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
|
||||||
|
|
||||||
|
# To Numpy
|
||||||
|
def _seq_to_numpy(seq):
|
||||||
|
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
|
||||||
|
|
||||||
|
return arr
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"inputs": _seq_to_numpy(results["inputs"]),
|
||||||
|
"labels": _seq_to_numpy(results["labels"]),
|
||||||
|
|
||||||
|
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
||||||
|
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
||||||
|
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
metadata = PuzzleDatasetMetadata(
|
||||||
|
seq_len=int(math.prod(grid_size)), # type: ignore
|
||||||
|
vocab_size=len(CHARSET) + 1, # PAD + Charset
|
||||||
|
pad_id=0,
|
||||||
|
ignore_label_id=0,
|
||||||
|
blank_identifier_id=0,
|
||||||
|
num_puzzle_identifiers=1,
|
||||||
|
total_groups=len(results["group_indices"]) - 1,
|
||||||
|
mean_puzzle_examples=1,
|
||||||
|
total_puzzles=len(results["group_indices"]) - 1,
|
||||||
|
sets=["all"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save metadata as JSON.
|
||||||
|
save_dir = os.path.join(config.output_dir, set_name)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
||||||
|
json.dump(metadata.model_dump(), f)
|
||||||
|
|
||||||
|
# Save data
|
||||||
|
for k, v in results.items():
|
||||||
|
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
||||||
|
|
||||||
|
# Save IDs mapping (for visualization only)
|
||||||
|
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||||
|
json.dump(["<blank>"], f)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(singleton=True)
|
||||||
|
def preprocess_data(config: DataProcessConfig):
|
||||||
|
convert_subset("train", config)
|
||||||
|
convert_subset("test", config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,167 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from argdantic import ArgParser
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from common import PuzzleDatasetMetadata
|
||||||
|
|
||||||
|
|
||||||
|
cli = ArgParser()
|
||||||
|
|
||||||
|
|
||||||
|
class DataProcessConfig(BaseModel):
|
||||||
|
source_repo: str = "sapientinc/sudoku-extreme"
|
||||||
|
output_dir: str = "data/sudoku-extreme-full"
|
||||||
|
|
||||||
|
subsample_size: Optional[int] = None
|
||||||
|
min_difficulty: Optional[int] = None
|
||||||
|
num_aug: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
|
||||||
|
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
|
||||||
|
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
|
||||||
|
|
||||||
|
# Randomly decide whether to transpose.
|
||||||
|
transpose_flag = np.random.rand() < 0.5
|
||||||
|
|
||||||
|
# Generate a valid row permutation:
|
||||||
|
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
|
||||||
|
bands = np.random.permutation(3)
|
||||||
|
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
|
||||||
|
|
||||||
|
# Similarly for columns (stacks).
|
||||||
|
stacks = np.random.permutation(3)
|
||||||
|
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
|
||||||
|
|
||||||
|
# Build an 81->81 mapping. For each new cell at (i, j)
|
||||||
|
# (row index = i // 9, col index = i % 9),
|
||||||
|
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
|
||||||
|
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
|
||||||
|
|
||||||
|
def apply_transformation(x: np.ndarray) -> np.ndarray:
|
||||||
|
# Apply transpose flag
|
||||||
|
if transpose_flag:
|
||||||
|
x = x.T
|
||||||
|
# Apply the position mapping.
|
||||||
|
new_board = x.flatten()[mapping].reshape(9, 9).copy()
|
||||||
|
# Apply digit mapping
|
||||||
|
return digit_map[new_board]
|
||||||
|
|
||||||
|
return apply_transformation(board), apply_transformation(solution)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_subset(set_name: str, config: DataProcessConfig):
|
||||||
|
# Read CSV
|
||||||
|
inputs = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
|
||||||
|
reader = csv.reader(csvfile)
|
||||||
|
next(reader) # Skip header
|
||||||
|
for source, q, a, rating in reader:
|
||||||
|
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
|
||||||
|
assert len(q) == 81 and len(a) == 81
|
||||||
|
|
||||||
|
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
||||||
|
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
||||||
|
|
||||||
|
# If subsample_size is specified for the training set,
|
||||||
|
# randomly sample the desired number of examples.
|
||||||
|
if set_name == "train" and config.subsample_size is not None:
|
||||||
|
total_samples = len(inputs)
|
||||||
|
if config.subsample_size < total_samples:
|
||||||
|
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
||||||
|
inputs = [inputs[i] for i in indices]
|
||||||
|
labels = [labels[i] for i in indices]
|
||||||
|
|
||||||
|
# Generate dataset
|
||||||
|
num_augments = config.num_aug if set_name == "train" else 0
|
||||||
|
|
||||||
|
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||||
|
puzzle_id = 0
|
||||||
|
example_id = 0
|
||||||
|
|
||||||
|
results["puzzle_indices"].append(0)
|
||||||
|
results["group_indices"].append(0)
|
||||||
|
|
||||||
|
for orig_inp, orig_out in zip(tqdm(inputs), labels):
|
||||||
|
for aug_idx in range(1 + num_augments):
|
||||||
|
# First index is not augmented
|
||||||
|
if aug_idx == 0:
|
||||||
|
inp, out = orig_inp, orig_out
|
||||||
|
else:
|
||||||
|
inp, out = shuffle_sudoku(orig_inp, orig_out)
|
||||||
|
|
||||||
|
# Push puzzle (only single example)
|
||||||
|
results["inputs"].append(inp)
|
||||||
|
results["labels"].append(out)
|
||||||
|
example_id += 1
|
||||||
|
puzzle_id += 1
|
||||||
|
|
||||||
|
results["puzzle_indices"].append(example_id)
|
||||||
|
results["puzzle_identifiers"].append(0)
|
||||||
|
|
||||||
|
# Push group
|
||||||
|
results["group_indices"].append(puzzle_id)
|
||||||
|
|
||||||
|
# To Numpy
|
||||||
|
def _seq_to_numpy(seq):
|
||||||
|
arr = np.concatenate(seq).reshape(len(seq), -1)
|
||||||
|
|
||||||
|
assert np.all((arr >= 0) & (arr <= 9))
|
||||||
|
return arr + 1
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"inputs": _seq_to_numpy(results["inputs"]),
|
||||||
|
"labels": _seq_to_numpy(results["labels"]),
|
||||||
|
|
||||||
|
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
||||||
|
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
||||||
|
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
metadata = PuzzleDatasetMetadata(
|
||||||
|
seq_len=81,
|
||||||
|
vocab_size=10 + 1, # PAD + "0" ... "9"
|
||||||
|
pad_id=0,
|
||||||
|
ignore_label_id=0,
|
||||||
|
blank_identifier_id=0,
|
||||||
|
num_puzzle_identifiers=1,
|
||||||
|
total_groups=len(results["group_indices"]) - 1,
|
||||||
|
mean_puzzle_examples=1,
|
||||||
|
total_puzzles=len(results["group_indices"]) - 1,
|
||||||
|
sets=["all"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save metadata as JSON.
|
||||||
|
save_dir = os.path.join(config.output_dir, set_name)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
||||||
|
json.dump(metadata.model_dump(), f)
|
||||||
|
|
||||||
|
# Save data
|
||||||
|
for k, v in results.items():
|
||||||
|
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
||||||
|
|
||||||
|
# Save IDs mapping (for visualization only)
|
||||||
|
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||||
|
json.dump(["<blank>"], f)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(singleton=True)
|
||||||
|
def preprocess_data(config: DataProcessConfig):
|
||||||
|
convert_subset("train", config)
|
||||||
|
convert_subset("test", config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# Global list mapping each dihedral transform id to its inverse.
|
||||||
|
# Index corresponds to the original tid, and the value is its inverse.
|
||||||
|
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
|
||||||
|
class PuzzleDatasetMetadata(pydantic.BaseModel):
|
||||||
|
pad_id: int
|
||||||
|
ignore_label_id: Optional[int]
|
||||||
|
blank_identifier_id: int
|
||||||
|
vocab_size: int
|
||||||
|
seq_len: int
|
||||||
|
num_puzzle_identifiers: int
|
||||||
|
total_groups: int
|
||||||
|
mean_puzzle_examples: float
|
||||||
|
total_puzzles: int
|
||||||
|
sets: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
||||||
|
"""8 dihedral symmetries by rotate, flip and mirror"""
|
||||||
|
|
||||||
|
if tid == 0:
|
||||||
|
return arr # identity
|
||||||
|
elif tid == 1:
|
||||||
|
return np.rot90(arr, k=1)
|
||||||
|
elif tid == 2:
|
||||||
|
return np.rot90(arr, k=2)
|
||||||
|
elif tid == 3:
|
||||||
|
return np.rot90(arr, k=3)
|
||||||
|
elif tid == 4:
|
||||||
|
return np.fliplr(arr) # horizontal flip
|
||||||
|
elif tid == 5:
|
||||||
|
return np.flipud(arr) # vertical flip
|
||||||
|
elif tid == 6:
|
||||||
|
return arr.T # transpose (reflection along main diagonal)
|
||||||
|
elif tid == 7:
|
||||||
|
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
|
||||||
|
else:
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
||||||
|
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
from typing import Dict, Sequence, Optional
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from numba import njit
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
|
||||||
|
from dataset.common import PuzzleDatasetMetadata
|
||||||
|
|
||||||
|
@njit
|
||||||
|
def _crop(grid: np.ndarray):
|
||||||
|
"""Find maximum-sized rectangle without any EOS token inside. """
|
||||||
|
grid = grid.reshape(30, 30)
|
||||||
|
|
||||||
|
max_area = 0
|
||||||
|
max_size = (0, 0)
|
||||||
|
nr, nc = grid.shape
|
||||||
|
|
||||||
|
num_c = nc
|
||||||
|
for num_r in range(1, nr + 1):
|
||||||
|
# Scan for maximum c
|
||||||
|
for c in range(1, num_c + 1):
|
||||||
|
x = grid[num_r - 1, c - 1]
|
||||||
|
if (x < 2) | (x > 11):
|
||||||
|
num_c = c - 1
|
||||||
|
break
|
||||||
|
|
||||||
|
area = num_r * num_c
|
||||||
|
if area > max_area:
|
||||||
|
max_area = area
|
||||||
|
max_size = (num_r, num_c)
|
||||||
|
|
||||||
|
return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class ARC:
|
||||||
|
required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
|
||||||
|
|
||||||
|
def __init__(self, data_path: str,
|
||||||
|
eval_metadata: PuzzleDatasetMetadata,
|
||||||
|
submission_K: int = 2,
|
||||||
|
pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
|
||||||
|
aggregated_voting: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.pass_Ks = pass_Ks
|
||||||
|
self.submission_K = submission_K
|
||||||
|
self.aggregated_voting = aggregated_voting
|
||||||
|
self.blank_identifier_id = eval_metadata.blank_identifier_id
|
||||||
|
|
||||||
|
# Load identifiers and test puzzles
|
||||||
|
with open(os.path.join(data_path, "identifiers.json"), "r") as f:
|
||||||
|
self.identifier_map = json.load(f)
|
||||||
|
with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
|
||||||
|
self.test_puzzles = json.load(f)
|
||||||
|
|
||||||
|
# States
|
||||||
|
self._local_hmap = {}
|
||||||
|
self._local_preds = {}
|
||||||
|
|
||||||
|
def begin_eval(self):
|
||||||
|
if not self.aggregated_voting:
|
||||||
|
# Clear previous predictions
|
||||||
|
self._local_hmap = {}
|
||||||
|
self._local_preds = {}
|
||||||
|
|
||||||
|
def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
|
||||||
|
# Collect required outputs to CPU
|
||||||
|
outputs = {}
|
||||||
|
q_values = None
|
||||||
|
|
||||||
|
for collection in (batch, preds):
|
||||||
|
for k, v in collection.items():
|
||||||
|
if k in self.required_outputs:
|
||||||
|
if k == "q_halt_logits":
|
||||||
|
q_values = v.to(torch.float64).sigmoid().cpu()
|
||||||
|
else:
|
||||||
|
outputs[k] = v.cpu()
|
||||||
|
|
||||||
|
assert q_values is not None
|
||||||
|
|
||||||
|
# Remove padding from outputs
|
||||||
|
mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
|
||||||
|
outputs = {k: v[mask] for k, v in outputs.items()}
|
||||||
|
|
||||||
|
# Get predictions
|
||||||
|
for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
|
||||||
|
name = self.identifier_map[identifier]
|
||||||
|
orig_name, _inverse_fn = inverse_aug(name)
|
||||||
|
|
||||||
|
input_hash = grid_hash(_inverse_fn(_crop(input)))
|
||||||
|
|
||||||
|
pred = _inverse_fn(_crop(pred))
|
||||||
|
assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
|
||||||
|
|
||||||
|
# Store into local state
|
||||||
|
pred_hash = grid_hash(pred)
|
||||||
|
|
||||||
|
self._local_hmap[pred_hash] = pred
|
||||||
|
|
||||||
|
self._local_preds.setdefault(orig_name, {})
|
||||||
|
self._local_preds[orig_name].setdefault(input_hash, [])
|
||||||
|
self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
|
||||||
|
|
||||||
|
def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
|
||||||
|
# Gather predictions to rank 0 for voting
|
||||||
|
global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
|
||||||
|
dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
|
||||||
|
|
||||||
|
# Rank 0 logic
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
submission = {}
|
||||||
|
correct = [0.0 for _ in range(len(self.pass_Ks))]
|
||||||
|
|
||||||
|
for name, puzzle in self.test_puzzles.items():
|
||||||
|
# Process test examples in this puzzle
|
||||||
|
submission[name] = []
|
||||||
|
num_test_correct = [0 for _ in range(len(self.pass_Ks))]
|
||||||
|
for pair in puzzle["test"]:
|
||||||
|
input_hash = grid_hash(arc_grid_to_np(pair["input"]))
|
||||||
|
label_hash = grid_hash(arc_grid_to_np(pair["output"]))
|
||||||
|
|
||||||
|
p_map = {}
|
||||||
|
for hmap, preds in global_hmap_preds: # type: ignore
|
||||||
|
for h, q in preds.get(name, {}).get(input_hash, {}):
|
||||||
|
p_map.setdefault(h, [0, 0])
|
||||||
|
p_map[h][0] += 1
|
||||||
|
p_map[h][1] += q
|
||||||
|
|
||||||
|
if not len(p_map):
|
||||||
|
print (f"Puzzle {name} has no predictions.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for h, stats in p_map.items():
|
||||||
|
stats[1] /= stats[0]
|
||||||
|
|
||||||
|
p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
|
||||||
|
|
||||||
|
# vote for different Ks
|
||||||
|
for i, k in enumerate(self.pass_Ks):
|
||||||
|
ok = False
|
||||||
|
for h, stats in p_map[:k]:
|
||||||
|
ok |= h == label_hash
|
||||||
|
|
||||||
|
num_test_correct[i] += ok
|
||||||
|
|
||||||
|
# Query grids
|
||||||
|
pred_grids = []
|
||||||
|
for h, stats in p_map[:self.submission_K]:
|
||||||
|
for hmap, preds in global_hmap_preds: # type: ignore
|
||||||
|
if h in hmap:
|
||||||
|
pred_grids.append(hmap[h])
|
||||||
|
break
|
||||||
|
|
||||||
|
# Pad to K
|
||||||
|
while len(pred_grids) < self.submission_K:
|
||||||
|
pred_grids.append(pred_grids[0])
|
||||||
|
|
||||||
|
submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
|
||||||
|
|
||||||
|
# Total correctness
|
||||||
|
for i in range(len(self.pass_Ks)):
|
||||||
|
correct[i] += num_test_correct[i] / len(puzzle["test"])
|
||||||
|
|
||||||
|
# Save submission
|
||||||
|
if save_path is not None:
|
||||||
|
with open(os.path.join(save_path, "submission.json"), "w") as f:
|
||||||
|
json.dump(submission, f)
|
||||||
|
|
||||||
|
# Final result
|
||||||
|
all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
|
||||||
|
|
||||||
|
return all_results
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,32 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
||||||
|
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
||||||
|
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
||||||
|
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
||||||
|
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if std == 0:
|
||||||
|
tensor.zero_()
|
||||||
|
else:
|
||||||
|
sqrt2 = math.sqrt(2)
|
||||||
|
a = math.erf(lower / sqrt2)
|
||||||
|
b = math.erf(upper / sqrt2)
|
||||||
|
z = (b - a) / 2
|
||||||
|
|
||||||
|
c = (2 * math.pi) ** -0.5
|
||||||
|
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
||||||
|
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
||||||
|
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
||||||
|
|
||||||
|
tensor.uniform_(a, b)
|
||||||
|
tensor.erfinv_()
|
||||||
|
tensor.mul_(sqrt2 * comp_std)
|
||||||
|
tensor.clip_(lower * comp_std, upper * comp_std)
|
||||||
|
|
||||||
|
return tensor
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
import copy
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class EMAHelper(object):
|
||||||
|
def __init__(self, mu=0.999):
|
||||||
|
self.mu = mu
|
||||||
|
self.shadow = {}
|
||||||
|
|
||||||
|
def register(self, module):
|
||||||
|
if isinstance(module, nn.DataParallel):
|
||||||
|
module = module.module
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.shadow[name] = param.data.clone()
|
||||||
|
|
||||||
|
def update(self, module):
|
||||||
|
if isinstance(module, nn.DataParallel):
|
||||||
|
module = module.module
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
|
||||||
|
|
||||||
|
def ema(self, module):
|
||||||
|
if isinstance(module, nn.DataParallel):
|
||||||
|
module = module.module
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
param.data.copy_(self.shadow[name].data)
|
||||||
|
|
||||||
|
def ema_copy(self, module):
|
||||||
|
module_copy = copy.deepcopy(module)
|
||||||
|
self.ema(module_copy)
|
||||||
|
return module_copy
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.shadow
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self.shadow = state_dict
|
||||||
|
|
||||||
@@ -0,0 +1,169 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
#try:
|
||||||
|
# from flash_attn_interface import flash_attn_func # type: ignore[import]
|
||||||
|
#except ImportError:
|
||||||
|
# # Fallback to FlashAttention 2
|
||||||
|
# from flash_attn import flash_attn_func # type: ignore[import]
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
|
||||||
|
|
||||||
|
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
def _find_multiple(a, b):
|
||||||
|
return (-(a // -b)) * b
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x: torch.Tensor):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||||
|
# q, k: [bs, seq_len, num_heads, head_dim]
|
||||||
|
# cos, sin: [seq_len, head_dim]
|
||||||
|
orig_dtype = q.dtype
|
||||||
|
q = q.to(cos.dtype)
|
||||||
|
k = k.to(cos.dtype)
|
||||||
|
|
||||||
|
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
||||||
|
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
||||||
|
|
||||||
|
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class CastedLinear(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool):
|
||||||
|
super().__init__()
|
||||||
|
# Truncated LeCun normal init
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
||||||
|
)
|
||||||
|
self.bias = None
|
||||||
|
if bias:
|
||||||
|
# Zero init bias
|
||||||
|
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
||||||
|
|
||||||
|
|
||||||
|
class CastedEmbedding(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
init_std: float,
|
||||||
|
cast_to: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.cast_to = cast_to
|
||||||
|
|
||||||
|
# Truncated LeCun normal init
|
||||||
|
self.embedding_weight = nn.Parameter(
|
||||||
|
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_position_embeddings, base, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# RoPE
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||||
|
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
||||||
|
freqs = torch.outer(t, inv_freq)
|
||||||
|
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
|
||||||
|
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.cos_cached, self.sin_cached
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.output_size = head_dim * num_heads
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
||||||
|
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
|
||||||
|
# Split head
|
||||||
|
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||||
|
query = qkv[:, :, :self.num_heads]
|
||||||
|
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
||||||
|
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
||||||
|
|
||||||
|
# RoPE
|
||||||
|
if cos_sin is not None:
|
||||||
|
cos, sin = cos_sin
|
||||||
|
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||||
|
|
||||||
|
# flash attn
|
||||||
|
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
|
||||||
|
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
|
||||||
|
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
|
||||||
|
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
|
||||||
|
return self.o_proj(attn_output)
|
||||||
|
|
||||||
|
class LinearSwish(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, reverse=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
|
||||||
|
self.reverse = reverse
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.reverse:
|
||||||
|
return F.silu(self.linear(x))
|
||||||
|
else:
|
||||||
|
return self.linear(F.silu(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, expansion: float):
|
||||||
|
super().__init__()
|
||||||
|
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
||||||
|
|
||||||
|
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
||||||
|
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
||||||
|
return self.down_proj(F.silu(gate) * up)
|
||||||
|
|
||||||
|
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
||||||
|
variance = hidden_states.square().mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
from typing import Any, Tuple, Dict, Sequence, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
IGNORE_LABEL_ID = -100
|
||||||
|
|
||||||
|
|
||||||
|
def s(x, epsilon=1e-30):
|
||||||
|
return torch.where(
|
||||||
|
x<0,
|
||||||
|
1/(1-x+ epsilon),
|
||||||
|
x + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_stablemax(x, dim=-1):
|
||||||
|
s_x = s(x)
|
||||||
|
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
||||||
|
|
||||||
|
|
||||||
|
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
|
||||||
|
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
||||||
|
|
||||||
|
if valid_mask is None:
|
||||||
|
valid_mask = (labels != ignore_index)
|
||||||
|
transformed_labels = torch.where(valid_mask, labels, 0)
|
||||||
|
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
||||||
|
|
||||||
|
return -torch.where(valid_mask, prediction_logprobs, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
||||||
|
# Cast logits to f32
|
||||||
|
# Flatten logits
|
||||||
|
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
||||||
|
|
||||||
|
|
||||||
|
class ACTLossHead(nn.Module):
|
||||||
|
def __init__(self, model: nn.Module, loss_type: str):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.loss_fn = globals()[loss_type]
|
||||||
|
|
||||||
|
def initial_carry(self, *args, **kwargs):
|
||||||
|
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
return_keys: Sequence[str],
|
||||||
|
# Model args
|
||||||
|
**model_kwargs,
|
||||||
|
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
||||||
|
# Model logits
|
||||||
|
# B x SeqLen x D
|
||||||
|
new_carry, outputs = self.model(**model_kwargs)
|
||||||
|
labels = new_carry.current_data["labels"]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Preds
|
||||||
|
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
|
||||||
|
|
||||||
|
# Correctness
|
||||||
|
mask = (labels != IGNORE_LABEL_ID)
|
||||||
|
loss_counts = mask.sum(-1)
|
||||||
|
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
||||||
|
|
||||||
|
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
||||||
|
seq_is_correct = is_correct.sum(-1) == loss_counts
|
||||||
|
|
||||||
|
# Metrics (halted)
|
||||||
|
valid_metrics = new_carry.halted & (loss_counts > 0)
|
||||||
|
metrics = {
|
||||||
|
"count": valid_metrics.sum(),
|
||||||
|
|
||||||
|
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
||||||
|
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
||||||
|
|
||||||
|
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
||||||
|
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Losses
|
||||||
|
|
||||||
|
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
|
||||||
|
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
||||||
|
metrics.update({
|
||||||
|
"lm_loss": lm_loss.detach(),
|
||||||
|
"q_halt_loss": q_halt_loss.detach(),
|
||||||
|
})
|
||||||
|
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
|
||||||
|
q_continue_loss = 0
|
||||||
|
if "target_q_continue" in outputs:
|
||||||
|
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
||||||
|
|
||||||
|
metrics["q_continue_loss"] = q_continue_loss.detach()
|
||||||
|
# Filter outputs for return
|
||||||
|
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
||||||
|
|
||||||
|
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
||||||
|
|
||||||
@@ -0,0 +1,294 @@
|
|||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||||
|
from models.sparse_embedding import CastedSparseEmbedding
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HierarchicalReasoningModel_ACTV1InnerCarry:
|
||||||
|
z_H: torch.Tensor
|
||||||
|
z_L: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HierarchicalReasoningModel_ACTV1Carry:
|
||||||
|
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
|
||||||
|
|
||||||
|
steps: torch.Tensor
|
||||||
|
halted: torch.Tensor
|
||||||
|
|
||||||
|
current_data: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
|
||||||
|
batch_size: int
|
||||||
|
seq_len: int
|
||||||
|
puzzle_emb_ndim: int = 0
|
||||||
|
num_puzzle_identifiers: int
|
||||||
|
vocab_size: int
|
||||||
|
|
||||||
|
H_cycles: int
|
||||||
|
L_cycles: int
|
||||||
|
|
||||||
|
H_layers: int
|
||||||
|
L_layers: int
|
||||||
|
|
||||||
|
# Transformer config
|
||||||
|
hidden_size: int
|
||||||
|
expansion: float
|
||||||
|
num_heads: int
|
||||||
|
pos_encodings: str
|
||||||
|
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
# Halting Q-learning config
|
||||||
|
halt_max_steps: int
|
||||||
|
halt_exploration_prob: float
|
||||||
|
|
||||||
|
forward_dtype: str = "bfloat16"
|
||||||
|
|
||||||
|
# Alexia: added
|
||||||
|
mlp_t: bool=False # use mlp on L instead of transformer
|
||||||
|
|
||||||
|
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
|
||||||
|
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
if self.config.mlp_t:
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
|
||||||
|
self.mlp_t = SwiGLU(
|
||||||
|
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = Attention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
head_dim=config.hidden_size // config.num_heads,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_key_value_heads=config.num_heads,
|
||||||
|
causal=False
|
||||||
|
)
|
||||||
|
self.mlp = SwiGLU(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
self.norm_eps = config.rms_norm_eps
|
||||||
|
|
||||||
|
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# B, L, D = hidden_states.shape
|
||||||
|
# Post Norm
|
||||||
|
if self.config.mlp_t:
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
out = self.mlp_t(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
else:
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||||
|
# Fully Connected
|
||||||
|
out = self.mlp(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||||
|
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
# Input injection (add)
|
||||||
|
hidden_states = hidden_states + input_injection
|
||||||
|
# Layers
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
|
||||||
|
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||||
|
embed_init_std = 1.0 / self.embed_scale
|
||||||
|
|
||||||
|
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
|
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||||
|
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
# Zero init puzzle embeddings
|
||||||
|
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||||
|
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||||
|
|
||||||
|
# LM Blocks
|
||||||
|
if self.config.pos_encodings == "rope":
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||||
|
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
base=self.config.rope_theta)
|
||||||
|
elif self.config.pos_encodings == "learned":
|
||||||
|
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Reasoning Layers
|
||||||
|
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
|
||||||
|
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||||
|
|
||||||
|
# Initial states
|
||||||
|
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
|
||||||
|
# Q head special init
|
||||||
|
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||||
|
with torch.no_grad():
|
||||||
|
self.q_head.weight.zero_()
|
||||||
|
self.q_head.bias.fill_(-5) # type: ignore
|
||||||
|
|
||||||
|
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||||
|
# Token embedding
|
||||||
|
embedding = self.embed_tokens(input.to(torch.int32))
|
||||||
|
|
||||||
|
# Puzzle embeddings
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||||
|
|
||||||
|
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||||
|
if pad_count > 0:
|
||||||
|
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||||
|
|
||||||
|
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||||
|
|
||||||
|
# Position embeddings
|
||||||
|
if self.config.pos_encodings == "learned":
|
||||||
|
# scale by 1/sqrt(2) to maintain forward variance
|
||||||
|
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
return self.embed_scale * embedding
|
||||||
|
|
||||||
|
def empty_carry(self, batch_size: int):
|
||||||
|
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
|
||||||
|
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||||
|
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
seq_info = dict(
|
||||||
|
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input encoding
|
||||||
|
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||||
|
|
||||||
|
# Forward iterations
|
||||||
|
with torch.no_grad():
|
||||||
|
z_H, z_L = carry.z_H, carry.z_L
|
||||||
|
for _H_step in range(self.config.H_cycles):
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
|
||||||
|
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||||
|
if not (_H_step == self.config.H_cycles - 1):
|
||||||
|
z_H = self.H_level(z_H, z_L, **seq_info)
|
||||||
|
assert not z_H.requires_grad and not z_L.requires_grad
|
||||||
|
# 1-step grad
|
||||||
|
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||||
|
z_H = self.H_level(z_H, z_L, **seq_info)
|
||||||
|
|
||||||
|
# LM Outputs
|
||||||
|
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
||||||
|
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
||||||
|
|
||||||
|
# Q head
|
||||||
|
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
||||||
|
|
||||||
|
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||||
|
|
||||||
|
|
||||||
|
class HierarchicalReasoningModel_ACTV1(nn.Module):
|
||||||
|
"""ACT wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, config_dict: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
|
||||||
|
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def puzzle_emb(self):
|
||||||
|
return self.inner.puzzle_emb
|
||||||
|
|
||||||
|
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||||
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
|
return HierarchicalReasoningModel_ACTV1Carry(
|
||||||
|
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||||
|
|
||||||
|
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||||
|
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||||
|
|
||||||
|
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||||
|
# Update data, carry (removing halted sequences)
|
||||||
|
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||||
|
|
||||||
|
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||||
|
|
||||||
|
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||||
|
|
||||||
|
# Forward inner model
|
||||||
|
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"logits": logits,
|
||||||
|
"q_halt_logits": q_halt_logits,
|
||||||
|
"q_continue_logits": q_continue_logits
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Step
|
||||||
|
new_steps = new_steps + 1
|
||||||
|
is_last_step = new_steps >= self.config.halt_max_steps
|
||||||
|
|
||||||
|
halted = is_last_step
|
||||||
|
|
||||||
|
# if training, and ACT is enabled
|
||||||
|
if self.training and (self.config.halt_max_steps > 1):
|
||||||
|
# Halt signal
|
||||||
|
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||||
|
halted = halted | (q_halt_logits > q_continue_logits)
|
||||||
|
|
||||||
|
# Exploration
|
||||||
|
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||||
|
|
||||||
|
halted = halted & (new_steps >= min_halt_steps)
|
||||||
|
|
||||||
|
# Compute target Q
|
||||||
|
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||||
|
# As batch_size is large, there're many parallel envs.
|
||||||
|
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||||
|
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
|
||||||
|
|
||||||
|
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||||
|
|
||||||
|
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||||
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
HRM ACT V2: Transformer Baseline for Architecture Ablation
|
||||||
|
|
||||||
|
This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
|
||||||
|
Key changes from V1:
|
||||||
|
1. REMOVED hierarchical split (no separate H and L levels)
|
||||||
|
2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
|
||||||
|
3. KEPT ACT outer loop structure intact
|
||||||
|
4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
|
||||||
|
|
||||||
|
Architecture: Single-level transformer that processes the full 30x30 grid as a
|
||||||
|
900-token sequence, with the same positional encodings and sparse embeddings as V1.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||||
|
from models.sparse_embedding import CastedSparseEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Model_ACTV2InnerCarry:
|
||||||
|
z_H: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Model_ACTV2Carry:
|
||||||
|
inner_carry: Model_ACTV2InnerCarry
|
||||||
|
|
||||||
|
steps: torch.Tensor
|
||||||
|
halted: torch.Tensor
|
||||||
|
|
||||||
|
current_data: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class Model_ACTV2Config(BaseModel):
|
||||||
|
batch_size: int
|
||||||
|
seq_len: int
|
||||||
|
puzzle_emb_ndim: int = 0
|
||||||
|
num_puzzle_identifiers: int
|
||||||
|
vocab_size: int
|
||||||
|
|
||||||
|
H_cycles: int
|
||||||
|
|
||||||
|
H_layers: int
|
||||||
|
|
||||||
|
# Transformer config
|
||||||
|
hidden_size: int
|
||||||
|
expansion: float
|
||||||
|
num_heads: int
|
||||||
|
pos_encodings: str
|
||||||
|
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
# Halting Q-learning config
|
||||||
|
halt_max_steps: int
|
||||||
|
halt_exploration_prob: float
|
||||||
|
act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
|
||||||
|
act_inference: bool = False # If True, use adaptive computation during inference
|
||||||
|
|
||||||
|
forward_dtype: str = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
class Model_ACTV2Block(nn.Module):
|
||||||
|
def __init__(self, config: Model_ACTV2Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.self_attn = Attention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
head_dim=config.hidden_size // config.num_heads,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_key_value_heads=config.num_heads,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
self.mlp = SwiGLU(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
self.norm_eps = config.rms_norm_eps
|
||||||
|
|
||||||
|
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Post Norm
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = rms_norm(
|
||||||
|
hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
|
||||||
|
variance_epsilon=self.norm_eps,
|
||||||
|
)
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Model_ACTV2ReasoningModule(nn.Module):
|
||||||
|
def __init__(self, layers: List[Model_ACTV2Block]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
# Input injection (add)
|
||||||
|
hidden_states = hidden_states + input_injection
|
||||||
|
# Layers
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Model_ACTV2_Inner(nn.Module):
|
||||||
|
def __init__(self, config: Model_ACTV2Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||||
|
embed_init_std = 1.0 / self.embed_scale
|
||||||
|
|
||||||
|
self.embed_tokens = CastedEmbedding(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
init_std=embed_init_std,
|
||||||
|
cast_to=self.forward_dtype,
|
||||||
|
)
|
||||||
|
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
|
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||||
|
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
# Zero init puzzle embeddings
|
||||||
|
self.puzzle_emb = CastedSparseEmbedding(
|
||||||
|
self.config.num_puzzle_identifiers,
|
||||||
|
self.config.puzzle_emb_ndim,
|
||||||
|
batch_size=self.config.batch_size,
|
||||||
|
init_std=0,
|
||||||
|
cast_to=self.forward_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LM Blocks
|
||||||
|
if self.config.pos_encodings == "rope":
|
||||||
|
self.rotary_emb = RotaryEmbedding(
|
||||||
|
dim=self.config.hidden_size // self.config.num_heads,
|
||||||
|
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
base=self.config.rope_theta,
|
||||||
|
)
|
||||||
|
elif self.config.pos_encodings == "learned":
|
||||||
|
self.embed_pos = CastedEmbedding(
|
||||||
|
self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
self.config.hidden_size,
|
||||||
|
init_std=embed_init_std,
|
||||||
|
cast_to=self.forward_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# Reasoning Layers
|
||||||
|
self.H_level = Model_ACTV2ReasoningModule(
|
||||||
|
layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial states
|
||||||
|
self.H_init = nn.Buffer(
|
||||||
|
trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
|
||||||
|
persistent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Q head special init
|
||||||
|
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||||
|
with torch.no_grad():
|
||||||
|
self.q_head.weight.zero_()
|
||||||
|
self.q_head.bias.fill_(-5) # type: ignore
|
||||||
|
|
||||||
|
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||||
|
# Token embedding
|
||||||
|
embedding = self.embed_tokens(input.to(torch.int32))
|
||||||
|
|
||||||
|
# Puzzle embeddings
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||||
|
|
||||||
|
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||||
|
if pad_count > 0:
|
||||||
|
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||||
|
|
||||||
|
embedding = torch.cat(
|
||||||
|
(puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Position embeddings
|
||||||
|
if self.config.pos_encodings == "learned":
|
||||||
|
# scale by 1/sqrt(2) to maintain forward variance
|
||||||
|
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
return self.embed_scale * embedding
|
||||||
|
|
||||||
|
def empty_carry(self, batch_size: int):
|
||||||
|
return Model_ACTV2InnerCarry(
|
||||||
|
z_H=torch.empty(
|
||||||
|
batch_size,
|
||||||
|
self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.forward_dtype,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
|
||||||
|
return Model_ACTV2InnerCarry(
|
||||||
|
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
|
||||||
|
) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
seq_info = dict(
|
||||||
|
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input encoding
|
||||||
|
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||||
|
|
||||||
|
# 1-step grad
|
||||||
|
z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
|
||||||
|
|
||||||
|
# LM Outputs
|
||||||
|
new_carry = Model_ACTV2InnerCarry(
|
||||||
|
z_H=z_H.detach(),
|
||||||
|
) # New carry no grad
|
||||||
|
output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
|
||||||
|
|
||||||
|
# Q head
|
||||||
|
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
||||||
|
|
||||||
|
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||||
|
|
||||||
|
|
||||||
|
class Model_ACTV2(nn.Module):
|
||||||
|
"""ACT wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, config_dict: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.config = Model_ACTV2Config(**config_dict)
|
||||||
|
self.inner = Model_ACTV2_Inner(self.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def puzzle_emb(self):
|
||||||
|
return self.inner.puzzle_emb
|
||||||
|
|
||||||
|
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||||
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
|
return Model_ACTV2Carry(
|
||||||
|
inner_carry=self.inner.empty_carry(
|
||||||
|
batch_size
|
||||||
|
), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||||
|
steps=torch.zeros((batch_size,), dtype=torch.int32),
|
||||||
|
halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
|
||||||
|
current_data={k: torch.empty_like(v) for k, v in batch.items()},
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
carry: Model_ACTV2Carry,
|
||||||
|
batch: Dict[str, torch.Tensor],
|
||||||
|
compute_target_q: bool = False,
|
||||||
|
) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
|
||||||
|
# Update data, carry (removing halted sequences)
|
||||||
|
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||||
|
|
||||||
|
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||||
|
|
||||||
|
new_current_data = {
|
||||||
|
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
|
||||||
|
for k, v in carry.current_data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Forward inner model
|
||||||
|
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
|
||||||
|
new_inner_carry, new_current_data
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Step
|
||||||
|
new_steps = new_steps + 1
|
||||||
|
is_last_step = new_steps >= self.config.halt_max_steps
|
||||||
|
|
||||||
|
halted = is_last_step
|
||||||
|
|
||||||
|
# Check if adaptive computation should be used
|
||||||
|
use_adaptive = (self.config.halt_max_steps > 1) and (
|
||||||
|
(self.training and self.config.act_enabled)
|
||||||
|
or (not self.training and self.config.act_inference)
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_adaptive:
|
||||||
|
# Halt signal based on Q-values (but always halt at max steps)
|
||||||
|
q_halt_signal = q_halt_logits > q_continue_logits
|
||||||
|
halted = halted | q_halt_signal
|
||||||
|
|
||||||
|
# Store actual steps used for logging (only during inference)
|
||||||
|
if not self.training:
|
||||||
|
outputs["actual_steps"] = new_steps.float()
|
||||||
|
|
||||||
|
# Exploration (only during training)
|
||||||
|
if self.training:
|
||||||
|
min_halt_steps = (
|
||||||
|
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
|
||||||
|
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||||
|
halted = halted & (new_steps >= min_halt_steps)
|
||||||
|
|
||||||
|
# Compute target Q (only during training)
|
||||||
|
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||||
|
# As batch_size is large, there're many parallel envs.
|
||||||
|
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||||
|
if self.training and compute_target_q:
|
||||||
|
next_q_halt_logits, next_q_continue_logits = self.inner(
|
||||||
|
new_inner_carry, new_current_data
|
||||||
|
)[-1]
|
||||||
|
|
||||||
|
outputs["target_q_continue"] = torch.sigmoid(
|
||||||
|
torch.where(
|
||||||
|
is_last_step,
|
||||||
|
next_q_halt_logits,
|
||||||
|
torch.maximum(next_q_halt_logits, next_q_continue_logits),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Model_ACTV2Carry(
|
||||||
|
new_inner_carry, new_steps, halted, new_current_data
|
||||||
|
), outputs
|
||||||
@@ -0,0 +1,297 @@
|
|||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import random
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||||
|
from models.sparse_embedding import CastedSparseEmbedding
|
||||||
|
|
||||||
|
IGNORE_LABEL_ID = -100
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
||||||
|
z_H: torch.Tensor
|
||||||
|
z_L: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Carry:
|
||||||
|
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
||||||
|
|
||||||
|
steps: torch.Tensor
|
||||||
|
halted: torch.Tensor
|
||||||
|
|
||||||
|
current_data: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
||||||
|
batch_size: int
|
||||||
|
seq_len: int
|
||||||
|
puzzle_emb_ndim: int = 0
|
||||||
|
num_puzzle_identifiers: int
|
||||||
|
vocab_size: int
|
||||||
|
|
||||||
|
H_cycles: int
|
||||||
|
L_cycles: int
|
||||||
|
|
||||||
|
H_layers: int # ignored
|
||||||
|
L_layers: int
|
||||||
|
|
||||||
|
# Transformer config
|
||||||
|
hidden_size: int
|
||||||
|
expansion: float
|
||||||
|
num_heads: int
|
||||||
|
pos_encodings: str
|
||||||
|
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
# Halting Q-learning config
|
||||||
|
halt_max_steps: int
|
||||||
|
halt_exploration_prob: float
|
||||||
|
|
||||||
|
forward_dtype: str = "bfloat16"
|
||||||
|
|
||||||
|
# Alexia: added
|
||||||
|
mlp_t: bool = False # use mlp on L instead of transformer
|
||||||
|
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
||||||
|
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
||||||
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
if self.config.mlp_t:
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
||||||
|
self.mlp_t = SwiGLU(
|
||||||
|
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = Attention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
head_dim=config.hidden_size // config.num_heads,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_key_value_heads=config.num_heads,
|
||||||
|
causal=False
|
||||||
|
)
|
||||||
|
self.mlp = SwiGLU(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
self.norm_eps = config.rms_norm_eps
|
||||||
|
|
||||||
|
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# B, L, D = hidden_states.shape
|
||||||
|
# Post Norm
|
||||||
|
if self.config.mlp_t:
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
out = self.mlp_t(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
else:
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||||
|
# Fully Connected
|
||||||
|
out = self.mlp(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||||
|
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
hidden_states = hidden_states + input_injection
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
||||||
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
|
||||||
|
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||||
|
embed_init_std = 1.0 / self.embed_scale
|
||||||
|
|
||||||
|
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
|
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||||
|
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
# Zero init puzzle embeddings
|
||||||
|
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||||
|
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||||
|
|
||||||
|
# LM Blocks
|
||||||
|
if self.config.pos_encodings == "rope":
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||||
|
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
base=self.config.rope_theta)
|
||||||
|
elif self.config.pos_encodings == "learned":
|
||||||
|
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Reasoning Layers
|
||||||
|
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||||
|
|
||||||
|
# Initial states
|
||||||
|
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
|
||||||
|
# Q head special init
|
||||||
|
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||||
|
with torch.no_grad():
|
||||||
|
self.q_head.weight.zero_()
|
||||||
|
self.q_head.bias.fill_(-5) # type: ignore
|
||||||
|
|
||||||
|
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||||
|
# Token embedding
|
||||||
|
embedding = self.embed_tokens(input.to(torch.int32))
|
||||||
|
|
||||||
|
# Puzzle embeddings
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||||
|
|
||||||
|
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||||
|
if pad_count > 0:
|
||||||
|
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||||
|
|
||||||
|
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||||
|
|
||||||
|
# Position embeddings
|
||||||
|
if self.config.pos_encodings == "learned":
|
||||||
|
# scale by 1/sqrt(2) to maintain forward variance
|
||||||
|
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
return self.embed_scale * embedding
|
||||||
|
|
||||||
|
def empty_carry(self, batch_size: int):
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||||
|
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
seq_info = dict(
|
||||||
|
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input encoding
|
||||||
|
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||||
|
|
||||||
|
# Forward iterations
|
||||||
|
it = 0
|
||||||
|
z_H, z_L = carry.z_H, carry.z_L
|
||||||
|
# H_cycles-1 without grad
|
||||||
|
with torch.no_grad():
|
||||||
|
for _H_step in range(self.config.H_cycles-1):
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||||
|
z_H = self.L_level(z_H, z_L, **seq_info)
|
||||||
|
# 1 with grad
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||||
|
z_H = self.L_level(z_H, z_L, **seq_info)
|
||||||
|
|
||||||
|
# LM Outputs
|
||||||
|
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
||||||
|
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
||||||
|
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
||||||
|
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
||||||
|
"""ACT wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, config_dict: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
||||||
|
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def puzzle_emb(self):
|
||||||
|
return self.inner.puzzle_emb
|
||||||
|
|
||||||
|
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||||
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1Carry(
|
||||||
|
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||||
|
|
||||||
|
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||||
|
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||||
|
|
||||||
|
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
# Update data, carry (removing halted sequences)
|
||||||
|
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||||
|
|
||||||
|
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||||
|
|
||||||
|
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||||
|
|
||||||
|
# Forward inner model
|
||||||
|
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"logits": logits,
|
||||||
|
"q_halt_logits": q_halt_logits,
|
||||||
|
"q_continue_logits": q_continue_logits
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Step
|
||||||
|
new_steps = new_steps + 1
|
||||||
|
is_last_step = new_steps >= self.config.halt_max_steps
|
||||||
|
|
||||||
|
halted = is_last_step
|
||||||
|
|
||||||
|
# if training, and ACT is enabled
|
||||||
|
if self.training and (self.config.halt_max_steps > 1):
|
||||||
|
|
||||||
|
# Halt signal
|
||||||
|
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||||
|
|
||||||
|
if self.config.no_ACT_continue:
|
||||||
|
halted = halted | (q_halt_logits > 0)
|
||||||
|
else:
|
||||||
|
halted = halted | (q_halt_logits > q_continue_logits)
|
||||||
|
|
||||||
|
# Exploration
|
||||||
|
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||||
|
halted = halted & (new_steps >= min_halt_steps)
|
||||||
|
|
||||||
|
if not self.config.no_ACT_continue:
|
||||||
|
# Compute target Q
|
||||||
|
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||||
|
# As batch_size is large, there're many parallel envs.
|
||||||
|
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||||
|
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
||||||
|
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||||
|
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||||
@@ -0,0 +1,323 @@
|
|||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import random
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||||
|
from models.sparse_embedding import CastedSparseEmbedding
|
||||||
|
|
||||||
|
IGNORE_LABEL_ID = -100
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
||||||
|
z_H: torch.Tensor
|
||||||
|
z_L1: torch.Tensor
|
||||||
|
z_L2: torch.Tensor
|
||||||
|
z_L3: torch.Tensor
|
||||||
|
z_L4: torch.Tensor
|
||||||
|
z_L5: torch.Tensor
|
||||||
|
z_L6: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Carry:
|
||||||
|
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
||||||
|
|
||||||
|
steps: torch.Tensor
|
||||||
|
halted: torch.Tensor
|
||||||
|
|
||||||
|
current_data: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
||||||
|
batch_size: int
|
||||||
|
seq_len: int
|
||||||
|
puzzle_emb_ndim: int = 0
|
||||||
|
num_puzzle_identifiers: int
|
||||||
|
vocab_size: int
|
||||||
|
|
||||||
|
H_cycles: int
|
||||||
|
L_cycles: int
|
||||||
|
|
||||||
|
H_layers: int # ignored
|
||||||
|
L_layers: int
|
||||||
|
|
||||||
|
# Transformer config
|
||||||
|
hidden_size: int
|
||||||
|
expansion: float
|
||||||
|
num_heads: int
|
||||||
|
pos_encodings: str
|
||||||
|
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
# Halting Q-learning config
|
||||||
|
halt_max_steps: int
|
||||||
|
halt_exploration_prob: float
|
||||||
|
|
||||||
|
forward_dtype: str = "bfloat16"
|
||||||
|
|
||||||
|
# Alexia: added
|
||||||
|
mlp_t: bool = False # use mlp on L instead of transformer
|
||||||
|
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
||||||
|
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
||||||
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
if self.config.mlp_t:
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
||||||
|
self.mlp_t = SwiGLU(
|
||||||
|
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = Attention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
head_dim=config.hidden_size // config.num_heads,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_key_value_heads=config.num_heads,
|
||||||
|
causal=False
|
||||||
|
)
|
||||||
|
self.mlp = SwiGLU(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
self.norm_eps = config.rms_norm_eps
|
||||||
|
|
||||||
|
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# B, L, D = hidden_states.shape
|
||||||
|
# Post Norm
|
||||||
|
if self.config.mlp_t:
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
out = self.mlp_t(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
else:
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||||
|
# Fully Connected
|
||||||
|
out = self.mlp(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||||
|
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
hidden_states = hidden_states + input_injection
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
||||||
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
|
||||||
|
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||||
|
embed_init_std = 1.0 / self.embed_scale
|
||||||
|
|
||||||
|
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
|
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||||
|
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
# Zero init puzzle embeddings
|
||||||
|
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||||
|
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||||
|
|
||||||
|
# LM Blocks
|
||||||
|
if self.config.pos_encodings == "rope":
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||||
|
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
base=self.config.rope_theta)
|
||||||
|
elif self.config.pos_encodings == "learned":
|
||||||
|
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Reasoning Layers
|
||||||
|
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||||
|
|
||||||
|
# Initial states
|
||||||
|
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
|
||||||
|
# Q head special init
|
||||||
|
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||||
|
with torch.no_grad():
|
||||||
|
self.q_head.weight.zero_()
|
||||||
|
self.q_head.bias.fill_(-5) # type: ignore
|
||||||
|
|
||||||
|
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||||
|
# Token embedding
|
||||||
|
embedding = self.embed_tokens(input.to(torch.int32))
|
||||||
|
|
||||||
|
# Puzzle embeddings
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||||
|
|
||||||
|
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||||
|
if pad_count > 0:
|
||||||
|
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||||
|
|
||||||
|
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||||
|
|
||||||
|
# Position embeddings
|
||||||
|
if self.config.pos_encodings == "learned":
|
||||||
|
# scale by 1/sqrt(2) to maintain forward variance
|
||||||
|
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
return self.embed_scale * embedding
|
||||||
|
|
||||||
|
def empty_carry(self, batch_size: int):
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||||
|
z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1),
|
||||||
|
z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2),
|
||||||
|
z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3),
|
||||||
|
z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4),
|
||||||
|
z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5),
|
||||||
|
z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
seq_info = dict(
|
||||||
|
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input encoding
|
||||||
|
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||||
|
|
||||||
|
# Forward iterations
|
||||||
|
it = 0
|
||||||
|
z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6]
|
||||||
|
# H_cycles-1 without grad
|
||||||
|
with torch.no_grad():
|
||||||
|
for _H_step in range(self.config.H_cycles-1):
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||||
|
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
|
||||||
|
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||||
|
z_H = self.L_level(z_H, z_L_, **seq_info)
|
||||||
|
# 1 with grad
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||||
|
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
|
||||||
|
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||||
|
z_H = self.L_level(z_H, z_L_, **seq_info)
|
||||||
|
|
||||||
|
# LM Outputs
|
||||||
|
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad
|
||||||
|
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
||||||
|
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
||||||
|
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
||||||
|
"""ACT wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, config_dict: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
||||||
|
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def puzzle_emb(self):
|
||||||
|
return self.inner.puzzle_emb
|
||||||
|
|
||||||
|
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||||
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1Carry(
|
||||||
|
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||||
|
|
||||||
|
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||||
|
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||||
|
|
||||||
|
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
# Update data, carry (removing halted sequences)
|
||||||
|
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||||
|
|
||||||
|
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||||
|
|
||||||
|
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||||
|
|
||||||
|
# Forward inner model
|
||||||
|
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"logits": logits,
|
||||||
|
"q_halt_logits": q_halt_logits,
|
||||||
|
"q_continue_logits": q_continue_logits
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Step
|
||||||
|
new_steps = new_steps + 1
|
||||||
|
is_last_step = new_steps >= self.config.halt_max_steps
|
||||||
|
|
||||||
|
halted = is_last_step
|
||||||
|
|
||||||
|
# if training, and ACT is enabled
|
||||||
|
if self.training and (self.config.halt_max_steps > 1):
|
||||||
|
|
||||||
|
# Halt signal
|
||||||
|
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||||
|
|
||||||
|
if self.config.no_ACT_continue:
|
||||||
|
halted = halted | (q_halt_logits > 0)
|
||||||
|
else:
|
||||||
|
halted = halted | (q_halt_logits > q_continue_logits)
|
||||||
|
|
||||||
|
# Exploration
|
||||||
|
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||||
|
halted = halted & (new_steps >= min_halt_steps)
|
||||||
|
|
||||||
|
if not self.config.no_ACT_continue:
|
||||||
|
# Compute target Q
|
||||||
|
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||||
|
# As batch_size is large, there're many parallel envs.
|
||||||
|
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||||
|
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
||||||
|
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||||
|
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||||
@@ -0,0 +1,294 @@
|
|||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import random
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||||
|
from models.sparse_embedding import CastedSparseEmbedding
|
||||||
|
|
||||||
|
IGNORE_LABEL_ID = -100
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
||||||
|
z_L: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Carry:
|
||||||
|
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
||||||
|
|
||||||
|
steps: torch.Tensor
|
||||||
|
halted: torch.Tensor
|
||||||
|
|
||||||
|
current_data: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
||||||
|
batch_size: int
|
||||||
|
seq_len: int
|
||||||
|
puzzle_emb_ndim: int = 0
|
||||||
|
num_puzzle_identifiers: int
|
||||||
|
vocab_size: int
|
||||||
|
|
||||||
|
H_cycles: int
|
||||||
|
L_cycles: int
|
||||||
|
|
||||||
|
H_layers: int # ignored
|
||||||
|
L_layers: int
|
||||||
|
|
||||||
|
# Transformer config
|
||||||
|
hidden_size: int
|
||||||
|
expansion: float
|
||||||
|
num_heads: int
|
||||||
|
pos_encodings: str
|
||||||
|
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
|
||||||
|
# Halting Q-learning config
|
||||||
|
halt_max_steps: int
|
||||||
|
halt_exploration_prob: float
|
||||||
|
|
||||||
|
forward_dtype: str = "bfloat16"
|
||||||
|
|
||||||
|
# Alexia: added
|
||||||
|
mlp_t: bool = False # use mlp on L instead of transformer
|
||||||
|
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
||||||
|
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
||||||
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
if self.config.mlp_t:
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
||||||
|
self.mlp_t = SwiGLU(
|
||||||
|
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = Attention(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
head_dim=config.hidden_size // config.num_heads,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
num_key_value_heads=config.num_heads,
|
||||||
|
causal=False
|
||||||
|
)
|
||||||
|
self.mlp = SwiGLU(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
expansion=config.expansion,
|
||||||
|
)
|
||||||
|
self.norm_eps = config.rms_norm_eps
|
||||||
|
|
||||||
|
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# B, L, D = hidden_states.shape
|
||||||
|
# Post Norm
|
||||||
|
if self.config.mlp_t:
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
out = self.mlp_t(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
hidden_states = hidden_states.transpose(1,2)
|
||||||
|
else:
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||||
|
# Fully Connected
|
||||||
|
out = self.mlp(hidden_states)
|
||||||
|
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||||
|
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
||||||
|
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||||
|
|
||||||
|
# I/O
|
||||||
|
|
||||||
|
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||||
|
embed_init_std = 1.0 / self.embed_scale
|
||||||
|
|
||||||
|
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||||
|
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||||
|
|
||||||
|
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
# Zero init puzzle embeddings
|
||||||
|
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||||
|
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||||
|
|
||||||
|
# LM Blocks
|
||||||
|
if self.config.pos_encodings == "rope":
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||||
|
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||||
|
base=self.config.rope_theta)
|
||||||
|
elif self.config.pos_encodings == "learned":
|
||||||
|
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Reasoning Layers
|
||||||
|
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||||
|
|
||||||
|
# Initial states
|
||||||
|
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||||
|
|
||||||
|
# Q head special init
|
||||||
|
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||||
|
with torch.no_grad():
|
||||||
|
self.q_head.weight.zero_()
|
||||||
|
self.q_head.bias.fill_(-5) # type: ignore
|
||||||
|
|
||||||
|
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||||
|
# Token embedding
|
||||||
|
embedding = self.embed_tokens(input.to(torch.int32))
|
||||||
|
|
||||||
|
# Puzzle embeddings
|
||||||
|
if self.config.puzzle_emb_ndim > 0:
|
||||||
|
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||||
|
|
||||||
|
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||||
|
if pad_count > 0:
|
||||||
|
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||||
|
|
||||||
|
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||||
|
|
||||||
|
# Position embeddings
|
||||||
|
if self.config.pos_encodings == "learned":
|
||||||
|
# scale by 1/sqrt(2) to maintain forward variance
|
||||||
|
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||||
|
|
||||||
|
# Scale
|
||||||
|
return self.embed_scale * embedding
|
||||||
|
|
||||||
|
def empty_carry(self, batch_size: int):
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||||
|
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
seq_info = dict(
|
||||||
|
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input encoding
|
||||||
|
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||||
|
|
||||||
|
# Forward iterations
|
||||||
|
it = 0
|
||||||
|
z_L = carry.z_L
|
||||||
|
# H_cycles-1 without grad
|
||||||
|
with torch.no_grad():
|
||||||
|
for _H_step in range(self.config.H_cycles-1):
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
z_L = self.L_level(z_L + input_embeddings, **seq_info)
|
||||||
|
z_L = self.L_level(z_L, **seq_info)
|
||||||
|
# 1 with grad
|
||||||
|
for _L_step in range(self.config.L_cycles):
|
||||||
|
z_L = self.L_level(z_L + input_embeddings, **seq_info)
|
||||||
|
z_L = self.L_level(z_L, **seq_info)
|
||||||
|
z_out = z_L
|
||||||
|
|
||||||
|
# LM Outputs
|
||||||
|
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad
|
||||||
|
output = self.lm_head(z_out)[:, self.puzzle_emb_len:]
|
||||||
|
q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
||||||
|
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||||
|
|
||||||
|
|
||||||
|
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
||||||
|
"""ACT wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, config_dict: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
||||||
|
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def puzzle_emb(self):
|
||||||
|
return self.inner.puzzle_emb
|
||||||
|
|
||||||
|
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||||
|
batch_size = batch["inputs"].shape[0]
|
||||||
|
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1Carry(
|
||||||
|
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||||
|
|
||||||
|
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||||
|
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||||
|
|
||||||
|
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
# Update data, carry (removing halted sequences)
|
||||||
|
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||||
|
|
||||||
|
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||||
|
|
||||||
|
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||||
|
|
||||||
|
# Forward inner model
|
||||||
|
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"logits": logits,
|
||||||
|
"q_halt_logits": q_halt_logits,
|
||||||
|
"q_continue_logits": q_continue_logits
|
||||||
|
}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Step
|
||||||
|
new_steps = new_steps + 1
|
||||||
|
is_last_step = new_steps >= self.config.halt_max_steps
|
||||||
|
|
||||||
|
halted = is_last_step
|
||||||
|
|
||||||
|
# if training, and ACT is enabled
|
||||||
|
if self.training and (self.config.halt_max_steps > 1):
|
||||||
|
|
||||||
|
# Halt signal
|
||||||
|
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||||
|
|
||||||
|
if self.config.no_ACT_continue:
|
||||||
|
halted = halted | (q_halt_logits > 0)
|
||||||
|
else:
|
||||||
|
halted = halted | (q_halt_logits > q_continue_logits)
|
||||||
|
|
||||||
|
# Exploration
|
||||||
|
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||||
|
halted = halted & (new_steps >= min_halt_steps)
|
||||||
|
|
||||||
|
if not self.config.no_ACT_continue:
|
||||||
|
# Compute target Q
|
||||||
|
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||||
|
# As batch_size is large, there're many parallel envs.
|
||||||
|
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||||
|
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
||||||
|
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||||
|
|
||||||
|
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.optim.optimizer import Optimizer, ParamsT
|
||||||
|
|
||||||
|
from models.common import trunc_normal_init_
|
||||||
|
|
||||||
|
|
||||||
|
class CastedSparseEmbedding(nn.Module):
|
||||||
|
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.cast_to = cast_to
|
||||||
|
|
||||||
|
# Real Weights
|
||||||
|
# Truncated LeCun normal init
|
||||||
|
self.weights = nn.Buffer(
|
||||||
|
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Local weights and IDs
|
||||||
|
# Local embeddings, with gradient, not persistent
|
||||||
|
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
|
||||||
|
# Local embedding IDs, not persistent
|
||||||
|
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.training:
|
||||||
|
# Test mode, no gradient
|
||||||
|
return self.weights[inputs].to(self.cast_to)
|
||||||
|
|
||||||
|
# Training mode, fill puzzle embedding from weights
|
||||||
|
with torch.no_grad():
|
||||||
|
self.local_weights.copy_(self.weights[inputs])
|
||||||
|
self.local_ids.copy_(inputs)
|
||||||
|
|
||||||
|
return self.local_weights.to(self.cast_to)
|
||||||
|
|
||||||
|
|
||||||
|
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: ParamsT,
|
||||||
|
|
||||||
|
world_size: int,
|
||||||
|
lr: Union[float, torch.Tensor] = 1e-3,
|
||||||
|
weight_decay: float = 1e-2,
|
||||||
|
):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
world_size=world_size
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def step(self, closure=None): # type: ignore
|
||||||
|
for group in self.param_groups:
|
||||||
|
# Find the sparse embedding weights
|
||||||
|
local_weights_grad = None
|
||||||
|
local_ids = None
|
||||||
|
weights = None
|
||||||
|
|
||||||
|
assert len(group["params"]) == 3
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.requires_grad:
|
||||||
|
local_weights_grad = p.grad
|
||||||
|
elif p.ndim == 1:
|
||||||
|
local_ids = p
|
||||||
|
elif p.ndim == 2:
|
||||||
|
weights = p
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
assert local_ids is not None
|
||||||
|
assert weights is not None
|
||||||
|
|
||||||
|
# Apply SignSGD
|
||||||
|
# Adam ≈ SignSGD if gradient is very sparse
|
||||||
|
if local_weights_grad is not None:
|
||||||
|
_sparse_emb_signsgd_dist(
|
||||||
|
local_weights_grad,
|
||||||
|
local_ids,
|
||||||
|
weights,
|
||||||
|
|
||||||
|
lr=group["lr"],
|
||||||
|
weight_decay=group["weight_decay"],
|
||||||
|
world_size=group["world_size"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sparse_emb_signsgd_dist(
|
||||||
|
local_weights_grad: torch.Tensor,
|
||||||
|
local_ids: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
|
||||||
|
lr: float,
|
||||||
|
weight_decay: float,
|
||||||
|
world_size: int
|
||||||
|
) -> None:
|
||||||
|
N, D = local_weights_grad.shape
|
||||||
|
|
||||||
|
# All-gather
|
||||||
|
all_weights_grad = local_weights_grad
|
||||||
|
all_ids = local_ids
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
|
||||||
|
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
|
||||||
|
|
||||||
|
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
|
||||||
|
dist.all_gather_into_tensor(all_ids, local_ids)
|
||||||
|
|
||||||
|
# Unique
|
||||||
|
grad_ids, inv = all_ids.unique(return_inverse=True)
|
||||||
|
|
||||||
|
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
|
||||||
|
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
|
||||||
|
|
||||||
|
# SignSGD with decoupled weight decay
|
||||||
|
p = weights[grad_ids]
|
||||||
|
|
||||||
|
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
|
||||||
|
|
||||||
|
# Write updated slices back
|
||||||
|
weights[grad_ids] = p
|
||||||
+654
@@ -0,0 +1,654 @@
|
|||||||
|
from typing import Optional, Any, Sequence, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import yaml
|
||||||
|
import shutil
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
import wandb
|
||||||
|
import coolname
|
||||||
|
import hydra
|
||||||
|
import pydantic
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from adam_atan2 import AdamATan2
|
||||||
|
|
||||||
|
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
|
||||||
|
from utils.functions import load_model_class, get_model_source_path
|
||||||
|
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
|
||||||
|
from models.ema import EMAHelper
|
||||||
|
|
||||||
|
|
||||||
|
class LossConfig(pydantic.BaseModel):
|
||||||
|
model_config = pydantic.ConfigDict(extra='allow')
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ArchConfig(pydantic.BaseModel):
|
||||||
|
model_config = pydantic.ConfigDict(extra='allow')
|
||||||
|
name: str
|
||||||
|
loss: LossConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorConfig(pydantic.BaseModel):
|
||||||
|
model_config = pydantic.ConfigDict(extra="allow")
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainConfig(pydantic.BaseModel):
|
||||||
|
# Config
|
||||||
|
arch: ArchConfig
|
||||||
|
# Data
|
||||||
|
data_paths: List[str]
|
||||||
|
data_paths_test: List[str] = []
|
||||||
|
# Evaluators
|
||||||
|
evaluators: List[EvaluatorConfig] = []
|
||||||
|
|
||||||
|
# Hyperparams
|
||||||
|
global_batch_size: int
|
||||||
|
epochs: int
|
||||||
|
|
||||||
|
lr: float
|
||||||
|
lr_min_ratio: float
|
||||||
|
lr_warmup_steps: int
|
||||||
|
|
||||||
|
weight_decay: float
|
||||||
|
beta1: float
|
||||||
|
beta2: float
|
||||||
|
|
||||||
|
# Puzzle embedding
|
||||||
|
puzzle_emb_lr: float
|
||||||
|
puzzle_emb_weight_decay: float
|
||||||
|
|
||||||
|
# Names
|
||||||
|
project_name: Optional[str] = None
|
||||||
|
run_name: Optional[str] = None
|
||||||
|
load_checkpoint: Optional[str] = None
|
||||||
|
checkpoint_path: Optional[str] = None
|
||||||
|
|
||||||
|
# Extras
|
||||||
|
seed: int = 0
|
||||||
|
checkpoint_every_eval: bool = False
|
||||||
|
eval_interval: Optional[int] = None
|
||||||
|
min_eval_interval: Optional[int] = 0 # when to start eval
|
||||||
|
eval_save_outputs: List[str] = []
|
||||||
|
|
||||||
|
ema: bool = False # use Exponential-Moving-Average
|
||||||
|
ema_rate: float = 0.999 # EMA-rate
|
||||||
|
freeze_weights: bool = False # If True, freeze weights and only learn the embeddings
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainState:
|
||||||
|
model: nn.Module
|
||||||
|
optimizers: Sequence[torch.optim.Optimizer]
|
||||||
|
optimizer_lrs: Sequence[float]
|
||||||
|
carry: Any
|
||||||
|
|
||||||
|
step: int
|
||||||
|
total_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
|
||||||
|
dataset = PuzzleDataset(PuzzleDatasetConfig(
|
||||||
|
seed=config.seed,
|
||||||
|
dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths,
|
||||||
|
rank=rank,
|
||||||
|
num_replicas=world_size,
|
||||||
|
**kwargs
|
||||||
|
), split=split)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=1,
|
||||||
|
prefetch_factor=8,
|
||||||
|
pin_memory=True,
|
||||||
|
persistent_workers=True
|
||||||
|
)
|
||||||
|
return dataloader, dataset.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
|
||||||
|
model_cfg = dict(
|
||||||
|
**config.arch.__pydantic_extra__, # type: ignore
|
||||||
|
batch_size=config.global_batch_size // world_size,
|
||||||
|
vocab_size=train_metadata.vocab_size,
|
||||||
|
seq_len=train_metadata.seq_len,
|
||||||
|
num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
|
||||||
|
causal=False # Non-autoregressive
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiate model with loss head
|
||||||
|
model_cls = load_model_class(config.arch.name)
|
||||||
|
loss_head_cls = load_model_class(config.arch.loss.name)
|
||||||
|
|
||||||
|
with torch.device("cuda"):
|
||||||
|
model: nn.Module = model_cls(model_cfg)
|
||||||
|
print(model)
|
||||||
|
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
|
||||||
|
if "DISABLE_COMPILE" not in os.environ:
|
||||||
|
model = torch.compile(model) # type: ignore
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
if rank == 0:
|
||||||
|
load_checkpoint(model, config)
|
||||||
|
|
||||||
|
# Broadcast parameters from rank 0
|
||||||
|
if world_size > 1:
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in list(model.parameters()) + list(model.buffers()):
|
||||||
|
dist.broadcast(param, src=0)
|
||||||
|
|
||||||
|
# Optimizers and lr
|
||||||
|
if config.arch.puzzle_emb_ndim == 0:
|
||||||
|
optimizers = [
|
||||||
|
AdamATan2(
|
||||||
|
model.parameters(),
|
||||||
|
lr=0, # Needs to be set by scheduler
|
||||||
|
weight_decay=config.weight_decay,
|
||||||
|
betas=(config.beta1, config.beta2)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
optimizer_lrs = [
|
||||||
|
config.lr
|
||||||
|
]
|
||||||
|
elif config.freeze_weights:
|
||||||
|
optimizers = [
|
||||||
|
CastedSparseEmbeddingSignSGD_Distributed(
|
||||||
|
model.model.puzzle_emb.buffers(), # type: ignore
|
||||||
|
lr=0, # Needs to be set by scheduler
|
||||||
|
weight_decay=config.puzzle_emb_weight_decay,
|
||||||
|
world_size=world_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
optimizer_lrs = [
|
||||||
|
config.puzzle_emb_lr
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
optimizers = [
|
||||||
|
CastedSparseEmbeddingSignSGD_Distributed(
|
||||||
|
model.model.puzzle_emb.buffers(), # type: ignore
|
||||||
|
lr=0, # Needs to be set by scheduler
|
||||||
|
weight_decay=config.puzzle_emb_weight_decay,
|
||||||
|
world_size=world_size
|
||||||
|
),
|
||||||
|
AdamATan2(
|
||||||
|
model.parameters(),
|
||||||
|
lr=0, # Needs to be set by scheduler
|
||||||
|
weight_decay=config.weight_decay,
|
||||||
|
betas=(config.beta1, config.beta2)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
optimizer_lrs = [
|
||||||
|
config.puzzle_emb_lr,
|
||||||
|
config.lr
|
||||||
|
]
|
||||||
|
|
||||||
|
return model, optimizers, optimizer_lrs
|
||||||
|
|
||||||
|
def mix_weights_direct(device, alpha, net, nets):
|
||||||
|
sd = []
|
||||||
|
for i in range(len(nets)):
|
||||||
|
sd += [nets[i].state_dict()]
|
||||||
|
sd_alpha = {}
|
||||||
|
for k in sd[0].keys():
|
||||||
|
comb_net = alpha[0]*sd[0][k].to(device)
|
||||||
|
for i in range(1,len(nets)):
|
||||||
|
comb_net += alpha[i]*sd[i][k].to(device)
|
||||||
|
sd_alpha[k] = comb_net
|
||||||
|
net.load_state_dict(sd_alpha)
|
||||||
|
return net
|
||||||
|
|
||||||
|
def cosine_schedule_with_warmup_lr_lambda(
|
||||||
|
current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
|
||||||
|
):
|
||||||
|
if current_step < num_warmup_steps:
|
||||||
|
return base_lr * float(current_step) / float(max(1, num_warmup_steps))
|
||||||
|
|
||||||
|
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||||
|
return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
|
||||||
|
|
||||||
|
|
||||||
|
def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
|
||||||
|
# Estimated total training steps
|
||||||
|
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
return TrainState(
|
||||||
|
step=0,
|
||||||
|
total_steps=total_steps,
|
||||||
|
|
||||||
|
model=model,
|
||||||
|
optimizers=optimizers,
|
||||||
|
optimizer_lrs=optimizer_lrs,
|
||||||
|
carry=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_train_state(config: PretrainConfig, train_state: TrainState):
|
||||||
|
# FIXME: Only saved model.
|
||||||
|
if config.checkpoint_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(config.checkpoint_path, exist_ok=True)
|
||||||
|
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model: nn.Module, config: PretrainConfig):
|
||||||
|
if config.load_checkpoint is not None:
|
||||||
|
print(f"Loading checkpoint {config.load_checkpoint}")
|
||||||
|
|
||||||
|
# Load state dict
|
||||||
|
state_dict = torch.load(config.load_checkpoint, map_location="cuda")
|
||||||
|
|
||||||
|
# Resize and reset puzzle emb if needed
|
||||||
|
puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
|
||||||
|
expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore
|
||||||
|
if puzzle_emb_name in state_dict:
|
||||||
|
puzzle_emb = state_dict[puzzle_emb_name]
|
||||||
|
if puzzle_emb.shape != expected_shape:
|
||||||
|
print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}")
|
||||||
|
# Re-initialize using mean
|
||||||
|
state_dict[puzzle_emb_name] = (
|
||||||
|
torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous()
|
||||||
|
)
|
||||||
|
model.load_state_dict(state_dict, assign=True)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
|
||||||
|
return cosine_schedule_with_warmup_lr_lambda(
|
||||||
|
current_step=train_state.step,
|
||||||
|
base_lr=base_lr,
|
||||||
|
num_warmup_steps=round(config.lr_warmup_steps),
|
||||||
|
num_training_steps=train_state.total_steps,
|
||||||
|
min_ratio=config.lr_min_ratio
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]:
|
||||||
|
data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths
|
||||||
|
# Initialize evaluators
|
||||||
|
evaluators = []
|
||||||
|
for cfg in config.evaluators:
|
||||||
|
for data_path in data_paths:
|
||||||
|
cls = load_model_class(cfg.name, "evaluators.")(
|
||||||
|
data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__
|
||||||
|
) # type: ignore
|
||||||
|
evaluators.append(cls)
|
||||||
|
|
||||||
|
return evaluators
|
||||||
|
|
||||||
|
def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
|
||||||
|
train_state.step += 1
|
||||||
|
if train_state.step > train_state.total_steps: # At most train_total_steps
|
||||||
|
return
|
||||||
|
|
||||||
|
# To device
|
||||||
|
batch = {k: v.cuda() for k, v in batch.items()}
|
||||||
|
|
||||||
|
# Init carry if it is None
|
||||||
|
if train_state.carry is None:
|
||||||
|
with torch.device("cuda"):
|
||||||
|
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
|
||||||
|
|
||||||
|
((1 / global_batch_size) * loss).backward()
|
||||||
|
|
||||||
|
# Allreduce
|
||||||
|
if world_size > 1:
|
||||||
|
for param in train_state.model.parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
dist.all_reduce(param.grad)
|
||||||
|
|
||||||
|
# Apply optimizer
|
||||||
|
lr_this_step = None
|
||||||
|
for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
|
||||||
|
lr_this_step = compute_lr(base_lr, config, train_state)
|
||||||
|
|
||||||
|
for param_group in optim.param_groups:
|
||||||
|
param_group['lr'] = lr_this_step
|
||||||
|
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad()
|
||||||
|
|
||||||
|
# Reduce metrics
|
||||||
|
if len(metrics):
|
||||||
|
assert not any(v.requires_grad for v in metrics.values())
|
||||||
|
|
||||||
|
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
|
||||||
|
# Reduce and reconstruct
|
||||||
|
metric_values = torch.stack([metrics[k] for k in metric_keys])
|
||||||
|
if world_size > 1:
|
||||||
|
dist.reduce(metric_values, dst=0)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
metric_values = metric_values.cpu().numpy()
|
||||||
|
reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
count = max(reduced_metrics["count"], 1) # Avoid NaNs
|
||||||
|
reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
|
||||||
|
|
||||||
|
reduced_metrics["train/lr"] = lr_this_step
|
||||||
|
return reduced_metrics
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
config: PretrainConfig,
|
||||||
|
train_state: TrainState,
|
||||||
|
eval_loader: torch.utils.data.DataLoader,
|
||||||
|
eval_metadata: PuzzleDatasetMetadata,
|
||||||
|
evaluators: List[Any],
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
cpu_group: Optional[dist.ProcessGroup],
|
||||||
|
):
|
||||||
|
reduced_metrics = None
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
return_keys = set(config.eval_save_outputs)
|
||||||
|
for evaluator in evaluators:
|
||||||
|
evaluator.begin_eval()
|
||||||
|
return_keys.update(evaluator.required_outputs)
|
||||||
|
|
||||||
|
# Run evaluation
|
||||||
|
set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
|
||||||
|
|
||||||
|
save_preds = {}
|
||||||
|
|
||||||
|
metric_keys = []
|
||||||
|
metric_values = None
|
||||||
|
|
||||||
|
carry = None
|
||||||
|
processed_batches = 0
|
||||||
|
|
||||||
|
for set_name, batch, global_batch_size in eval_loader:
|
||||||
|
processed_batches += 1
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Processing batch {processed_batches}: {set_name}")
|
||||||
|
|
||||||
|
# To device
|
||||||
|
batch = {k: v.cuda() for k, v in batch.items()}
|
||||||
|
with torch.device("cuda"):
|
||||||
|
carry = train_state.model.initial_carry(batch) # type: ignore
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
inference_steps = 0
|
||||||
|
while True:
|
||||||
|
carry, loss, metrics, preds, all_finish = train_state.model(
|
||||||
|
carry=carry, batch=batch, return_keys=return_keys
|
||||||
|
)
|
||||||
|
inference_steps += 1
|
||||||
|
|
||||||
|
if all_finish:
|
||||||
|
break
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print(f" Completed inference in {inference_steps} steps")
|
||||||
|
|
||||||
|
for collection in (batch, preds):
|
||||||
|
for k, v in collection.items():
|
||||||
|
if k in config.eval_save_outputs:
|
||||||
|
save_preds.setdefault(k, [])
|
||||||
|
save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
|
||||||
|
|
||||||
|
for evaluator in evaluators:
|
||||||
|
evaluator.update_batch(batch, preds)
|
||||||
|
|
||||||
|
del carry, loss, preds, batch, all_finish
|
||||||
|
|
||||||
|
# Aggregate metrics
|
||||||
|
set_id = set_ids[set_name]
|
||||||
|
|
||||||
|
if metric_values is None:
|
||||||
|
metric_keys = list(
|
||||||
|
sorted(metrics.keys())
|
||||||
|
) # Sort keys to guarantee all processes use the same order.
|
||||||
|
metric_values = torch.zeros(
|
||||||
|
(len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
|
||||||
|
|
||||||
|
del metrics
|
||||||
|
|
||||||
|
# concatenate save preds
|
||||||
|
save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()}
|
||||||
|
|
||||||
|
# Save preds
|
||||||
|
if config.checkpoint_path is not None and len(save_preds):
|
||||||
|
# Each rank save predictions independently
|
||||||
|
os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)
|
||||||
|
torch.save(
|
||||||
|
save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")
|
||||||
|
)
|
||||||
|
|
||||||
|
del save_preds
|
||||||
|
|
||||||
|
# Reduce to rank 0
|
||||||
|
if metric_values is not None:
|
||||||
|
if world_size > 1:
|
||||||
|
dist.reduce(metric_values, dst=0)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
reduced_metrics = metric_values.cpu().numpy()
|
||||||
|
reduced_metrics = {
|
||||||
|
set_name: {
|
||||||
|
metric_name: reduced_metrics[set_id, metric_id]
|
||||||
|
for metric_id, metric_name in enumerate(metric_keys)
|
||||||
|
}
|
||||||
|
for set_id, set_name in enumerate(set_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
for set_name, m in reduced_metrics.items():
|
||||||
|
count = m.pop("count")
|
||||||
|
reduced_metrics[set_name] = {k: v / count for k, v in m.items()}
|
||||||
|
|
||||||
|
# Run evaluators
|
||||||
|
if rank == 0:
|
||||||
|
print(f"\nRunning {len(evaluators)} evaluator(s)...")
|
||||||
|
|
||||||
|
for i, evaluator in enumerate(evaluators):
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}")
|
||||||
|
|
||||||
|
# Path for saving
|
||||||
|
evaluator_save_path = None
|
||||||
|
if config.checkpoint_path is not None:
|
||||||
|
evaluator_save_path = os.path.join(
|
||||||
|
config.checkpoint_path,
|
||||||
|
f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}",
|
||||||
|
)
|
||||||
|
os.makedirs(evaluator_save_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Run and log
|
||||||
|
metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group)
|
||||||
|
if rank == 0 and metrics is not None:
|
||||||
|
if reduced_metrics is None:
|
||||||
|
reduced_metrics = {}
|
||||||
|
|
||||||
|
reduced_metrics.update(metrics)
|
||||||
|
print(f" Completed {evaluator.__class__.__name__}")
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print("All evaluators completed!")
|
||||||
|
|
||||||
|
return reduced_metrics
|
||||||
|
|
||||||
|
def save_code_and_config(config: PretrainConfig):
|
||||||
|
if config.checkpoint_path is None or wandb.run is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
os.makedirs(config.checkpoint_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy code
|
||||||
|
code_list = [
|
||||||
|
get_model_source_path(config.arch.name),
|
||||||
|
get_model_source_path(config.arch.loss.name)
|
||||||
|
]
|
||||||
|
for code_file in code_list:
|
||||||
|
if code_file is not None:
|
||||||
|
code_name = os.path.basename(code_file)
|
||||||
|
|
||||||
|
shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
|
||||||
|
|
||||||
|
# Dump config as yaml
|
||||||
|
config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
|
||||||
|
with open(config_file, "wt") as f:
|
||||||
|
yaml.dump(config.model_dump(), f)
|
||||||
|
|
||||||
|
# Log code
|
||||||
|
wandb.run.log_code(config.checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
|
||||||
|
objects = [None]
|
||||||
|
if rank == 0:
|
||||||
|
config = PretrainConfig(**hydra_config) # type: ignore
|
||||||
|
|
||||||
|
# Naming
|
||||||
|
if config.project_name is None:
|
||||||
|
config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch"
|
||||||
|
if config.run_name is None:
|
||||||
|
config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
|
||||||
|
if config.checkpoint_path is None:
|
||||||
|
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
|
||||||
|
|
||||||
|
objects = [config]
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
dist.broadcast_object_list(objects, src=0)
|
||||||
|
|
||||||
|
return objects[0] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
|
||||||
|
def launch(hydra_config: DictConfig):
|
||||||
|
RANK = 0
|
||||||
|
WORLD_SIZE = 1
|
||||||
|
CPU_PROCESS_GROUP = None
|
||||||
|
|
||||||
|
# Initialize distributed training if in distributed environment (e.g. torchrun)
|
||||||
|
if "LOCAL_RANK" in os.environ:
|
||||||
|
# Initialize distributed, default device and dtype
|
||||||
|
dist.init_process_group(backend="nccl")
|
||||||
|
|
||||||
|
RANK = dist.get_rank()
|
||||||
|
WORLD_SIZE = dist.get_world_size()
|
||||||
|
|
||||||
|
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||||
|
|
||||||
|
# CPU GLOO process group
|
||||||
|
CPU_PROCESS_GROUP = dist.new_group(backend="gloo")
|
||||||
|
assert (
|
||||||
|
dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load sync'ed config
|
||||||
|
config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
|
||||||
|
|
||||||
|
# Seed RNGs to ensure consistency
|
||||||
|
torch.random.manual_seed(config.seed + RANK)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
|
||||||
|
total_iters = config.epochs // train_epochs_per_iter
|
||||||
|
|
||||||
|
assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
|
||||||
|
|
||||||
|
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
||||||
|
try:
|
||||||
|
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
||||||
|
except:
|
||||||
|
print("NO EVAL DATA FOUND")
|
||||||
|
eval_loader = eval_metadata = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
evaluators = create_evaluators(config, eval_metadata)
|
||||||
|
except:
|
||||||
|
print("No evaluator found")
|
||||||
|
evaluators = []
|
||||||
|
|
||||||
|
# Train state
|
||||||
|
train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)
|
||||||
|
|
||||||
|
# Progress bar and logger
|
||||||
|
progress_bar = None
|
||||||
|
ema_helper = None
|
||||||
|
if RANK == 0:
|
||||||
|
progress_bar = tqdm.tqdm(total=train_state.total_steps)
|
||||||
|
wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
|
||||||
|
wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
|
||||||
|
save_code_and_config(config)
|
||||||
|
if config.ema:
|
||||||
|
print('Setup EMA')
|
||||||
|
ema_helper = EMAHelper(mu=config.ema_rate)
|
||||||
|
ema_helper.register(train_state.model)
|
||||||
|
|
||||||
|
# Training Loop
|
||||||
|
for _iter_id in range(total_iters):
|
||||||
|
print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
|
||||||
|
|
||||||
|
############ Train Iter
|
||||||
|
if RANK == 0:
|
||||||
|
print("TRAIN")
|
||||||
|
train_state.model.train()
|
||||||
|
for set_name, batch, global_batch_size in train_loader:
|
||||||
|
metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
||||||
|
|
||||||
|
if RANK == 0 and metrics is not None:
|
||||||
|
wandb.log(metrics, step=train_state.step)
|
||||||
|
progress_bar.update(train_state.step - progress_bar.n) # type: ignore
|
||||||
|
if config.ema:
|
||||||
|
ema_helper.update(train_state.model)
|
||||||
|
|
||||||
|
if _iter_id >= config.min_eval_interval:
|
||||||
|
############ Evaluation
|
||||||
|
if RANK == 0:
|
||||||
|
print("EVALUATE")
|
||||||
|
if config.ema:
|
||||||
|
print("SWITCH TO EMA")
|
||||||
|
train_state_eval = copy.deepcopy(train_state)
|
||||||
|
train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
|
||||||
|
else:
|
||||||
|
train_state_eval = train_state
|
||||||
|
train_state_eval.model.eval()
|
||||||
|
metrics = evaluate(config,
|
||||||
|
train_state_eval,
|
||||||
|
eval_loader,
|
||||||
|
eval_metadata,
|
||||||
|
evaluators,
|
||||||
|
rank=RANK,
|
||||||
|
world_size=WORLD_SIZE,
|
||||||
|
cpu_group=CPU_PROCESS_GROUP)
|
||||||
|
|
||||||
|
if RANK == 0 and metrics is not None:
|
||||||
|
wandb.log(metrics, step=train_state.step)
|
||||||
|
|
||||||
|
############ Checkpointing
|
||||||
|
if RANK == 0:
|
||||||
|
print("SAVE CHECKPOINT")
|
||||||
|
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
|
||||||
|
save_train_state(config, train_state_eval)
|
||||||
|
|
||||||
|
if config.ema:
|
||||||
|
del train_state_eval
|
||||||
|
|
||||||
|
# finalize
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
launch()
|
||||||
@@ -0,0 +1,250 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import Tuple, List, Dict, Optional
|
||||||
|
import numpy as np
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import IterableDataset, get_worker_info
|
||||||
|
|
||||||
|
from models.losses import IGNORE_LABEL_ID
|
||||||
|
from dataset.common import PuzzleDatasetMetadata
|
||||||
|
|
||||||
|
from argdantic import ArgParser
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
|
||||||
|
# Pack examples into a full batch
|
||||||
|
batch = []
|
||||||
|
batch_puzzle_indices = []
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
while (start_index < group_order.size) and (current_size < global_batch_size):
|
||||||
|
# Pick a group and a puzzle from that group
|
||||||
|
group_id = group_order[start_index]
|
||||||
|
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
|
||||||
|
start_index += 1
|
||||||
|
|
||||||
|
# Get range of the puzzle
|
||||||
|
puzzle_start = puzzle_indices[puzzle_id]
|
||||||
|
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
|
||||||
|
|
||||||
|
append_size = min(puzzle_size, global_batch_size - current_size)
|
||||||
|
|
||||||
|
# Put into batch
|
||||||
|
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
|
||||||
|
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
|
||||||
|
|
||||||
|
current_size += append_size
|
||||||
|
|
||||||
|
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
|
||||||
|
|
||||||
|
|
||||||
|
class PuzzleDatasetConfig(pydantic.BaseModel):
|
||||||
|
seed: int
|
||||||
|
dataset_paths: List[str]
|
||||||
|
global_batch_size: int
|
||||||
|
test_set_mode: bool
|
||||||
|
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
|
||||||
|
rank: int
|
||||||
|
num_replicas: int
|
||||||
|
|
||||||
|
class PuzzleDataset(IterableDataset):
|
||||||
|
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.split = split
|
||||||
|
|
||||||
|
# Merge multiple metadata
|
||||||
|
prev_seq_len = None
|
||||||
|
prev_vocab_size = None
|
||||||
|
prev_pad_id = None
|
||||||
|
prev_ignore_label_id = None
|
||||||
|
prev_blank_identifier_id = None
|
||||||
|
prev_sets = None
|
||||||
|
prev_num_identifiers = None
|
||||||
|
mean_puzzle_examples = 0
|
||||||
|
total_puzzles = 0
|
||||||
|
total_groups = 0
|
||||||
|
num_identifiers = 0
|
||||||
|
for dataset_path in config.dataset_paths:
|
||||||
|
current_metadata = self._load_metadata(dataset_path)
|
||||||
|
if prev_seq_len is None:
|
||||||
|
prev_seq_len = current_metadata.seq_len
|
||||||
|
prev_vocab_size = current_metadata.vocab_size
|
||||||
|
prev_pad_id = current_metadata.pad_id
|
||||||
|
prev_ignore_label_id = current_metadata.ignore_label_id
|
||||||
|
prev_blank_identifier_id = current_metadata.blank_identifier_id
|
||||||
|
prev_sets = current_metadata.sets
|
||||||
|
prev_num_identifiers = current_metadata.num_puzzle_identifiers
|
||||||
|
else:
|
||||||
|
assert prev_seq_len == current_metadata.seq_len
|
||||||
|
assert prev_vocab_size == current_metadata.vocab_size
|
||||||
|
assert prev_pad_id == current_metadata.pad_id
|
||||||
|
assert prev_ignore_label_id == current_metadata.ignore_label_id
|
||||||
|
assert prev_blank_identifier_id == current_metadata.blank_identifier_id
|
||||||
|
assert prev_sets == current_metadata.sets
|
||||||
|
assert prev_num_identifiers == current_metadata.num_puzzle_identifiers
|
||||||
|
mean_puzzle_examples += current_metadata.mean_puzzle_examples*current_metadata.total_puzzles
|
||||||
|
total_puzzles += current_metadata.total_puzzles
|
||||||
|
total_groups += current_metadata.total_groups
|
||||||
|
num_identifiers += current_metadata.num_puzzle_identifiers
|
||||||
|
mean_puzzle_examples = mean_puzzle_examples / total_puzzles
|
||||||
|
|
||||||
|
self.metadata = PuzzleDatasetMetadata(
|
||||||
|
seq_len=prev_seq_len,
|
||||||
|
vocab_size=prev_vocab_size,
|
||||||
|
pad_id=prev_pad_id,
|
||||||
|
ignore_label_id=prev_ignore_label_id,
|
||||||
|
blank_identifier_id=prev_blank_identifier_id,
|
||||||
|
num_puzzle_identifiers=num_identifiers,
|
||||||
|
total_groups=total_groups,
|
||||||
|
mean_puzzle_examples=mean_puzzle_examples,
|
||||||
|
total_puzzles=total_puzzles,
|
||||||
|
sets=prev_sets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
|
||||||
|
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
|
||||||
|
|
||||||
|
# State
|
||||||
|
self._data = None
|
||||||
|
self._iters = 0
|
||||||
|
|
||||||
|
def _load_metadata(self, dataset_path) -> PuzzleDatasetMetadata:
|
||||||
|
with open(os.path.join(dataset_path, self.split, "dataset.json"), "r") as f:
|
||||||
|
return PuzzleDatasetMetadata(**json.load(f))
|
||||||
|
|
||||||
|
def _lazy_load_dataset(self):
|
||||||
|
if self._data is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
field_mmap_modes = {
|
||||||
|
"inputs": "r",
|
||||||
|
"labels": "r",
|
||||||
|
|
||||||
|
# Keep indices in memory
|
||||||
|
"puzzle_identifiers": None,
|
||||||
|
"puzzle_indices": None,
|
||||||
|
"group_indices": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
self._data = {}
|
||||||
|
for set_name in self.metadata.sets: # Load subset
|
||||||
|
for i, dataset_path in enumerate(self.config.dataset_paths):
|
||||||
|
if i > 0:
|
||||||
|
set_name_ = set_name + str(i)
|
||||||
|
else:
|
||||||
|
set_name_ = set_name
|
||||||
|
self._data[set_name_] = {
|
||||||
|
field_name: np.load(os.path.join(dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
|
||||||
|
for field_name, mmap_mode in field_mmap_modes.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_batch(self, batch):
|
||||||
|
# Convert dtype
|
||||||
|
batch = {k: v.astype(np.int32) for k, v in batch.items()}
|
||||||
|
|
||||||
|
# Convert ignore label IDs
|
||||||
|
if self.metadata.ignore_label_id is not None:
|
||||||
|
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
|
||||||
|
|
||||||
|
# Pad
|
||||||
|
if batch["puzzle_identifiers"].size < self.local_batch_size:
|
||||||
|
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
|
||||||
|
pad_values = {
|
||||||
|
"inputs": self.metadata.pad_id,
|
||||||
|
"labels": IGNORE_LABEL_ID,
|
||||||
|
"puzzle_identifiers": self.metadata.blank_identifier_id
|
||||||
|
}
|
||||||
|
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
|
||||||
|
|
||||||
|
# To tensor
|
||||||
|
return {k: torch.from_numpy(v) for k, v in batch.items()}
|
||||||
|
|
||||||
|
def _iter_test(self):
|
||||||
|
for set_i, (set_name, dataset) in enumerate(self._data.items()): # type: ignore
|
||||||
|
total_examples = len(dataset["inputs"])
|
||||||
|
|
||||||
|
# Load examples one by one
|
||||||
|
start_index = 0
|
||||||
|
while start_index < total_examples:
|
||||||
|
# Compute indices
|
||||||
|
end_index = min(total_examples, start_index + self.config.global_batch_size)
|
||||||
|
|
||||||
|
local_start = start_index + self.config.rank * self.local_batch_size
|
||||||
|
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
|
||||||
|
|
||||||
|
# Get batch of examples, and also puzzle IDs
|
||||||
|
puzzle_indices = []
|
||||||
|
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
|
||||||
|
for i in range(local_start, local_end):
|
||||||
|
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
|
||||||
|
puzzle_index += 1
|
||||||
|
|
||||||
|
puzzle_indices.append(puzzle_index)
|
||||||
|
|
||||||
|
batch = self._collate_batch({
|
||||||
|
"inputs": dataset["inputs"][local_start: local_end],
|
||||||
|
"labels": dataset["labels"][local_start: local_end],
|
||||||
|
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
|
||||||
|
})
|
||||||
|
|
||||||
|
yield set_name, batch, end_index - start_index
|
||||||
|
|
||||||
|
# Advance to next batch
|
||||||
|
start_index += self.config.global_batch_size
|
||||||
|
|
||||||
|
def _iter_train(self):
|
||||||
|
for set_name, dataset in self._data.items(): # type: ignore
|
||||||
|
# Increase epoch count
|
||||||
|
self._iters += 1
|
||||||
|
|
||||||
|
# Randomly shuffle groups
|
||||||
|
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
|
||||||
|
|
||||||
|
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
while start_index < group_order.size:
|
||||||
|
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
|
||||||
|
rng,
|
||||||
|
group_order=group_order,
|
||||||
|
puzzle_indices=dataset["puzzle_indices"],
|
||||||
|
group_indices=dataset["group_indices"],
|
||||||
|
start_index=start_index,
|
||||||
|
global_batch_size=self.config.global_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select current rank and collate
|
||||||
|
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
|
||||||
|
|
||||||
|
# Drop last batch
|
||||||
|
if global_effective_batch_size < self.config.global_batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
||||||
|
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
||||||
|
batch = self._collate_batch({
|
||||||
|
"inputs": dataset["inputs"][batch_indices],
|
||||||
|
"labels": dataset["labels"][batch_indices],
|
||||||
|
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
|
||||||
|
})
|
||||||
|
|
||||||
|
yield set_name, batch, global_effective_batch_size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
worker_info = get_worker_info()
|
||||||
|
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
|
||||||
|
|
||||||
|
self._lazy_load_dataset()
|
||||||
|
|
||||||
|
# Iterate using specified mode
|
||||||
|
if self.config.test_set_mode:
|
||||||
|
yield from self._iter_test()
|
||||||
|
else:
|
||||||
|
yield from self._iter_train()
|
||||||
|
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
torch
|
||||||
|
adam-atan2
|
||||||
|
einops
|
||||||
|
tqdm
|
||||||
|
coolname
|
||||||
|
pydantic
|
||||||
|
argdantic
|
||||||
|
wandb
|
||||||
|
omegaconf
|
||||||
|
hydra-core
|
||||||
|
huggingface_hub
|
||||||
|
packaging
|
||||||
|
ninja
|
||||||
|
wheel
|
||||||
|
setuptools
|
||||||
|
setuptools-scm
|
||||||
|
pydantic-core
|
||||||
|
huggingface_hub
|
||||||
|
numba
|
||||||
|
triton
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_class(identifier: str, prefix: str = "models."):
|
||||||
|
module_path, class_name = identifier.split('@')
|
||||||
|
|
||||||
|
# Import the module
|
||||||
|
module = importlib.import_module(prefix + module_path)
|
||||||
|
cls = getattr(module, class_name)
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_source_path(identifier: str, prefix: str = "models."):
|
||||||
|
module_path, class_name = identifier.split('@')
|
||||||
|
|
||||||
|
module = importlib.import_module(prefix + module_path)
|
||||||
|
return inspect.getsourcefile(module)
|
||||||
Reference in New Issue
Block a user