This commit is contained in:
Alexia Jolicoeur-Martineau
2025-10-07 09:26:04 -04:00
commit 8120f2bdf7
39 changed files with 27428 additions and 0 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025. Samsung Electronics Co., Ltd. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

147
README.md Normal file
View File

@@ -0,0 +1,147 @@
# Less is More: Recursive Reasoning with Tiny Networks
This is the codebase for the paper: "Less is More: Recursive Reasoning with Tiny Networks" were we present a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network.
[Paper](https://arxiv.org/abs/2510.04871)
### How TRM works
Tiny Recursion Model (TRM) recursively improves its predicted answer y with a tiny network. It starts with the embedded input question x and initial embedded answer y and latent z. For up to K improvements steps, it tries to improve its answer y. It does so by i) recursively updating n times its latent z given the question x, current answer y, and current latent z (recursive reasoning), and then ii) updating its answer y given the current answer y and current latent z. This recursive process allows the model to progressively improve its answer (potentially addressing any errors from its previous answer) in an extremely parameter-efficient manner while minimizing overfitting.
<p align="center">
<img src="{{ site.baseurl }}/assets/images/TRM_fig.png" alt="TRM-Figure" style="width:50%">
</p>
### Requirements
- Python 3.10 (or similar)
- Cuda 12.6.0 (or similar)
```bash
pip install --upgrade pip wheel setuptools
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 # install torch based on your cuda version
pip install -r requirements.txt # install requirements
pip install --no-cache-dir --no-build-isolation adam-atan2
wandb login YOUR-LOGIN # login if you want the logger to sync results to your Weights & Biases (https://wandb.ai/)
```
### Dataset Preparation
```bash
# ARC-AGI-1
python -m dataset.build_arc_dataset \
--input-file-prefix kaggle/combined/arc-agi \
--output-dir data/arc1concept-aug-1000 \
--subsets training evaluation concept \
--test-set-name evaluation
# ARC-AGI-2
python -m dataset.build_arc_dataset \
--input-file-prefix kaggle/combined/arc-agi \
--output-dir data/arc2concept-aug-1000 \
--subsets training2 evaluation2 concept \
--test-set-name evaluation2
## Note: You cannot train on both ARC-AGI-1 and ARC-AGI-2 and evaluate them both because ARC-AGI-2 training data contains some ARC-AGI-1 eval data
# Sudoku-Extreme
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples, 1000 augments
# Maze-Hard
python dataset/build_maze_dataset.py # 1000 examples, 8 augments
```
## Experiments
### ARC-AGI (assuming 4 H-100 GPUs):
```bash
run_name="pretrain_att_arc12concept_4"
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
arch=trm \
data_paths="[data/arc12concept-aug-1000]" \
arch.L_layers=2 \
arch.H_cycles=3 arch.L_cycles=4 \
+run_name=${run_name} ema=True
```
*Runtime:* ~3 days
### Sudoku-Extreme (assuming 1 L40S GPU):
```bash
run_name="pretrain_mlp_t_sudoku"
python pretrain.py \
arch=trm \
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
evaluators="[]" \
epochs=50000 eval_interval=5000 \
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
arch.mlp_t=True arch.pos_encodings=none \
arch.L_layers=2 \
arch.H_cycles=3 arch.L_cycles=6 \
+run_name=${run_name} ema=True
run_name="pretrain_att_sudoku"
python pretrain.py \
arch=trm \
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
evaluators="[]" \
epochs=50000 eval_interval=5000 \
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
arch.L_layers=2 \
arch.H_cycles=3 arch.L_cycles=6 \
+run_name=${run_name} ema=True
```
*Runtime:* < 36 hours
### Maze-Hard (assuming 4 L40S GPUs):
```bash
run_name="pretrain_att_maze30x30"
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
arch=trm \
data_paths="[data/maze-30x30-hard-1k]" \
evaluators="[]" \
epochs=50000 eval_interval=5000 \
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
arch.L_layers=2 \
arch.H_cycles=3 arch.L_cycles=4 \
+run_name=${run_name} ema=True
```
*Runtime:* < 24 hours
## Reference
If you find our work useful, please consider citing:
```bibtex
@misc{jolicoeurmartineau2025tinyrecursionmodel,
title={Less is More: Recursive Reasoning with Tiny Networks},
author={Alexia Jolicoeur-Martineau},
year={2025},
eprint={xxxxxxx},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/xxxxxxxxx},
}
```
and the Hierarchical Reasoning Model (HRM):
```bibtex
@misc{wang2025hierarchicalreasoningmodel,
title={Hierarchical Reasoning Model},
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
year={2025},
eprint={2506.21734},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2506.21734},
}
```
This code is based on the Hierarchical Reasoning Model [code](https://github.com/sapientinc/HRM) and the Hierarchical Reasoning Model Analysis [code](https://github.com/arcprize/hierarchical-reasoning-model-analysis).

BIN
assets/TRM_fig.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 346 KiB

BIN
assets/TRM_pseudocode.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 261 KiB

24
config/arch/hrm.yaml Normal file
View File

@@ -0,0 +1,24 @@
name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
loss:
name: losses@ACTLossHead
loss_type: stablemax_cross_entropy
halt_exploration_prob: 0.1
halt_max_steps: 16
H_cycles: 2
L_cycles: 2
H_layers: 4
L_layers: 4
hidden_size: 512
num_heads: 8 # min(2, hidden_size // 64)
expansion: 4
puzzle_emb_ndim: ${.hidden_size}
pos_encodings: rope
forward_dtype: bfloat16
mlp_t: False # use mlp on L instead of transformer

View File

@@ -0,0 +1,18 @@
name: recursive_reasoning.transformers_baseline@Model_ACTV2
loss:
name: losses@ACTLossHead
loss_type: stablemax_cross_entropy
halt_exploration_prob: 0.1
halt_max_steps: 16
H_cycles: 1 # kept for compatibility
H_layers: 8
hidden_size: 512
num_heads: 12
expansion: 4
puzzle_emb_ndim: ${.hidden_size}
pos_encodings: rope

26
config/arch/trm.yaml Normal file
View File

@@ -0,0 +1,26 @@
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
loss:
name: losses@ACTLossHead
loss_type: stablemax_cross_entropy
halt_exploration_prob: 0.1
halt_max_steps: 16
H_cycles: 3
L_cycles: 6
H_layers: 0
L_layers: 2
hidden_size: 512
num_heads: 8 # min(2, hidden_size // 64)
expansion: 4
puzzle_emb_ndim: ${.hidden_size}
pos_encodings: rope
forward_dtype: bfloat16
mlp_t: False # use mlp on L instead of transformer
puzzle_emb_len: 16 # if non-zero, its specified to this value
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

View File

@@ -0,0 +1,26 @@
name: recursive_reasoning.trm_hier6@TinyRecursiveReasoningModel_ACTV1
loss:
name: losses@ACTLossHead
loss_type: stablemax_cross_entropy
halt_exploration_prob: 0.1
halt_max_steps: 16
H_cycles: 3
L_cycles: 6
H_layers: 0
L_layers: 2
hidden_size: 512
num_heads: 8 # min(2, hidden_size // 64)
expansion: 4
puzzle_emb_ndim: ${.hidden_size}
pos_encodings: rope
forward_dtype: bfloat16
mlp_t: False # use mlp on L instead of transformer
puzzle_emb_len: 16 # if non-zero, its specified to this value
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

View File

@@ -0,0 +1,26 @@
name: recursive_reasoning.trm_singlez@TinyRecursiveReasoningModel_ACTV1
loss:
name: losses@ACTLossHead
loss_type: stablemax_cross_entropy
halt_exploration_prob: 0.1
halt_max_steps: 16
H_cycles: 3
L_cycles: 6
H_layers: 0
L_layers: 2
hidden_size: 512
num_heads: 8 # min(2, hidden_size // 64)
expansion: 4
puzzle_emb_ndim: ${.hidden_size}
pos_encodings: rope
forward_dtype: bfloat16
mlp_t: False # use mlp on L instead of transformer
puzzle_emb_len: 16 # if non-zero, its specified to this value
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

42
config/cfg_pretrain.yaml Normal file
View File

@@ -0,0 +1,42 @@
# ARC training config
defaults:
- arch: trm
- _self_
hydra:
output_subdir: null
# Data path
data_paths: ['data/arc-aug-1000']
data_paths_test: []
evaluators:
- name: arc@ARC
# Hyperparams - Training
global_batch_size: 768
epochs: 100000
eval_interval: 10000
checkpoint_every_eval: True
lr: 1e-4
lr_min_ratio: 1.0
lr_warmup_steps: 2000
# Standard hyperparameter settings for LM, as used in Llama
beta1: 0.9
beta2: 0.95
weight_decay: 0.1
puzzle_emb_weight_decay: 0.1
# Hyperparams - Puzzle embeddings training
puzzle_emb_lr: 1e-2
seed: 0
min_eval_interval: 0 # when to start the eval
ema: False # use Exponential-Moving-Average
ema_rate: 0.999 # EMA-rate
freeze_weights: False # If True, freeze weights and only learn the embeddings

View File

@@ -0,0 +1,341 @@
from typing import List, Tuple, Dict
from dataclasses import dataclass
import os
import json
import hashlib
import numpy as np
from argdantic import ArgParser
from pydantic import BaseModel
from dataset.common import PuzzleDatasetMetadata, dihedral_transform, inverse_dihedral_transform
cli = ArgParser()
class DataProcessConfig(BaseModel):
input_file_prefix: str
output_dir: str
subsets: List[str]
test_set_name: str
test_set_name2: str = "your_test_set"
seed: int = 42
num_aug: int = 1000
puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
ARCMaxGridSize = 30
ARCAugmentRetriesFactor = 5
PuzzleIdSeparator = "|||"
@dataclass
class ARCPuzzle:
id: str
examples: List[Tuple[np.ndarray, np.ndarray]]
def arc_grid_to_np(grid: List[List[int]]):
arr = np.array(grid)
# Shape check
assert arr.ndim == 2
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
# Element check
assert np.all((arr >= 0) & (arr <= 9))
return arr.astype(np.uint8)
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
# PAD: 0, <eos>: 1, digits: 2 ... 11
# Compute random top-left pad
if do_translation:
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
else:
pad_r = pad_c = 0
# Pad grid
result = []
for grid in [inp, out]:
nrow, ncol = grid.shape
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
# Add <eos>
eos_row, eos_col = pad_r + nrow, pad_c + ncol
if eos_row < ARCMaxGridSize:
grid[eos_row, pad_c:eos_col] = 1
if eos_col < ARCMaxGridSize:
grid[pad_r:eos_row, eos_col] = 1
result.append(grid.flatten())
return result
def grid_hash(grid: np.ndarray):
assert grid.ndim == 2
assert grid.dtype == np.uint8
buffer = [x.to_bytes(1, byteorder='big') for x in grid.shape]
buffer.append(grid.tobytes())
return hashlib.sha256(b"".join(buffer)).hexdigest()
def puzzle_hash(puzzle: dict):
# Hash the puzzle for checking equivalence
hashes = []
for example_type, example in puzzle.items():
for input, label in example.examples:
hashes.append(f"{grid_hash(input)}|{grid_hash(label)}")
hashes.sort()
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
def aug(name: str):
# Augment plan
trans_id = np.random.randint(0, 8)
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
name_with_aug_repr = f"{name}{PuzzleIdSeparator}t{trans_id}{PuzzleIdSeparator}{''.join(str(x) for x in mapping)}"
def _map_grid(grid: np.ndarray):
return dihedral_transform(mapping[grid], trans_id)
return name_with_aug_repr, _map_grid
def inverse_aug(name: str):
# Inverse the "aug" function
if PuzzleIdSeparator not in name:
return name, lambda x: x
trans_id, perm = name.split(PuzzleIdSeparator)[-2:]
trans_id = int(trans_id[1:]) # Remove "t" letter
inv_perm = np.argsort(list(perm)).astype(np.uint8)
def _map_grid(grid: np.ndarray):
return inv_perm[inverse_dihedral_transform(grid, trans_id)]
return name.split(PuzzleIdSeparator)[0], _map_grid
def convert_single_arc_puzzle(results: dict, name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
# Convert
dests = set(dest_mapping.values())
converted = {dest: ARCPuzzle(name, []) for dest in dests}
for example_type, examples in puzzle.items():
# Map to target split
dest = dest_mapping[example_type]
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
group = [converted]
# Augment
if aug_count > 0:
hashes = {puzzle_hash(converted)}
for _trial in range(ARCAugmentRetriesFactor * aug_count):
aug_name, _map_grid = aug(name)
# Check duplicate
augmented = {dest: ARCPuzzle(aug_name, [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
h = puzzle_hash(augmented)
if h not in hashes:
hashes.add(h)
group.append(augmented)
if len(group) >= aug_count + 1:
break
if len(group) < aug_count + 1:
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
# Append
for dest in dests:
# Convert the examples
dest_split, dest_set = dest
results.setdefault(dest_split, {})
results[dest_split].setdefault(dest_set, [])
results[dest_split][dest_set].append([converted[dest] for converted in group])
def load_puzzles_arcagi(config: DataProcessConfig):
train_examples_dest = ("train", "all")
test_examples_map = {
config.test_set_name: [(1.0, ("test", "all"))],
config.test_set_name2: [(1.0, ("test", "all"))],
"_default": [(1.0, ("train", "all"))]
}
test_puzzles = {}
results = {}
total_puzzles = 0
for subset_name in config.subsets:
# Load all puzzles in this subset
with open(f"{config.input_file_prefix}_{subset_name}_challenges.json", "r") as f:
puzzles = json.load(f)
sols_filename = f"{config.input_file_prefix}_{subset_name}_solutions.json"
if os.path.isfile(sols_filename):
with open(sols_filename, "r") as f:
sols = json.load(f)
for puzzle_id in puzzles.keys():
for idx, sol_grid in enumerate(sols[puzzle_id]):
puzzles[puzzle_id]["test"][idx]["output"] = sol_grid
else:
# Fill with dummy
print (f"{subset_name} solutions not found, filling with dummy")
for puzzle_id, puzzle in puzzles.items():
for example in puzzle["test"]:
example.setdefault("output", [[0]])
# Shuffle puzzles
puzzles = list(puzzles.items())
np.random.shuffle(puzzles)
# Assign by fraction
for idx, (name, puzzle) in enumerate(puzzles):
fraction = idx / len(puzzles)
test_examples_dest = None
for f, dest in test_examples_map.get(subset_name, test_examples_map["_default"]):
if fraction < f:
test_examples_dest = dest
break
assert test_examples_dest is not None
if test_examples_dest[0] == "test":
test_puzzles[name] = puzzle
convert_single_arc_puzzle(results, name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
total_puzzles += 1
print (f"Total puzzles: {total_puzzles}")
return results, test_puzzles
def convert_dataset(config: DataProcessConfig):
np.random.seed(config.seed)
# Read dataset
data, test_puzzles = load_puzzles_arcagi(config)
# Map global puzzle identifiers
num_identifiers = config.puzzle_identifiers_start # 0 is blank, start at 1
identifier_map = {}
for split_name, split in data.items():
for subset_name, subset in split.items():
for group in subset:
for puzzle in group:
if puzzle.id not in identifier_map:
identifier_map[puzzle.id] = num_identifiers
num_identifiers += 1
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
# Save
for split_name, split in data.items():
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
# Translational augmentations
enable_translational_augment = split_name == "train"
# Statistics
total_examples = 0
total_puzzles = 0
total_groups = 0
for subset_name, subset in split.items(): # "all" is the only subset
# Construct subset
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
results["puzzle_indices"].append(0)
results["group_indices"].append(0)
example_id = 0
puzzle_id = 0
for group in subset:
for puzzle in group:
# Push puzzle
no_aug_id = np.random.randint(0, len(puzzle.examples))
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
results["inputs"].append(inp)
results["labels"].append(out)
example_id += 1
total_examples += 1
results["puzzle_indices"].append(example_id)
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
puzzle_id += 1
total_puzzles += 1
# Push group
results["group_indices"].append(puzzle_id)
total_groups += 1
for k, v in results.items():
if k in {"inputs", "labels"}:
v = np.stack(v, 0)
else:
v = np.array(v, dtype=np.int32)
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=ARCMaxGridSize * ARCMaxGridSize,
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
pad_id=0,
ignore_label_id=0,
blank_identifier_id=0,
num_puzzle_identifiers=num_identifiers,
total_groups=total_groups,
mean_puzzle_examples=total_examples / total_puzzles,
total_puzzles=total_puzzles,
sets=list(split.keys())
)
# Save metadata as JSON.
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
json.dump(metadata.model_dump(), f)
# Save IDs mapping
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
ids_mapping = {v: k for k, v in identifier_map.items()}
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
# Save Test Puzzles
with open(os.path.join(config.output_dir, "test_puzzles.json"), "w") as f:
json.dump(test_puzzles, f)
@cli.command(singleton=True)
def main(config: DataProcessConfig):
convert_dataset(config)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,140 @@
from typing import Optional
import math
import os
import csv
import json
import numpy as np
from argdantic import ArgParser
from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from common import PuzzleDatasetMetadata, dihedral_transform
CHARSET = "# SGo"
cli = ArgParser()
class DataProcessConfig(BaseModel):
source_repo: str = "sapientinc/maze-30x30-hard-1k"
output_dir: str = "data/maze-30x30-hard-1k"
subsample_size: Optional[int] = None
aug: bool = False
def convert_subset(set_name: str, config: DataProcessConfig):
# Read CSV
all_chars = set()
grid_size = None
inputs = []
labels = []
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
reader = csv.reader(csvfile)
next(reader) # Skip header
for source, q, a, rating in reader:
all_chars.update(q)
all_chars.update(a)
if grid_size is None:
n = int(len(q) ** 0.5)
grid_size = (n, n)
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
# If subsample_size is specified for the training set,
# randomly sample the desired number of examples.
if set_name == "train" and config.subsample_size is not None:
total_samples = len(inputs)
if config.subsample_size < total_samples:
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
inputs = [inputs[i] for i in indices]
labels = [labels[i] for i in indices]
# Generate dataset
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
puzzle_id = 0
example_id = 0
results["puzzle_indices"].append(0)
results["group_indices"].append(0)
for inp, out in zip(tqdm(inputs), labels):
# Dihedral transformations for augmentation
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
results["inputs"].append(dihedral_transform(inp, aug_idx))
results["labels"].append(dihedral_transform(out, aug_idx))
example_id += 1
puzzle_id += 1
results["puzzle_indices"].append(example_id)
results["puzzle_identifiers"].append(0)
# Push group
results["group_indices"].append(puzzle_id)
# Char mappings
assert len(all_chars - set(CHARSET)) == 0
char2id = np.zeros(256, np.uint8)
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
# To Numpy
def _seq_to_numpy(seq):
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
return arr
results = {
"inputs": _seq_to_numpy(results["inputs"]),
"labels": _seq_to_numpy(results["labels"]),
"group_indices": np.array(results["group_indices"], dtype=np.int32),
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
}
# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=int(math.prod(grid_size)), # type: ignore
vocab_size=len(CHARSET) + 1, # PAD + Charset
pad_id=0,
ignore_label_id=0,
blank_identifier_id=0,
num_puzzle_identifiers=1,
total_groups=len(results["group_indices"]) - 1,
mean_puzzle_examples=1,
total_puzzles=len(results["group_indices"]) - 1,
sets=["all"]
)
# Save metadata as JSON.
save_dir = os.path.join(config.output_dir, set_name)
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
json.dump(metadata.model_dump(), f)
# Save data
for k, v in results.items():
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
# Save IDs mapping (for visualization only)
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
json.dump(["<blank>"], f)
@cli.command(singleton=True)
def preprocess_data(config: DataProcessConfig):
convert_subset("train", config)
convert_subset("test", config)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,167 @@
from typing import Optional
import os
import csv
import json
import numpy as np
from argdantic import ArgParser
from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from common import PuzzleDatasetMetadata
cli = ArgParser()
class DataProcessConfig(BaseModel):
source_repo: str = "sapientinc/sudoku-extreme"
output_dir: str = "data/sudoku-extreme-full"
subsample_size: Optional[int] = None
min_difficulty: Optional[int] = None
num_aug: int = 0
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
# Randomly decide whether to transpose.
transpose_flag = np.random.rand() < 0.5
# Generate a valid row permutation:
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
bands = np.random.permutation(3)
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
# Similarly for columns (stacks).
stacks = np.random.permutation(3)
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
# Build an 81->81 mapping. For each new cell at (i, j)
# (row index = i // 9, col index = i % 9),
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
def apply_transformation(x: np.ndarray) -> np.ndarray:
# Apply transpose flag
if transpose_flag:
x = x.T
# Apply the position mapping.
new_board = x.flatten()[mapping].reshape(9, 9).copy()
# Apply digit mapping
return digit_map[new_board]
return apply_transformation(board), apply_transformation(solution)
def convert_subset(set_name: str, config: DataProcessConfig):
# Read CSV
inputs = []
labels = []
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
reader = csv.reader(csvfile)
next(reader) # Skip header
for source, q, a, rating in reader:
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
assert len(q) == 81 and len(a) == 81
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
# If subsample_size is specified for the training set,
# randomly sample the desired number of examples.
if set_name == "train" and config.subsample_size is not None:
total_samples = len(inputs)
if config.subsample_size < total_samples:
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
inputs = [inputs[i] for i in indices]
labels = [labels[i] for i in indices]
# Generate dataset
num_augments = config.num_aug if set_name == "train" else 0
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
puzzle_id = 0
example_id = 0
results["puzzle_indices"].append(0)
results["group_indices"].append(0)
for orig_inp, orig_out in zip(tqdm(inputs), labels):
for aug_idx in range(1 + num_augments):
# First index is not augmented
if aug_idx == 0:
inp, out = orig_inp, orig_out
else:
inp, out = shuffle_sudoku(orig_inp, orig_out)
# Push puzzle (only single example)
results["inputs"].append(inp)
results["labels"].append(out)
example_id += 1
puzzle_id += 1
results["puzzle_indices"].append(example_id)
results["puzzle_identifiers"].append(0)
# Push group
results["group_indices"].append(puzzle_id)
# To Numpy
def _seq_to_numpy(seq):
arr = np.concatenate(seq).reshape(len(seq), -1)
assert np.all((arr >= 0) & (arr <= 9))
return arr + 1
results = {
"inputs": _seq_to_numpy(results["inputs"]),
"labels": _seq_to_numpy(results["labels"]),
"group_indices": np.array(results["group_indices"], dtype=np.int32),
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
}
# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=81,
vocab_size=10 + 1, # PAD + "0" ... "9"
pad_id=0,
ignore_label_id=0,
blank_identifier_id=0,
num_puzzle_identifiers=1,
total_groups=len(results["group_indices"]) - 1,
mean_puzzle_examples=1,
total_puzzles=len(results["group_indices"]) - 1,
sets=["all"]
)
# Save metadata as JSON.
save_dir = os.path.join(config.output_dir, set_name)
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
json.dump(metadata.model_dump(), f)
# Save data
for k, v in results.items():
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
# Save IDs mapping (for visualization only)
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
json.dump(["<blank>"], f)
@cli.command(singleton=True)
def preprocess_data(config: DataProcessConfig):
convert_subset("train", config)
convert_subset("test", config)
if __name__ == "__main__":
cli()

49
dataset/common.py Normal file
View File

@@ -0,0 +1,49 @@
from typing import List, Optional
import pydantic
import numpy as np
# Global list mapping each dihedral transform id to its inverse.
# Index corresponds to the original tid, and the value is its inverse.
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
class PuzzleDatasetMetadata(pydantic.BaseModel):
pad_id: int
ignore_label_id: Optional[int]
blank_identifier_id: int
vocab_size: int
seq_len: int
num_puzzle_identifiers: int
total_groups: int
mean_puzzle_examples: float
total_puzzles: int
sets: List[str]
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
"""8 dihedral symmetries by rotate, flip and mirror"""
if tid == 0:
return arr # identity
elif tid == 1:
return np.rot90(arr, k=1)
elif tid == 2:
return np.rot90(arr, k=2)
elif tid == 3:
return np.rot90(arr, k=3)
elif tid == 4:
return np.fliplr(arr) # horizontal flip
elif tid == 5:
return np.flipud(arr) # vertical flip
elif tid == 6:
return arr.T # transpose (reflection along main diagonal)
elif tid == 7:
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
else:
return arr
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])

177
evaluators/arc.py Normal file
View File

@@ -0,0 +1,177 @@
from typing import Dict, Sequence, Optional
import os
import json
import torch
import numpy as np
from numba import njit
import torch.distributed as dist
from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
from dataset.common import PuzzleDatasetMetadata
@njit
def _crop(grid: np.ndarray):
"""Find maximum-sized rectangle without any EOS token inside. """
grid = grid.reshape(30, 30)
max_area = 0
max_size = (0, 0)
nr, nc = grid.shape
num_c = nc
for num_r in range(1, nr + 1):
# Scan for maximum c
for c in range(1, num_c + 1):
x = grid[num_r - 1, c - 1]
if (x < 2) | (x > 11):
num_c = c - 1
break
area = num_r * num_c
if area > max_area:
max_area = area
max_size = (num_r, num_c)
return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
class ARC:
required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
def __init__(self, data_path: str,
eval_metadata: PuzzleDatasetMetadata,
submission_K: int = 2,
pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
aggregated_voting: bool = True):
super().__init__()
self.pass_Ks = pass_Ks
self.submission_K = submission_K
self.aggregated_voting = aggregated_voting
self.blank_identifier_id = eval_metadata.blank_identifier_id
# Load identifiers and test puzzles
with open(os.path.join(data_path, "identifiers.json"), "r") as f:
self.identifier_map = json.load(f)
with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
self.test_puzzles = json.load(f)
# States
self._local_hmap = {}
self._local_preds = {}
def begin_eval(self):
if not self.aggregated_voting:
# Clear previous predictions
self._local_hmap = {}
self._local_preds = {}
def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
# Collect required outputs to CPU
outputs = {}
q_values = None
for collection in (batch, preds):
for k, v in collection.items():
if k in self.required_outputs:
if k == "q_halt_logits":
q_values = v.to(torch.float64).sigmoid().cpu()
else:
outputs[k] = v.cpu()
assert q_values is not None
# Remove padding from outputs
mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
outputs = {k: v[mask] for k, v in outputs.items()}
# Get predictions
for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
name = self.identifier_map[identifier]
orig_name, _inverse_fn = inverse_aug(name)
input_hash = grid_hash(_inverse_fn(_crop(input)))
pred = _inverse_fn(_crop(pred))
assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
# Store into local state
pred_hash = grid_hash(pred)
self._local_hmap[pred_hash] = pred
self._local_preds.setdefault(orig_name, {})
self._local_preds[orig_name].setdefault(input_hash, [])
self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
# Gather predictions to rank 0 for voting
global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
# Rank 0 logic
if rank != 0:
return
submission = {}
correct = [0.0 for _ in range(len(self.pass_Ks))]
for name, puzzle in self.test_puzzles.items():
# Process test examples in this puzzle
submission[name] = []
num_test_correct = [0 for _ in range(len(self.pass_Ks))]
for pair in puzzle["test"]:
input_hash = grid_hash(arc_grid_to_np(pair["input"]))
label_hash = grid_hash(arc_grid_to_np(pair["output"]))
p_map = {}
for hmap, preds in global_hmap_preds: # type: ignore
for h, q in preds.get(name, {}).get(input_hash, {}):
p_map.setdefault(h, [0, 0])
p_map[h][0] += 1
p_map[h][1] += q
if not len(p_map):
print (f"Puzzle {name} has no predictions.")
continue
for h, stats in p_map.items():
stats[1] /= stats[0]
p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
# vote for different Ks
for i, k in enumerate(self.pass_Ks):
ok = False
for h, stats in p_map[:k]:
ok |= h == label_hash
num_test_correct[i] += ok
# Query grids
pred_grids = []
for h, stats in p_map[:self.submission_K]:
for hmap, preds in global_hmap_preds: # type: ignore
if h in hmap:
pred_grids.append(hmap[h])
break
# Pad to K
while len(pred_grids) < self.submission_K:
pred_grids.append(pred_grids[0])
submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
# Total correctness
for i in range(len(self.pass_Ks)):
correct[i] += num_test_correct[i] / len(puzzle["test"])
# Save submission
if save_path is not None:
with open(os.path.join(save_path, "submission.json"), "w") as f:
json.dump(submission, f)
# Final result
all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
return all_results

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

View File

File diff suppressed because one or more lines are too long

32
models/common.py Normal file
View File

@@ -0,0 +1,32 @@
import math
import torch
from torch import nn
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
with torch.no_grad():
if std == 0:
tensor.zero_()
else:
sqrt2 = math.sqrt(2)
a = math.erf(lower / sqrt2)
b = math.erf(upper / sqrt2)
z = (b - a) / 2
c = (2 * math.pi) ** -0.5
pdf_u = c * math.exp(-0.5 * lower ** 2)
pdf_l = c * math.exp(-0.5 * upper ** 2)
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
tensor.uniform_(a, b)
tensor.erfinv_()
tensor.mul_(sqrt2 * comp_std)
tensor.clip_(lower * comp_std, upper * comp_std)
return tensor

40
models/ema.py Normal file
View File

@@ -0,0 +1,40 @@
import copy
import torch.nn as nn
class EMAHelper(object):
def __init__(self, mu=0.999):
self.mu = mu
self.shadow = {}
def register(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
def ema(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
param.data.copy_(self.shadow[name].data)
def ema_copy(self, module):
module_copy = copy.deepcopy(module)
self.ema(module_copy)
return module_copy
def state_dict(self):
return self.shadow
def load_state_dict(self, state_dict):
self.shadow = state_dict

169
models/layers.py Normal file
View File

@@ -0,0 +1,169 @@
from typing import Tuple
import einops
import torch
from torch import nn
import torch.nn.functional as F
#try:
# from flash_attn_interface import flash_attn_func # type: ignore[import]
#except ImportError:
# # Fallback to FlashAttention 2
# from flash_attn import flash_attn_func # type: ignore[import]
from torch.nn.functional import scaled_dot_product_attention
from models.common import trunc_normal_init_
CosSin = Tuple[torch.Tensor, torch.Tensor]
def _find_multiple(a, b):
return (-(a // -b)) * b
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [bs, seq_len, num_heads, head_dim]
# cos, sin: [seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class CastedLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool):
super().__init__()
# Truncated LeCun normal init
self.weight = nn.Parameter(
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
)
self.bias = None
if bias:
# Zero init bias
self.bias = nn.Parameter(torch.zeros((out_features, )))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
class CastedEmbedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
init_std: float,
cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Truncated LeCun normal init
self.embedding_weight = nn.Parameter(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.embedding(input, self.embedding_weight.to(self.cast_to))
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings, base, device=None):
super().__init__()
# RoPE
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
def forward(self):
return self.cos_cached, self.sin_cached
class Attention(nn.Module):
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
super().__init__()
self.hidden_size = hidden_size
self.head_dim = head_dim
self.output_size = head_dim * num_heads
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads
self.causal = causal
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# hidden_states: [bs, seq_len, num_heads, head_dim]
qkv = self.qkv_proj(hidden_states)
# Split head
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query = qkv[:, :, :self.num_heads]
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
# RoPE
if cos_sin is not None:
cos, sin = cos_sin
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# flash attn
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)
class LinearSwish(nn.Module):
def __init__(self, hidden_size: int, reverse=False):
super().__init__()
self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
self.reverse = reverse
def forward(self, x):
if self.reverse:
return F.silu(self.linear(x))
else:
return self.linear(F.silu(x))
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, expansion: float):
super().__init__()
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)

103
models/losses.py Normal file
View File

@@ -0,0 +1,103 @@
from typing import Any, Tuple, Dict, Sequence, Optional
import torch
import torch.nn.functional as F
from torch import nn
import math
IGNORE_LABEL_ID = -100
def s(x, epsilon=1e-30):
return torch.where(
x<0,
1/(1-x+ epsilon),
x + 1
)
def log_stablemax(x, dim=-1):
s_x = s(x)
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
if valid_mask is None:
valid_mask = (labels != ignore_index)
transformed_labels = torch.where(valid_mask, labels, 0)
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
return -torch.where(valid_mask, prediction_logprobs, 0)
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
# Cast logits to f32
# Flatten logits
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
class ACTLossHead(nn.Module):
def __init__(self, model: nn.Module, loss_type: str):
super().__init__()
self.model = model
self.loss_fn = globals()[loss_type]
def initial_carry(self, *args, **kwargs):
return self.model.initial_carry(*args, **kwargs) # type: ignore
def forward(
self,
return_keys: Sequence[str],
# Model args
**model_kwargs,
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
# Model logits
# B x SeqLen x D
new_carry, outputs = self.model(**model_kwargs)
labels = new_carry.current_data["labels"]
with torch.no_grad():
# Preds
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
# Correctness
mask = (labels != IGNORE_LABEL_ID)
loss_counts = mask.sum(-1)
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
seq_is_correct = is_correct.sum(-1) == loss_counts
# Metrics (halted)
valid_metrics = new_carry.halted & (loss_counts > 0)
metrics = {
"count": valid_metrics.sum(),
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
}
# Losses
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
metrics.update({
"lm_loss": lm_loss.detach(),
"q_halt_loss": q_halt_loss.detach(),
})
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
q_continue_loss = 0
if "target_q_continue" in outputs:
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
metrics["q_continue_loss"] = q_continue_loss.detach()
# Filter outputs for return
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()

View File

@@ -0,0 +1,294 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
@dataclass
class HierarchicalReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L: torch.Tensor
@dataclass
class HierarchicalReasoningModel_ACTV1Carry:
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool=False # use mlp on L instead of transformer
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
# Input injection (add)
hidden_states = hidden_states + input_injection
# Layers
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return HierarchicalReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
return HierarchicalReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
with torch.no_grad():
z_H, z_L = carry.z_H, carry.z_L
for _H_step in range(self.config.H_cycles):
for _L_step in range(self.config.L_cycles):
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
if not (_H_step == self.config.H_cycles - 1):
z_H = self.H_level(z_H, z_L, **seq_info)
assert not z_H.requires_grad and not z_L.requires_grad
# 1-step grad
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.H_level(z_H, z_L, **seq_info)
# LM Outputs
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
# Q head
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class HierarchicalReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return HierarchicalReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

View File

@@ -0,0 +1,342 @@
"""
HRM ACT V2: Transformer Baseline for Architecture Ablation
This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
Key changes from V1:
1. REMOVED hierarchical split (no separate H and L levels)
2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
3. KEPT ACT outer loop structure intact
4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
Architecture: Single-level transformer that processes the full 30x30 grid as a
900-token sequence, with the same positional encodings and sparse embeddings as V1.
"""
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
@dataclass
class Model_ACTV2InnerCarry:
z_H: torch.Tensor
@dataclass
class Model_ACTV2Carry:
inner_carry: Model_ACTV2InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class Model_ACTV2Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
H_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
act_inference: bool = False # If True, use adaptive computation during inference
forward_dtype: str = "bfloat16"
class Model_ACTV2Block(nn.Module):
def __init__(self, config: Model_ACTV2Config) -> None:
super().__init__()
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False,
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# Post Norm
# Self Attention
hidden_states = rms_norm(
hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
variance_epsilon=self.norm_eps,
)
# Fully Connected
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
return hidden_states
class Model_ACTV2ReasoningModule(nn.Module):
def __init__(self, layers: List[Model_ACTV2Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
# Input injection (add)
hidden_states = hidden_states + input_injection
# Layers
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class Model_ACTV2_Inner(nn.Module):
def __init__(self, config: Model_ACTV2Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(
self.config.vocab_size,
self.config.hidden_size,
init_std=embed_init_std,
cast_to=self.forward_dtype,
)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(
self.config.num_puzzle_identifiers,
self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size,
init_std=0,
cast_to=self.forward_dtype,
)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(
dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta,
)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(
self.config.seq_len + self.puzzle_emb_len,
self.config.hidden_size,
init_std=embed_init_std,
cast_to=self.forward_dtype,
)
else:
raise NotImplementedError()
# Reasoning Layers
self.H_level = Model_ACTV2ReasoningModule(
layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
)
# Initial states
self.H_init = nn.Buffer(
trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
persistent=True,
)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat(
(puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return Model_ACTV2InnerCarry(
z_H=torch.empty(
batch_size,
self.config.seq_len + self.puzzle_emb_len,
self.config.hidden_size,
dtype=self.forward_dtype,
),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
return Model_ACTV2InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
)
def forward(
self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# 1-step grad
z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
# LM Outputs
new_carry = Model_ACTV2InnerCarry(
z_H=z_H.detach(),
) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
# Q head
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class Model_ACTV2(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = Model_ACTV2Config(**config_dict)
self.inner = Model_ACTV2_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return Model_ACTV2Carry(
inner_carry=self.inner.empty_carry(
batch_size
), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size,), dtype=torch.int32),
halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()},
)
def forward(
self,
carry: Model_ACTV2Carry,
batch: Dict[str, torch.Tensor],
compute_target_q: bool = False,
) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
for k, v in carry.current_data.items()
}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
new_inner_carry, new_current_data
)
outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# Check if adaptive computation should be used
use_adaptive = (self.config.halt_max_steps > 1) and (
(self.training and self.config.act_enabled)
or (not self.training and self.config.act_inference)
)
if use_adaptive:
# Halt signal based on Q-values (but always halt at max steps)
q_halt_signal = q_halt_logits > q_continue_logits
halted = halted | q_halt_signal
# Store actual steps used for logging (only during inference)
if not self.training:
outputs["actual_steps"] = new_steps.float()
# Exploration (only during training)
if self.training:
min_halt_steps = (
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
# Compute target Q (only during training)
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
if self.training and compute_target_q:
next_q_halt_logits, next_q_continue_logits = self.inner(
new_inner_carry, new_current_data
)[-1]
outputs["target_q_continue"] = torch.sigmoid(
torch.where(
is_last_step,
next_q_halt_logits,
torch.maximum(next_q_halt_logits, next_q_continue_logits),
)
)
return Model_ACTV2Carry(
new_inner_carry, new_steps, halted, new_current_data
), outputs

View File

@@ -0,0 +1,297 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
from models.common import trunc_normal_init_
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # ignored
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool = False # use mlp on L instead of transformer
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
hidden_states = hidden_states + input_injection
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
it = 0
z_H, z_L = carry.z_H, carry.z_L
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.L_level(z_H, z_L, **seq_info)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.L_level(z_H, z_L, **seq_info)
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

View File

@@ -0,0 +1,323 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
from models.common import trunc_normal_init_
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L1: torch.Tensor
z_L2: torch.Tensor
z_L3: torch.Tensor
z_L4: torch.Tensor
z_L5: torch.Tensor
z_L6: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # ignored
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool = False # use mlp on L instead of transformer
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
hidden_states = hidden_states + input_injection
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1),
z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2),
z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3),
z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4),
z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5),
z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6),
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
it = 0
z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6]
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_H = self.L_level(z_H, z_L_, **seq_info)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_H = self.L_level(z_H, z_L_, **seq_info)
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

View File

@@ -0,0 +1,294 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
from models.common import trunc_normal_init_
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_L: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # ignored
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool = False # use mlp on L instead of transformer
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
it = 0
z_L = carry.z_L
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L + input_embeddings, **seq_info)
z_L = self.L_level(z_L, **seq_info)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L + input_embeddings, **seq_info)
z_L = self.L_level(z_L, **seq_info)
z_out = z_L
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_out)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

132
models/sparse_embedding.py Normal file
View File

@@ -0,0 +1,132 @@
from typing import Union
import torch
from torch import nn
import torch.distributed as dist
from torch.optim.optimizer import Optimizer, ParamsT
from models.common import trunc_normal_init_
class CastedSparseEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Real Weights
# Truncated LeCun normal init
self.weights = nn.Buffer(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
)
# Local weights and IDs
# Local embeddings, with gradient, not persistent
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
# Local embedding IDs, not persistent
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not self.training:
# Test mode, no gradient
return self.weights[inputs].to(self.cast_to)
# Training mode, fill puzzle embedding from weights
with torch.no_grad():
self.local_weights.copy_(self.weights[inputs])
self.local_ids.copy_(inputs)
return self.local_weights.to(self.cast_to)
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
def __init__(
self,
params: ParamsT,
world_size: int,
lr: Union[float, torch.Tensor] = 1e-3,
weight_decay: float = 1e-2,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
weight_decay=weight_decay,
world_size=world_size
)
super().__init__(params, defaults)
@torch.no_grad
def step(self, closure=None): # type: ignore
for group in self.param_groups:
# Find the sparse embedding weights
local_weights_grad = None
local_ids = None
weights = None
assert len(group["params"]) == 3
for p in group["params"]:
if p.requires_grad:
local_weights_grad = p.grad
elif p.ndim == 1:
local_ids = p
elif p.ndim == 2:
weights = p
else:
assert False
assert local_ids is not None
assert weights is not None
# Apply SignSGD
# Adam ≈ SignSGD if gradient is very sparse
if local_weights_grad is not None:
_sparse_emb_signsgd_dist(
local_weights_grad,
local_ids,
weights,
lr=group["lr"],
weight_decay=group["weight_decay"],
world_size=group["world_size"]
)
def _sparse_emb_signsgd_dist(
local_weights_grad: torch.Tensor,
local_ids: torch.Tensor,
weights: torch.Tensor,
lr: float,
weight_decay: float,
world_size: int
) -> None:
N, D = local_weights_grad.shape
# All-gather
all_weights_grad = local_weights_grad
all_ids = local_ids
if world_size > 1:
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
dist.all_gather_into_tensor(all_ids, local_ids)
# Unique
grad_ids, inv = all_ids.unique(return_inverse=True)
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
# SignSGD with decoupled weight decay
p = weights[grad_ids]
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
# Write updated slices back
weights[grad_ids] = p

654
pretrain.py Normal file
View File

@@ -0,0 +1,654 @@
from typing import Optional, Any, Sequence, List
from dataclasses import dataclass
import os
import math
import yaml
import shutil
import copy
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader
import tqdm
import wandb
import coolname
import hydra
import pydantic
from omegaconf import DictConfig
from adam_atan2 import AdamATan2
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
from utils.functions import load_model_class, get_model_source_path
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
from models.ema import EMAHelper
class LossConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
class ArchConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
loss: LossConfig
class EvaluatorConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="allow")
name: str
class PretrainConfig(pydantic.BaseModel):
# Config
arch: ArchConfig
# Data
data_paths: List[str]
data_paths_test: List[str] = []
# Evaluators
evaluators: List[EvaluatorConfig] = []
# Hyperparams
global_batch_size: int
epochs: int
lr: float
lr_min_ratio: float
lr_warmup_steps: int
weight_decay: float
beta1: float
beta2: float
# Puzzle embedding
puzzle_emb_lr: float
puzzle_emb_weight_decay: float
# Names
project_name: Optional[str] = None
run_name: Optional[str] = None
load_checkpoint: Optional[str] = None
checkpoint_path: Optional[str] = None
# Extras
seed: int = 0
checkpoint_every_eval: bool = False
eval_interval: Optional[int] = None
min_eval_interval: Optional[int] = 0 # when to start eval
eval_save_outputs: List[str] = []
ema: bool = False # use Exponential-Moving-Average
ema_rate: float = 0.999 # EMA-rate
freeze_weights: bool = False # If True, freeze weights and only learn the embeddings
@dataclass
class TrainState:
model: nn.Module
optimizers: Sequence[torch.optim.Optimizer]
optimizer_lrs: Sequence[float]
carry: Any
step: int
total_steps: int
def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
dataset = PuzzleDataset(PuzzleDatasetConfig(
seed=config.seed,
dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths,
rank=rank,
num_replicas=world_size,
**kwargs
), split=split)
dataloader = DataLoader(
dataset,
batch_size=None,
num_workers=1,
prefetch_factor=8,
pin_memory=True,
persistent_workers=True
)
return dataloader, dataset.metadata
def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
model_cfg = dict(
**config.arch.__pydantic_extra__, # type: ignore
batch_size=config.global_batch_size // world_size,
vocab_size=train_metadata.vocab_size,
seq_len=train_metadata.seq_len,
num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
causal=False # Non-autoregressive
)
# Instantiate model with loss head
model_cls = load_model_class(config.arch.name)
loss_head_cls = load_model_class(config.arch.loss.name)
with torch.device("cuda"):
model: nn.Module = model_cls(model_cfg)
print(model)
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
if "DISABLE_COMPILE" not in os.environ:
model = torch.compile(model) # type: ignore
# Load checkpoint
if rank == 0:
load_checkpoint(model, config)
# Broadcast parameters from rank 0
if world_size > 1:
with torch.no_grad():
for param in list(model.parameters()) + list(model.buffers()):
dist.broadcast(param, src=0)
# Optimizers and lr
if config.arch.puzzle_emb_ndim == 0:
optimizers = [
AdamATan2(
model.parameters(),
lr=0, # Needs to be set by scheduler
weight_decay=config.weight_decay,
betas=(config.beta1, config.beta2)
)
]
optimizer_lrs = [
config.lr
]
elif config.freeze_weights:
optimizers = [
CastedSparseEmbeddingSignSGD_Distributed(
model.model.puzzle_emb.buffers(), # type: ignore
lr=0, # Needs to be set by scheduler
weight_decay=config.puzzle_emb_weight_decay,
world_size=world_size
)
]
optimizer_lrs = [
config.puzzle_emb_lr
]
else:
optimizers = [
CastedSparseEmbeddingSignSGD_Distributed(
model.model.puzzle_emb.buffers(), # type: ignore
lr=0, # Needs to be set by scheduler
weight_decay=config.puzzle_emb_weight_decay,
world_size=world_size
),
AdamATan2(
model.parameters(),
lr=0, # Needs to be set by scheduler
weight_decay=config.weight_decay,
betas=(config.beta1, config.beta2)
)
]
optimizer_lrs = [
config.puzzle_emb_lr,
config.lr
]
return model, optimizers, optimizer_lrs
def mix_weights_direct(device, alpha, net, nets):
sd = []
for i in range(len(nets)):
sd += [nets[i].state_dict()]
sd_alpha = {}
for k in sd[0].keys():
comb_net = alpha[0]*sd[0][k].to(device)
for i in range(1,len(nets)):
comb_net += alpha[i]*sd[i][k].to(device)
sd_alpha[k] = comb_net
net.load_state_dict(sd_alpha)
return net
def cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
):
if current_step < num_warmup_steps:
return base_lr * float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
# Estimated total training steps
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
# Model
model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)
return TrainState(
step=0,
total_steps=total_steps,
model=model,
optimizers=optimizers,
optimizer_lrs=optimizer_lrs,
carry=None
)
def save_train_state(config: PretrainConfig, train_state: TrainState):
# FIXME: Only saved model.
if config.checkpoint_path is None:
return
os.makedirs(config.checkpoint_path, exist_ok=True)
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
def load_checkpoint(model: nn.Module, config: PretrainConfig):
if config.load_checkpoint is not None:
print(f"Loading checkpoint {config.load_checkpoint}")
# Load state dict
state_dict = torch.load(config.load_checkpoint, map_location="cuda")
# Resize and reset puzzle emb if needed
puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore
if puzzle_emb_name in state_dict:
puzzle_emb = state_dict[puzzle_emb_name]
if puzzle_emb.shape != expected_shape:
print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}")
# Re-initialize using mean
state_dict[puzzle_emb_name] = (
torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous()
)
model.load_state_dict(state_dict, assign=True)
def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
return cosine_schedule_with_warmup_lr_lambda(
current_step=train_state.step,
base_lr=base_lr,
num_warmup_steps=round(config.lr_warmup_steps),
num_training_steps=train_state.total_steps,
min_ratio=config.lr_min_ratio
)
def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]:
data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths
# Initialize evaluators
evaluators = []
for cfg in config.evaluators:
for data_path in data_paths:
cls = load_model_class(cfg.name, "evaluators.")(
data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__
) # type: ignore
evaluators.append(cls)
return evaluators
def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
train_state.step += 1
if train_state.step > train_state.total_steps: # At most train_total_steps
return
# To device
batch = {k: v.cuda() for k, v in batch.items()}
# Init carry if it is None
if train_state.carry is None:
with torch.device("cuda"):
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
# Forward
train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
((1 / global_batch_size) * loss).backward()
# Allreduce
if world_size > 1:
for param in train_state.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad)
# Apply optimizer
lr_this_step = None
for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
lr_this_step = compute_lr(base_lr, config, train_state)
for param_group in optim.param_groups:
param_group['lr'] = lr_this_step
optim.step()
optim.zero_grad()
# Reduce metrics
if len(metrics):
assert not any(v.requires_grad for v in metrics.values())
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
# Reduce and reconstruct
metric_values = torch.stack([metrics[k] for k in metric_keys])
if world_size > 1:
dist.reduce(metric_values, dst=0)
if rank == 0:
metric_values = metric_values.cpu().numpy()
reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
# Postprocess
count = max(reduced_metrics["count"], 1) # Avoid NaNs
reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
reduced_metrics["train/lr"] = lr_this_step
return reduced_metrics
def evaluate(
config: PretrainConfig,
train_state: TrainState,
eval_loader: torch.utils.data.DataLoader,
eval_metadata: PuzzleDatasetMetadata,
evaluators: List[Any],
rank: int,
world_size: int,
cpu_group: Optional[dist.ProcessGroup],
):
reduced_metrics = None
with torch.inference_mode():
return_keys = set(config.eval_save_outputs)
for evaluator in evaluators:
evaluator.begin_eval()
return_keys.update(evaluator.required_outputs)
# Run evaluation
set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
save_preds = {}
metric_keys = []
metric_values = None
carry = None
processed_batches = 0
for set_name, batch, global_batch_size in eval_loader:
processed_batches += 1
if rank == 0:
print(f"Processing batch {processed_batches}: {set_name}")
# To device
batch = {k: v.cuda() for k, v in batch.items()}
with torch.device("cuda"):
carry = train_state.model.initial_carry(batch) # type: ignore
# Forward
inference_steps = 0
while True:
carry, loss, metrics, preds, all_finish = train_state.model(
carry=carry, batch=batch, return_keys=return_keys
)
inference_steps += 1
if all_finish:
break
if rank == 0:
print(f" Completed inference in {inference_steps} steps")
for collection in (batch, preds):
for k, v in collection.items():
if k in config.eval_save_outputs:
save_preds.setdefault(k, [])
save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
for evaluator in evaluators:
evaluator.update_batch(batch, preds)
del carry, loss, preds, batch, all_finish
# Aggregate metrics
set_id = set_ids[set_name]
if metric_values is None:
metric_keys = list(
sorted(metrics.keys())
) # Sort keys to guarantee all processes use the same order.
metric_values = torch.zeros(
(len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda"
)
metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
del metrics
# concatenate save preds
save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()}
# Save preds
if config.checkpoint_path is not None and len(save_preds):
# Each rank save predictions independently
os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)
torch.save(
save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")
)
del save_preds
# Reduce to rank 0
if metric_values is not None:
if world_size > 1:
dist.reduce(metric_values, dst=0)
if rank == 0:
reduced_metrics = metric_values.cpu().numpy()
reduced_metrics = {
set_name: {
metric_name: reduced_metrics[set_id, metric_id]
for metric_id, metric_name in enumerate(metric_keys)
}
for set_id, set_name in enumerate(set_ids)
}
# Postprocess
for set_name, m in reduced_metrics.items():
count = m.pop("count")
reduced_metrics[set_name] = {k: v / count for k, v in m.items()}
# Run evaluators
if rank == 0:
print(f"\nRunning {len(evaluators)} evaluator(s)...")
for i, evaluator in enumerate(evaluators):
if rank == 0:
print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}")
# Path for saving
evaluator_save_path = None
if config.checkpoint_path is not None:
evaluator_save_path = os.path.join(
config.checkpoint_path,
f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}",
)
os.makedirs(evaluator_save_path, exist_ok=True)
# Run and log
metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group)
if rank == 0 and metrics is not None:
if reduced_metrics is None:
reduced_metrics = {}
reduced_metrics.update(metrics)
print(f" Completed {evaluator.__class__.__name__}")
if rank == 0:
print("All evaluators completed!")
return reduced_metrics
def save_code_and_config(config: PretrainConfig):
if config.checkpoint_path is None or wandb.run is None:
return
os.makedirs(config.checkpoint_path, exist_ok=True)
# Copy code
code_list = [
get_model_source_path(config.arch.name),
get_model_source_path(config.arch.loss.name)
]
for code_file in code_list:
if code_file is not None:
code_name = os.path.basename(code_file)
shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
# Dump config as yaml
config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
with open(config_file, "wt") as f:
yaml.dump(config.model_dump(), f)
# Log code
wandb.run.log_code(config.checkpoint_path)
def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
objects = [None]
if rank == 0:
config = PretrainConfig(**hydra_config) # type: ignore
# Naming
if config.project_name is None:
config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch"
if config.run_name is None:
config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
if config.checkpoint_path is None:
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
objects = [config]
if world_size > 1:
dist.broadcast_object_list(objects, src=0)
return objects[0] # type: ignore
@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
def launch(hydra_config: DictConfig):
RANK = 0
WORLD_SIZE = 1
CPU_PROCESS_GROUP = None
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
# CPU GLOO process group
CPU_PROCESS_GROUP = dist.new_group(backend="gloo")
assert (
dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE
)
# Load sync'ed config
config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
# Seed RNGs to ensure consistency
torch.random.manual_seed(config.seed + RANK)
# Dataset
train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
total_iters = config.epochs // train_epochs_per_iter
assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
try:
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
except:
print("NO EVAL DATA FOUND")
eval_loader = eval_metadata = None
try:
evaluators = create_evaluators(config, eval_metadata)
except:
print("No evaluator found")
evaluators = []
# Train state
train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)
# Progress bar and logger
progress_bar = None
ema_helper = None
if RANK == 0:
progress_bar = tqdm.tqdm(total=train_state.total_steps)
wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
save_code_and_config(config)
if config.ema:
print('Setup EMA')
ema_helper = EMAHelper(mu=config.ema_rate)
ema_helper.register(train_state.model)
# Training Loop
for _iter_id in range(total_iters):
print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
############ Train Iter
if RANK == 0:
print("TRAIN")
train_state.model.train()
for set_name, batch, global_batch_size in train_loader:
metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
if RANK == 0 and metrics is not None:
wandb.log(metrics, step=train_state.step)
progress_bar.update(train_state.step - progress_bar.n) # type: ignore
if config.ema:
ema_helper.update(train_state.model)
if _iter_id >= config.min_eval_interval:
############ Evaluation
if RANK == 0:
print("EVALUATE")
if config.ema:
print("SWITCH TO EMA")
train_state_eval = copy.deepcopy(train_state)
train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
else:
train_state_eval = train_state
train_state_eval.model.eval()
metrics = evaluate(config,
train_state_eval,
eval_loader,
eval_metadata,
evaluators,
rank=RANK,
world_size=WORLD_SIZE,
cpu_group=CPU_PROCESS_GROUP)
if RANK == 0 and metrics is not None:
wandb.log(metrics, step=train_state.step)
############ Checkpointing
if RANK == 0:
print("SAVE CHECKPOINT")
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
save_train_state(config, train_state_eval)
if config.ema:
del train_state_eval
# finalize
if dist.is_initialized():
dist.destroy_process_group()
wandb.finish()
if __name__ == "__main__":
launch()

250
puzzle_dataset.py Normal file
View File

@@ -0,0 +1,250 @@
import os
import json
from typing import Tuple, List, Dict, Optional
import numpy as np
import pydantic
import torch
from torch.utils.data import IterableDataset, get_worker_info
from models.losses import IGNORE_LABEL_ID
from dataset.common import PuzzleDatasetMetadata
from argdantic import ArgParser
from pydantic import BaseModel
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
# Pack examples into a full batch
batch = []
batch_puzzle_indices = []
current_size = 0
while (start_index < group_order.size) and (current_size < global_batch_size):
# Pick a group and a puzzle from that group
group_id = group_order[start_index]
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
start_index += 1
# Get range of the puzzle
puzzle_start = puzzle_indices[puzzle_id]
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
append_size = min(puzzle_size, global_batch_size - current_size)
# Put into batch
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
current_size += append_size
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
class PuzzleDatasetConfig(pydantic.BaseModel):
seed: int
dataset_paths: List[str]
global_batch_size: int
test_set_mode: bool
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
rank: int
num_replicas: int
class PuzzleDataset(IterableDataset):
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
super().__init__()
self.config = config
self.split = split
# Merge multiple metadata
prev_seq_len = None
prev_vocab_size = None
prev_pad_id = None
prev_ignore_label_id = None
prev_blank_identifier_id = None
prev_sets = None
prev_num_identifiers = None
mean_puzzle_examples = 0
total_puzzles = 0
total_groups = 0
num_identifiers = 0
for dataset_path in config.dataset_paths:
current_metadata = self._load_metadata(dataset_path)
if prev_seq_len is None:
prev_seq_len = current_metadata.seq_len
prev_vocab_size = current_metadata.vocab_size
prev_pad_id = current_metadata.pad_id
prev_ignore_label_id = current_metadata.ignore_label_id
prev_blank_identifier_id = current_metadata.blank_identifier_id
prev_sets = current_metadata.sets
prev_num_identifiers = current_metadata.num_puzzle_identifiers
else:
assert prev_seq_len == current_metadata.seq_len
assert prev_vocab_size == current_metadata.vocab_size
assert prev_pad_id == current_metadata.pad_id
assert prev_ignore_label_id == current_metadata.ignore_label_id
assert prev_blank_identifier_id == current_metadata.blank_identifier_id
assert prev_sets == current_metadata.sets
assert prev_num_identifiers == current_metadata.num_puzzle_identifiers
mean_puzzle_examples += current_metadata.mean_puzzle_examples*current_metadata.total_puzzles
total_puzzles += current_metadata.total_puzzles
total_groups += current_metadata.total_groups
num_identifiers += current_metadata.num_puzzle_identifiers
mean_puzzle_examples = mean_puzzle_examples / total_puzzles
self.metadata = PuzzleDatasetMetadata(
seq_len=prev_seq_len,
vocab_size=prev_vocab_size,
pad_id=prev_pad_id,
ignore_label_id=prev_ignore_label_id,
blank_identifier_id=prev_blank_identifier_id,
num_puzzle_identifiers=num_identifiers,
total_groups=total_groups,
mean_puzzle_examples=mean_puzzle_examples,
total_puzzles=total_puzzles,
sets=prev_sets
)
# Checks
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
# State
self._data = None
self._iters = 0
def _load_metadata(self, dataset_path) -> PuzzleDatasetMetadata:
with open(os.path.join(dataset_path, self.split, "dataset.json"), "r") as f:
return PuzzleDatasetMetadata(**json.load(f))
def _lazy_load_dataset(self):
if self._data is not None:
return
field_mmap_modes = {
"inputs": "r",
"labels": "r",
# Keep indices in memory
"puzzle_identifiers": None,
"puzzle_indices": None,
"group_indices": None
}
# Load data
self._data = {}
for set_name in self.metadata.sets: # Load subset
for i, dataset_path in enumerate(self.config.dataset_paths):
if i > 0:
set_name_ = set_name + str(i)
else:
set_name_ = set_name
self._data[set_name_] = {
field_name: np.load(os.path.join(dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
for field_name, mmap_mode in field_mmap_modes.items()
}
def _collate_batch(self, batch):
# Convert dtype
batch = {k: v.astype(np.int32) for k, v in batch.items()}
# Convert ignore label IDs
if self.metadata.ignore_label_id is not None:
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
# Pad
if batch["puzzle_identifiers"].size < self.local_batch_size:
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
pad_values = {
"inputs": self.metadata.pad_id,
"labels": IGNORE_LABEL_ID,
"puzzle_identifiers": self.metadata.blank_identifier_id
}
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
# To tensor
return {k: torch.from_numpy(v) for k, v in batch.items()}
def _iter_test(self):
for set_i, (set_name, dataset) in enumerate(self._data.items()): # type: ignore
total_examples = len(dataset["inputs"])
# Load examples one by one
start_index = 0
while start_index < total_examples:
# Compute indices
end_index = min(total_examples, start_index + self.config.global_batch_size)
local_start = start_index + self.config.rank * self.local_batch_size
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
# Get batch of examples, and also puzzle IDs
puzzle_indices = []
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
for i in range(local_start, local_end):
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
puzzle_index += 1
puzzle_indices.append(puzzle_index)
batch = self._collate_batch({
"inputs": dataset["inputs"][local_start: local_end],
"labels": dataset["labels"][local_start: local_end],
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
})
yield set_name, batch, end_index - start_index
# Advance to next batch
start_index += self.config.global_batch_size
def _iter_train(self):
for set_name, dataset in self._data.items(): # type: ignore
# Increase epoch count
self._iters += 1
# Randomly shuffle groups
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
start_index = 0
while start_index < group_order.size:
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
rng,
group_order=group_order,
puzzle_indices=dataset["puzzle_indices"],
group_indices=dataset["group_indices"],
start_index=start_index,
global_batch_size=self.config.global_batch_size,
)
# Select current rank and collate
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
# Drop last batch
if global_effective_batch_size < self.config.global_batch_size:
break
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
batch = self._collate_batch({
"inputs": dataset["inputs"][batch_indices],
"labels": dataset["labels"][batch_indices],
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
})
yield set_name, batch, global_effective_batch_size
def __iter__(self):
worker_info = get_worker_info()
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
self._lazy_load_dataset()
# Iterate using specified mode
if self.config.test_set_mode:
yield from self._iter_test()
else:
yield from self._iter_train()

20
requirements.txt Normal file
View File

@@ -0,0 +1,20 @@
torch
adam-atan2
einops
tqdm
coolname
pydantic
argdantic
wandb
omegaconf
hydra-core
huggingface_hub
packaging
ninja
wheel
setuptools
setuptools-scm
pydantic-core
huggingface_hub
numba
triton

19
utils/functions.py Normal file
View File

@@ -0,0 +1,19 @@
import importlib
import inspect
def load_model_class(identifier: str, prefix: str = "models."):
module_path, class_name = identifier.split('@')
# Import the module
module = importlib.import_module(prefix + module_path)
cls = getattr(module, class_name)
return cls
def get_model_source_path(identifier: str, prefix: str = "models."):
module_path, class_name = identifier.split('@')
module = importlib.import_module(prefix + module_path)
return inspect.getsourcefile(module)