upload
This commit is contained in:
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025. Samsung Electronics Co., Ltd. All Rights Reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
147
README.md
Normal file
147
README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
# Less is More: Recursive Reasoning with Tiny Networks
|
||||
|
||||
This is the codebase for the paper: "Less is More: Recursive Reasoning with Tiny Networks" were we present a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network.
|
||||
|
||||
[Paper](https://arxiv.org/abs/2510.04871)
|
||||
|
||||
### How TRM works
|
||||
|
||||
Tiny Recursion Model (TRM) recursively improves its predicted answer y with a tiny network. It starts with the embedded input question x and initial embedded answer y and latent z. For up to K improvements steps, it tries to improve its answer y. It does so by i) recursively updating n times its latent z given the question x, current answer y, and current latent z (recursive reasoning), and then ii) updating its answer y given the current answer y and current latent z. This recursive process allows the model to progressively improve its answer (potentially addressing any errors from its previous answer) in an extremely parameter-efficient manner while minimizing overfitting.
|
||||
|
||||
<p align="center">
|
||||
<img src="{{ site.baseurl }}/assets/images/TRM_fig.png" alt="TRM-Figure" style="width:50%">
|
||||
</p>
|
||||
|
||||
### Requirements
|
||||
|
||||
- Python 3.10 (or similar)
|
||||
- Cuda 12.6.0 (or similar)
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip wheel setuptools
|
||||
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 # install torch based on your cuda version
|
||||
pip install -r requirements.txt # install requirements
|
||||
pip install --no-cache-dir --no-build-isolation adam-atan2
|
||||
wandb login YOUR-LOGIN # login if you want the logger to sync results to your Weights & Biases (https://wandb.ai/)
|
||||
```
|
||||
|
||||
### Dataset Preparation
|
||||
|
||||
```bash
|
||||
# ARC-AGI-1
|
||||
python -m dataset.build_arc_dataset \
|
||||
--input-file-prefix kaggle/combined/arc-agi \
|
||||
--output-dir data/arc1concept-aug-1000 \
|
||||
--subsets training evaluation concept \
|
||||
--test-set-name evaluation
|
||||
|
||||
# ARC-AGI-2
|
||||
python -m dataset.build_arc_dataset \
|
||||
--input-file-prefix kaggle/combined/arc-agi \
|
||||
--output-dir data/arc2concept-aug-1000 \
|
||||
--subsets training2 evaluation2 concept \
|
||||
--test-set-name evaluation2
|
||||
|
||||
## Note: You cannot train on both ARC-AGI-1 and ARC-AGI-2 and evaluate them both because ARC-AGI-2 training data contains some ARC-AGI-1 eval data
|
||||
|
||||
# Sudoku-Extreme
|
||||
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples, 1000 augments
|
||||
|
||||
# Maze-Hard
|
||||
python dataset/build_maze_dataset.py # 1000 examples, 8 augments
|
||||
```
|
||||
|
||||
## Experiments
|
||||
|
||||
### ARC-AGI (assuming 4 H-100 GPUs):
|
||||
|
||||
```bash
|
||||
run_name="pretrain_att_arc12concept_4"
|
||||
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/arc12concept-aug-1000]" \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=4 \
|
||||
+run_name=${run_name} ema=True
|
||||
|
||||
```
|
||||
|
||||
*Runtime:* ~3 days
|
||||
|
||||
### Sudoku-Extreme (assuming 1 L40S GPU):
|
||||
|
||||
```bash
|
||||
run_name="pretrain_mlp_t_sudoku"
|
||||
python pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
||||
evaluators="[]" \
|
||||
epochs=50000 eval_interval=5000 \
|
||||
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||
arch.mlp_t=True arch.pos_encodings=none \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=6 \
|
||||
+run_name=${run_name} ema=True
|
||||
|
||||
run_name="pretrain_att_sudoku"
|
||||
python pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
|
||||
evaluators="[]" \
|
||||
epochs=50000 eval_interval=5000 \
|
||||
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=6 \
|
||||
+run_name=${run_name} ema=True
|
||||
```
|
||||
|
||||
*Runtime:* < 36 hours
|
||||
|
||||
### Maze-Hard (assuming 4 L40S GPUs):
|
||||
|
||||
```bash
|
||||
run_name="pretrain_att_maze30x30"
|
||||
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
|
||||
arch=trm \
|
||||
data_paths="[data/maze-30x30-hard-1k]" \
|
||||
evaluators="[]" \
|
||||
epochs=50000 eval_interval=5000 \
|
||||
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
|
||||
arch.L_layers=2 \
|
||||
arch.H_cycles=3 arch.L_cycles=4 \
|
||||
+run_name=${run_name} ema=True
|
||||
```
|
||||
|
||||
*Runtime:* < 24 hours
|
||||
|
||||
## Reference
|
||||
|
||||
If you find our work useful, please consider citing:
|
||||
|
||||
```bibtex
|
||||
@misc{jolicoeurmartineau2025tinyrecursionmodel,
|
||||
title={Less is More: Recursive Reasoning with Tiny Networks},
|
||||
author={Alexia Jolicoeur-Martineau},
|
||||
year={2025},
|
||||
eprint={xxxxxxx},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI},
|
||||
url={https://arxiv.org/abs/xxxxxxxxx},
|
||||
}
|
||||
```
|
||||
|
||||
and the Hierarchical Reasoning Model (HRM):
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025hierarchicalreasoningmodel,
|
||||
title={Hierarchical Reasoning Model},
|
||||
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
|
||||
year={2025},
|
||||
eprint={2506.21734},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.AI},
|
||||
url={https://arxiv.org/abs/2506.21734},
|
||||
}
|
||||
```
|
||||
|
||||
This code is based on the Hierarchical Reasoning Model [code](https://github.com/sapientinc/HRM) and the Hierarchical Reasoning Model Analysis [code](https://github.com/arcprize/hierarchical-reasoning-model-analysis).
|
||||
BIN
assets/TRM_fig.png
Normal file
BIN
assets/TRM_fig.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 346 KiB |
BIN
assets/TRM_pseudocode.png
Normal file
BIN
assets/TRM_pseudocode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 261 KiB |
24
config/arch/hrm.yaml
Normal file
24
config/arch/hrm.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
|
||||
loss:
|
||||
name: losses@ACTLossHead
|
||||
loss_type: stablemax_cross_entropy
|
||||
|
||||
halt_exploration_prob: 0.1
|
||||
halt_max_steps: 16
|
||||
|
||||
H_cycles: 2
|
||||
L_cycles: 2
|
||||
|
||||
H_layers: 4
|
||||
L_layers: 4
|
||||
|
||||
hidden_size: 512
|
||||
num_heads: 8 # min(2, hidden_size // 64)
|
||||
expansion: 4
|
||||
|
||||
puzzle_emb_ndim: ${.hidden_size}
|
||||
|
||||
pos_encodings: rope
|
||||
forward_dtype: bfloat16
|
||||
|
||||
mlp_t: False # use mlp on L instead of transformer
|
||||
18
config/arch/transformers_baseline.yaml
Normal file
18
config/arch/transformers_baseline.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
name: recursive_reasoning.transformers_baseline@Model_ACTV2
|
||||
loss:
|
||||
name: losses@ACTLossHead
|
||||
loss_type: stablemax_cross_entropy
|
||||
|
||||
halt_exploration_prob: 0.1
|
||||
halt_max_steps: 16
|
||||
|
||||
H_cycles: 1 # kept for compatibility
|
||||
H_layers: 8
|
||||
|
||||
hidden_size: 512
|
||||
num_heads: 12
|
||||
expansion: 4
|
||||
|
||||
puzzle_emb_ndim: ${.hidden_size}
|
||||
|
||||
pos_encodings: rope
|
||||
26
config/arch/trm.yaml
Normal file
26
config/arch/trm.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
|
||||
loss:
|
||||
name: losses@ACTLossHead
|
||||
loss_type: stablemax_cross_entropy
|
||||
|
||||
halt_exploration_prob: 0.1
|
||||
halt_max_steps: 16
|
||||
|
||||
H_cycles: 3
|
||||
L_cycles: 6
|
||||
|
||||
H_layers: 0
|
||||
L_layers: 2
|
||||
|
||||
hidden_size: 512
|
||||
num_heads: 8 # min(2, hidden_size // 64)
|
||||
expansion: 4
|
||||
|
||||
puzzle_emb_ndim: ${.hidden_size}
|
||||
|
||||
pos_encodings: rope
|
||||
forward_dtype: bfloat16
|
||||
|
||||
mlp_t: False # use mlp on L instead of transformer
|
||||
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
||||
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||
26
config/arch/trm_hier6.yaml
Normal file
26
config/arch/trm_hier6.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
name: recursive_reasoning.trm_hier6@TinyRecursiveReasoningModel_ACTV1
|
||||
loss:
|
||||
name: losses@ACTLossHead
|
||||
loss_type: stablemax_cross_entropy
|
||||
|
||||
halt_exploration_prob: 0.1
|
||||
halt_max_steps: 16
|
||||
|
||||
H_cycles: 3
|
||||
L_cycles: 6
|
||||
|
||||
H_layers: 0
|
||||
L_layers: 2
|
||||
|
||||
hidden_size: 512
|
||||
num_heads: 8 # min(2, hidden_size // 64)
|
||||
expansion: 4
|
||||
|
||||
puzzle_emb_ndim: ${.hidden_size}
|
||||
|
||||
pos_encodings: rope
|
||||
forward_dtype: bfloat16
|
||||
|
||||
mlp_t: False # use mlp on L instead of transformer
|
||||
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
||||
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||
26
config/arch/trm_singlez.yaml
Normal file
26
config/arch/trm_singlez.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
name: recursive_reasoning.trm_singlez@TinyRecursiveReasoningModel_ACTV1
|
||||
loss:
|
||||
name: losses@ACTLossHead
|
||||
loss_type: stablemax_cross_entropy
|
||||
|
||||
halt_exploration_prob: 0.1
|
||||
halt_max_steps: 16
|
||||
|
||||
H_cycles: 3
|
||||
L_cycles: 6
|
||||
|
||||
H_layers: 0
|
||||
L_layers: 2
|
||||
|
||||
hidden_size: 512
|
||||
num_heads: 8 # min(2, hidden_size // 64)
|
||||
expansion: 4
|
||||
|
||||
puzzle_emb_ndim: ${.hidden_size}
|
||||
|
||||
pos_encodings: rope
|
||||
forward_dtype: bfloat16
|
||||
|
||||
mlp_t: False # use mlp on L instead of transformer
|
||||
puzzle_emb_len: 16 # if non-zero, its specified to this value
|
||||
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||
42
config/cfg_pretrain.yaml
Normal file
42
config/cfg_pretrain.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
# ARC training config
|
||||
|
||||
defaults:
|
||||
- arch: trm
|
||||
- _self_
|
||||
|
||||
hydra:
|
||||
output_subdir: null
|
||||
|
||||
# Data path
|
||||
data_paths: ['data/arc-aug-1000']
|
||||
data_paths_test: []
|
||||
|
||||
evaluators:
|
||||
- name: arc@ARC
|
||||
|
||||
# Hyperparams - Training
|
||||
global_batch_size: 768
|
||||
|
||||
epochs: 100000
|
||||
eval_interval: 10000
|
||||
checkpoint_every_eval: True
|
||||
|
||||
lr: 1e-4
|
||||
lr_min_ratio: 1.0
|
||||
lr_warmup_steps: 2000
|
||||
|
||||
# Standard hyperparameter settings for LM, as used in Llama
|
||||
beta1: 0.9
|
||||
beta2: 0.95
|
||||
weight_decay: 0.1
|
||||
puzzle_emb_weight_decay: 0.1
|
||||
|
||||
# Hyperparams - Puzzle embeddings training
|
||||
puzzle_emb_lr: 1e-2
|
||||
|
||||
seed: 0
|
||||
min_eval_interval: 0 # when to start the eval
|
||||
|
||||
ema: False # use Exponential-Moving-Average
|
||||
ema_rate: 0.999 # EMA-rate
|
||||
freeze_weights: False # If True, freeze weights and only learn the embeddings
|
||||
341
dataset/build_arc_dataset.py
Normal file
341
dataset/build_arc_dataset.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from typing import List, Tuple, Dict
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import numpy as np
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dataset.common import PuzzleDatasetMetadata, dihedral_transform, inverse_dihedral_transform
|
||||
|
||||
|
||||
cli = ArgParser()
|
||||
|
||||
|
||||
class DataProcessConfig(BaseModel):
|
||||
input_file_prefix: str
|
||||
output_dir: str
|
||||
subsets: List[str]
|
||||
test_set_name: str
|
||||
test_set_name2: str = "your_test_set"
|
||||
seed: int = 42
|
||||
num_aug: int = 1000
|
||||
puzzle_identifiers_start: int = 1 # start > 1 to handle multiple datasets
|
||||
|
||||
ARCMaxGridSize = 30
|
||||
ARCAugmentRetriesFactor = 5
|
||||
|
||||
PuzzleIdSeparator = "|||"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ARCPuzzle:
|
||||
id: str
|
||||
examples: List[Tuple[np.ndarray, np.ndarray]]
|
||||
|
||||
|
||||
def arc_grid_to_np(grid: List[List[int]]):
|
||||
arr = np.array(grid)
|
||||
|
||||
# Shape check
|
||||
assert arr.ndim == 2
|
||||
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
|
||||
# Element check
|
||||
assert np.all((arr >= 0) & (arr <= 9))
|
||||
return arr.astype(np.uint8)
|
||||
|
||||
|
||||
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
|
||||
# PAD: 0, <eos>: 1, digits: 2 ... 11
|
||||
# Compute random top-left pad
|
||||
if do_translation:
|
||||
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
|
||||
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
|
||||
else:
|
||||
pad_r = pad_c = 0
|
||||
|
||||
# Pad grid
|
||||
result = []
|
||||
for grid in [inp, out]:
|
||||
nrow, ncol = grid.shape
|
||||
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
|
||||
|
||||
# Add <eos>
|
||||
eos_row, eos_col = pad_r + nrow, pad_c + ncol
|
||||
if eos_row < ARCMaxGridSize:
|
||||
grid[eos_row, pad_c:eos_col] = 1
|
||||
if eos_col < ARCMaxGridSize:
|
||||
grid[pad_r:eos_row, eos_col] = 1
|
||||
|
||||
result.append(grid.flatten())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def grid_hash(grid: np.ndarray):
|
||||
assert grid.ndim == 2
|
||||
assert grid.dtype == np.uint8
|
||||
|
||||
buffer = [x.to_bytes(1, byteorder='big') for x in grid.shape]
|
||||
buffer.append(grid.tobytes())
|
||||
|
||||
return hashlib.sha256(b"".join(buffer)).hexdigest()
|
||||
|
||||
|
||||
def puzzle_hash(puzzle: dict):
|
||||
# Hash the puzzle for checking equivalence
|
||||
hashes = []
|
||||
for example_type, example in puzzle.items():
|
||||
for input, label in example.examples:
|
||||
hashes.append(f"{grid_hash(input)}|{grid_hash(label)}")
|
||||
|
||||
hashes.sort()
|
||||
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
|
||||
|
||||
|
||||
def aug(name: str):
|
||||
# Augment plan
|
||||
trans_id = np.random.randint(0, 8)
|
||||
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
|
||||
|
||||
name_with_aug_repr = f"{name}{PuzzleIdSeparator}t{trans_id}{PuzzleIdSeparator}{''.join(str(x) for x in mapping)}"
|
||||
|
||||
def _map_grid(grid: np.ndarray):
|
||||
return dihedral_transform(mapping[grid], trans_id)
|
||||
|
||||
return name_with_aug_repr, _map_grid
|
||||
|
||||
|
||||
def inverse_aug(name: str):
|
||||
# Inverse the "aug" function
|
||||
if PuzzleIdSeparator not in name:
|
||||
return name, lambda x: x
|
||||
|
||||
trans_id, perm = name.split(PuzzleIdSeparator)[-2:]
|
||||
trans_id = int(trans_id[1:]) # Remove "t" letter
|
||||
inv_perm = np.argsort(list(perm)).astype(np.uint8)
|
||||
|
||||
def _map_grid(grid: np.ndarray):
|
||||
return inv_perm[inverse_dihedral_transform(grid, trans_id)]
|
||||
|
||||
return name.split(PuzzleIdSeparator)[0], _map_grid
|
||||
|
||||
|
||||
def convert_single_arc_puzzle(results: dict, name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
|
||||
# Convert
|
||||
dests = set(dest_mapping.values())
|
||||
converted = {dest: ARCPuzzle(name, []) for dest in dests}
|
||||
for example_type, examples in puzzle.items():
|
||||
# Map to target split
|
||||
dest = dest_mapping[example_type]
|
||||
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
|
||||
|
||||
group = [converted]
|
||||
|
||||
# Augment
|
||||
if aug_count > 0:
|
||||
hashes = {puzzle_hash(converted)}
|
||||
|
||||
for _trial in range(ARCAugmentRetriesFactor * aug_count):
|
||||
aug_name, _map_grid = aug(name)
|
||||
|
||||
# Check duplicate
|
||||
augmented = {dest: ARCPuzzle(aug_name, [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
|
||||
h = puzzle_hash(augmented)
|
||||
if h not in hashes:
|
||||
hashes.add(h)
|
||||
group.append(augmented)
|
||||
|
||||
if len(group) >= aug_count + 1:
|
||||
break
|
||||
|
||||
if len(group) < aug_count + 1:
|
||||
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
|
||||
|
||||
# Append
|
||||
for dest in dests:
|
||||
# Convert the examples
|
||||
dest_split, dest_set = dest
|
||||
|
||||
results.setdefault(dest_split, {})
|
||||
results[dest_split].setdefault(dest_set, [])
|
||||
results[dest_split][dest_set].append([converted[dest] for converted in group])
|
||||
|
||||
|
||||
def load_puzzles_arcagi(config: DataProcessConfig):
|
||||
train_examples_dest = ("train", "all")
|
||||
test_examples_map = {
|
||||
config.test_set_name: [(1.0, ("test", "all"))],
|
||||
config.test_set_name2: [(1.0, ("test", "all"))],
|
||||
"_default": [(1.0, ("train", "all"))]
|
||||
}
|
||||
|
||||
test_puzzles = {}
|
||||
results = {}
|
||||
|
||||
total_puzzles = 0
|
||||
for subset_name in config.subsets:
|
||||
# Load all puzzles in this subset
|
||||
with open(f"{config.input_file_prefix}_{subset_name}_challenges.json", "r") as f:
|
||||
puzzles = json.load(f)
|
||||
|
||||
sols_filename = f"{config.input_file_prefix}_{subset_name}_solutions.json"
|
||||
if os.path.isfile(sols_filename):
|
||||
with open(sols_filename, "r") as f:
|
||||
sols = json.load(f)
|
||||
|
||||
for puzzle_id in puzzles.keys():
|
||||
for idx, sol_grid in enumerate(sols[puzzle_id]):
|
||||
puzzles[puzzle_id]["test"][idx]["output"] = sol_grid
|
||||
else:
|
||||
# Fill with dummy
|
||||
print (f"{subset_name} solutions not found, filling with dummy")
|
||||
|
||||
for puzzle_id, puzzle in puzzles.items():
|
||||
for example in puzzle["test"]:
|
||||
example.setdefault("output", [[0]])
|
||||
|
||||
# Shuffle puzzles
|
||||
puzzles = list(puzzles.items())
|
||||
np.random.shuffle(puzzles)
|
||||
|
||||
# Assign by fraction
|
||||
for idx, (name, puzzle) in enumerate(puzzles):
|
||||
fraction = idx / len(puzzles)
|
||||
test_examples_dest = None
|
||||
for f, dest in test_examples_map.get(subset_name, test_examples_map["_default"]):
|
||||
if fraction < f:
|
||||
test_examples_dest = dest
|
||||
break
|
||||
|
||||
assert test_examples_dest is not None
|
||||
|
||||
if test_examples_dest[0] == "test":
|
||||
test_puzzles[name] = puzzle
|
||||
|
||||
convert_single_arc_puzzle(results, name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
|
||||
total_puzzles += 1
|
||||
|
||||
print (f"Total puzzles: {total_puzzles}")
|
||||
return results, test_puzzles
|
||||
|
||||
|
||||
def convert_dataset(config: DataProcessConfig):
|
||||
np.random.seed(config.seed)
|
||||
|
||||
# Read dataset
|
||||
data, test_puzzles = load_puzzles_arcagi(config)
|
||||
|
||||
# Map global puzzle identifiers
|
||||
num_identifiers = config.puzzle_identifiers_start # 0 is blank, start at 1
|
||||
identifier_map = {}
|
||||
for split_name, split in data.items():
|
||||
for subset_name, subset in split.items():
|
||||
for group in subset:
|
||||
for puzzle in group:
|
||||
if puzzle.id not in identifier_map:
|
||||
identifier_map[puzzle.id] = num_identifiers
|
||||
num_identifiers += 1
|
||||
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
|
||||
|
||||
# Save
|
||||
for split_name, split in data.items():
|
||||
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
|
||||
|
||||
# Translational augmentations
|
||||
enable_translational_augment = split_name == "train"
|
||||
|
||||
# Statistics
|
||||
total_examples = 0
|
||||
total_puzzles = 0
|
||||
total_groups = 0
|
||||
|
||||
for subset_name, subset in split.items(): # "all" is the only subset
|
||||
# Construct subset
|
||||
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||
results["puzzle_indices"].append(0)
|
||||
results["group_indices"].append(0)
|
||||
|
||||
example_id = 0
|
||||
puzzle_id = 0
|
||||
|
||||
for group in subset:
|
||||
for puzzle in group:
|
||||
# Push puzzle
|
||||
no_aug_id = np.random.randint(0, len(puzzle.examples))
|
||||
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
|
||||
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
|
||||
|
||||
results["inputs"].append(inp)
|
||||
results["labels"].append(out)
|
||||
example_id += 1
|
||||
|
||||
total_examples += 1
|
||||
|
||||
results["puzzle_indices"].append(example_id)
|
||||
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
|
||||
|
||||
puzzle_id += 1
|
||||
total_puzzles += 1
|
||||
|
||||
# Push group
|
||||
results["group_indices"].append(puzzle_id)
|
||||
total_groups += 1
|
||||
|
||||
for k, v in results.items():
|
||||
if k in {"inputs", "labels"}:
|
||||
v = np.stack(v, 0)
|
||||
else:
|
||||
v = np.array(v, dtype=np.int32)
|
||||
|
||||
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
|
||||
|
||||
# Metadata
|
||||
metadata = PuzzleDatasetMetadata(
|
||||
seq_len=ARCMaxGridSize * ARCMaxGridSize,
|
||||
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
|
||||
pad_id=0,
|
||||
ignore_label_id=0,
|
||||
blank_identifier_id=0,
|
||||
num_puzzle_identifiers=num_identifiers,
|
||||
total_groups=total_groups,
|
||||
mean_puzzle_examples=total_examples / total_puzzles,
|
||||
total_puzzles=total_puzzles,
|
||||
sets=list(split.keys())
|
||||
)
|
||||
|
||||
# Save metadata as JSON.
|
||||
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
|
||||
json.dump(metadata.model_dump(), f)
|
||||
|
||||
# Save IDs mapping
|
||||
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||
ids_mapping = {v: k for k, v in identifier_map.items()}
|
||||
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
|
||||
|
||||
# Save Test Puzzles
|
||||
with open(os.path.join(config.output_dir, "test_puzzles.json"), "w") as f:
|
||||
json.dump(test_puzzles, f)
|
||||
|
||||
|
||||
@cli.command(singleton=True)
|
||||
def main(config: DataProcessConfig):
|
||||
convert_dataset(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
140
dataset/build_maze_dataset.py
Normal file
140
dataset/build_maze_dataset.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import Optional
|
||||
import math
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from common import PuzzleDatasetMetadata, dihedral_transform
|
||||
|
||||
|
||||
CHARSET = "# SGo"
|
||||
|
||||
|
||||
cli = ArgParser()
|
||||
|
||||
|
||||
class DataProcessConfig(BaseModel):
|
||||
source_repo: str = "sapientinc/maze-30x30-hard-1k"
|
||||
output_dir: str = "data/maze-30x30-hard-1k"
|
||||
|
||||
subsample_size: Optional[int] = None
|
||||
aug: bool = False
|
||||
|
||||
|
||||
def convert_subset(set_name: str, config: DataProcessConfig):
|
||||
# Read CSV
|
||||
all_chars = set()
|
||||
grid_size = None
|
||||
inputs = []
|
||||
labels = []
|
||||
|
||||
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
|
||||
reader = csv.reader(csvfile)
|
||||
next(reader) # Skip header
|
||||
for source, q, a, rating in reader:
|
||||
all_chars.update(q)
|
||||
all_chars.update(a)
|
||||
|
||||
if grid_size is None:
|
||||
n = int(len(q) ** 0.5)
|
||||
grid_size = (n, n)
|
||||
|
||||
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
|
||||
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
|
||||
|
||||
# If subsample_size is specified for the training set,
|
||||
# randomly sample the desired number of examples.
|
||||
if set_name == "train" and config.subsample_size is not None:
|
||||
total_samples = len(inputs)
|
||||
if config.subsample_size < total_samples:
|
||||
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
||||
inputs = [inputs[i] for i in indices]
|
||||
labels = [labels[i] for i in indices]
|
||||
|
||||
# Generate dataset
|
||||
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||
puzzle_id = 0
|
||||
example_id = 0
|
||||
|
||||
results["puzzle_indices"].append(0)
|
||||
results["group_indices"].append(0)
|
||||
|
||||
for inp, out in zip(tqdm(inputs), labels):
|
||||
# Dihedral transformations for augmentation
|
||||
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
|
||||
results["inputs"].append(dihedral_transform(inp, aug_idx))
|
||||
results["labels"].append(dihedral_transform(out, aug_idx))
|
||||
example_id += 1
|
||||
puzzle_id += 1
|
||||
|
||||
results["puzzle_indices"].append(example_id)
|
||||
results["puzzle_identifiers"].append(0)
|
||||
|
||||
# Push group
|
||||
results["group_indices"].append(puzzle_id)
|
||||
|
||||
# Char mappings
|
||||
assert len(all_chars - set(CHARSET)) == 0
|
||||
|
||||
char2id = np.zeros(256, np.uint8)
|
||||
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
|
||||
|
||||
# To Numpy
|
||||
def _seq_to_numpy(seq):
|
||||
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
|
||||
|
||||
return arr
|
||||
|
||||
results = {
|
||||
"inputs": _seq_to_numpy(results["inputs"]),
|
||||
"labels": _seq_to_numpy(results["labels"]),
|
||||
|
||||
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
||||
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
||||
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
||||
}
|
||||
|
||||
# Metadata
|
||||
metadata = PuzzleDatasetMetadata(
|
||||
seq_len=int(math.prod(grid_size)), # type: ignore
|
||||
vocab_size=len(CHARSET) + 1, # PAD + Charset
|
||||
pad_id=0,
|
||||
ignore_label_id=0,
|
||||
blank_identifier_id=0,
|
||||
num_puzzle_identifiers=1,
|
||||
total_groups=len(results["group_indices"]) - 1,
|
||||
mean_puzzle_examples=1,
|
||||
total_puzzles=len(results["group_indices"]) - 1,
|
||||
sets=["all"]
|
||||
)
|
||||
|
||||
# Save metadata as JSON.
|
||||
save_dir = os.path.join(config.output_dir, set_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
||||
json.dump(metadata.model_dump(), f)
|
||||
|
||||
# Save data
|
||||
for k, v in results.items():
|
||||
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
||||
|
||||
# Save IDs mapping (for visualization only)
|
||||
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||
json.dump(["<blank>"], f)
|
||||
|
||||
|
||||
@cli.command(singleton=True)
|
||||
def preprocess_data(config: DataProcessConfig):
|
||||
convert_subset("train", config)
|
||||
convert_subset("test", config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
167
dataset/build_sudoku_dataset.py
Normal file
167
dataset/build_sudoku_dataset.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from common import PuzzleDatasetMetadata
|
||||
|
||||
|
||||
cli = ArgParser()
|
||||
|
||||
|
||||
class DataProcessConfig(BaseModel):
|
||||
source_repo: str = "sapientinc/sudoku-extreme"
|
||||
output_dir: str = "data/sudoku-extreme-full"
|
||||
|
||||
subsample_size: Optional[int] = None
|
||||
min_difficulty: Optional[int] = None
|
||||
num_aug: int = 0
|
||||
|
||||
|
||||
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
|
||||
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
|
||||
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
|
||||
|
||||
# Randomly decide whether to transpose.
|
||||
transpose_flag = np.random.rand() < 0.5
|
||||
|
||||
# Generate a valid row permutation:
|
||||
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
|
||||
bands = np.random.permutation(3)
|
||||
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
|
||||
|
||||
# Similarly for columns (stacks).
|
||||
stacks = np.random.permutation(3)
|
||||
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
|
||||
|
||||
# Build an 81->81 mapping. For each new cell at (i, j)
|
||||
# (row index = i // 9, col index = i % 9),
|
||||
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
|
||||
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
|
||||
|
||||
def apply_transformation(x: np.ndarray) -> np.ndarray:
|
||||
# Apply transpose flag
|
||||
if transpose_flag:
|
||||
x = x.T
|
||||
# Apply the position mapping.
|
||||
new_board = x.flatten()[mapping].reshape(9, 9).copy()
|
||||
# Apply digit mapping
|
||||
return digit_map[new_board]
|
||||
|
||||
return apply_transformation(board), apply_transformation(solution)
|
||||
|
||||
|
||||
def convert_subset(set_name: str, config: DataProcessConfig):
|
||||
# Read CSV
|
||||
inputs = []
|
||||
labels = []
|
||||
|
||||
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
|
||||
reader = csv.reader(csvfile)
|
||||
next(reader) # Skip header
|
||||
for source, q, a, rating in reader:
|
||||
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
|
||||
assert len(q) == 81 and len(a) == 81
|
||||
|
||||
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
||||
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
|
||||
|
||||
# If subsample_size is specified for the training set,
|
||||
# randomly sample the desired number of examples.
|
||||
if set_name == "train" and config.subsample_size is not None:
|
||||
total_samples = len(inputs)
|
||||
if config.subsample_size < total_samples:
|
||||
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
|
||||
inputs = [inputs[i] for i in indices]
|
||||
labels = [labels[i] for i in indices]
|
||||
|
||||
# Generate dataset
|
||||
num_augments = config.num_aug if set_name == "train" else 0
|
||||
|
||||
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
|
||||
puzzle_id = 0
|
||||
example_id = 0
|
||||
|
||||
results["puzzle_indices"].append(0)
|
||||
results["group_indices"].append(0)
|
||||
|
||||
for orig_inp, orig_out in zip(tqdm(inputs), labels):
|
||||
for aug_idx in range(1 + num_augments):
|
||||
# First index is not augmented
|
||||
if aug_idx == 0:
|
||||
inp, out = orig_inp, orig_out
|
||||
else:
|
||||
inp, out = shuffle_sudoku(orig_inp, orig_out)
|
||||
|
||||
# Push puzzle (only single example)
|
||||
results["inputs"].append(inp)
|
||||
results["labels"].append(out)
|
||||
example_id += 1
|
||||
puzzle_id += 1
|
||||
|
||||
results["puzzle_indices"].append(example_id)
|
||||
results["puzzle_identifiers"].append(0)
|
||||
|
||||
# Push group
|
||||
results["group_indices"].append(puzzle_id)
|
||||
|
||||
# To Numpy
|
||||
def _seq_to_numpy(seq):
|
||||
arr = np.concatenate(seq).reshape(len(seq), -1)
|
||||
|
||||
assert np.all((arr >= 0) & (arr <= 9))
|
||||
return arr + 1
|
||||
|
||||
results = {
|
||||
"inputs": _seq_to_numpy(results["inputs"]),
|
||||
"labels": _seq_to_numpy(results["labels"]),
|
||||
|
||||
"group_indices": np.array(results["group_indices"], dtype=np.int32),
|
||||
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
|
||||
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
|
||||
}
|
||||
|
||||
# Metadata
|
||||
metadata = PuzzleDatasetMetadata(
|
||||
seq_len=81,
|
||||
vocab_size=10 + 1, # PAD + "0" ... "9"
|
||||
pad_id=0,
|
||||
ignore_label_id=0,
|
||||
blank_identifier_id=0,
|
||||
num_puzzle_identifiers=1,
|
||||
total_groups=len(results["group_indices"]) - 1,
|
||||
mean_puzzle_examples=1,
|
||||
total_puzzles=len(results["group_indices"]) - 1,
|
||||
sets=["all"]
|
||||
)
|
||||
|
||||
# Save metadata as JSON.
|
||||
save_dir = os.path.join(config.output_dir, set_name)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
|
||||
json.dump(metadata.model_dump(), f)
|
||||
|
||||
# Save data
|
||||
for k, v in results.items():
|
||||
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
|
||||
|
||||
# Save IDs mapping (for visualization only)
|
||||
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
|
||||
json.dump(["<blank>"], f)
|
||||
|
||||
|
||||
@cli.command(singleton=True)
|
||||
def preprocess_data(config: DataProcessConfig):
|
||||
convert_subset("train", config)
|
||||
convert_subset("test", config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
49
dataset/common.py
Normal file
49
dataset/common.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import pydantic
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Global list mapping each dihedral transform id to its inverse.
|
||||
# Index corresponds to the original tid, and the value is its inverse.
|
||||
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
|
||||
|
||||
|
||||
class PuzzleDatasetMetadata(pydantic.BaseModel):
|
||||
pad_id: int
|
||||
ignore_label_id: Optional[int]
|
||||
blank_identifier_id: int
|
||||
vocab_size: int
|
||||
seq_len: int
|
||||
num_puzzle_identifiers: int
|
||||
total_groups: int
|
||||
mean_puzzle_examples: float
|
||||
total_puzzles: int
|
||||
sets: List[str]
|
||||
|
||||
|
||||
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
||||
"""8 dihedral symmetries by rotate, flip and mirror"""
|
||||
|
||||
if tid == 0:
|
||||
return arr # identity
|
||||
elif tid == 1:
|
||||
return np.rot90(arr, k=1)
|
||||
elif tid == 2:
|
||||
return np.rot90(arr, k=2)
|
||||
elif tid == 3:
|
||||
return np.rot90(arr, k=3)
|
||||
elif tid == 4:
|
||||
return np.fliplr(arr) # horizontal flip
|
||||
elif tid == 5:
|
||||
return np.flipud(arr) # vertical flip
|
||||
elif tid == 6:
|
||||
return arr.T # transpose (reflection along main diagonal)
|
||||
elif tid == 7:
|
||||
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
|
||||
else:
|
||||
return arr
|
||||
|
||||
|
||||
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
|
||||
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
|
||||
177
evaluators/arc.py
Normal file
177
evaluators/arc.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import Dict, Sequence, Optional
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
import torch.distributed as dist
|
||||
|
||||
from dataset.build_arc_dataset import inverse_aug, grid_hash, arc_grid_to_np
|
||||
from dataset.common import PuzzleDatasetMetadata
|
||||
|
||||
@njit
|
||||
def _crop(grid: np.ndarray):
|
||||
"""Find maximum-sized rectangle without any EOS token inside. """
|
||||
grid = grid.reshape(30, 30)
|
||||
|
||||
max_area = 0
|
||||
max_size = (0, 0)
|
||||
nr, nc = grid.shape
|
||||
|
||||
num_c = nc
|
||||
for num_r in range(1, nr + 1):
|
||||
# Scan for maximum c
|
||||
for c in range(1, num_c + 1):
|
||||
x = grid[num_r - 1, c - 1]
|
||||
if (x < 2) | (x > 11):
|
||||
num_c = c - 1
|
||||
break
|
||||
|
||||
area = num_r * num_c
|
||||
if area > max_area:
|
||||
max_area = area
|
||||
max_size = (num_r, num_c)
|
||||
|
||||
return (grid[:max_size[0], :max_size[1]] - 2).astype(np.uint8)
|
||||
|
||||
|
||||
class ARC:
|
||||
required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
|
||||
|
||||
def __init__(self, data_path: str,
|
||||
eval_metadata: PuzzleDatasetMetadata,
|
||||
submission_K: int = 2,
|
||||
pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
|
||||
aggregated_voting: bool = True):
|
||||
super().__init__()
|
||||
self.pass_Ks = pass_Ks
|
||||
self.submission_K = submission_K
|
||||
self.aggregated_voting = aggregated_voting
|
||||
self.blank_identifier_id = eval_metadata.blank_identifier_id
|
||||
|
||||
# Load identifiers and test puzzles
|
||||
with open(os.path.join(data_path, "identifiers.json"), "r") as f:
|
||||
self.identifier_map = json.load(f)
|
||||
with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
|
||||
self.test_puzzles = json.load(f)
|
||||
|
||||
# States
|
||||
self._local_hmap = {}
|
||||
self._local_preds = {}
|
||||
|
||||
def begin_eval(self):
|
||||
if not self.aggregated_voting:
|
||||
# Clear previous predictions
|
||||
self._local_hmap = {}
|
||||
self._local_preds = {}
|
||||
|
||||
def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
|
||||
# Collect required outputs to CPU
|
||||
outputs = {}
|
||||
q_values = None
|
||||
|
||||
for collection in (batch, preds):
|
||||
for k, v in collection.items():
|
||||
if k in self.required_outputs:
|
||||
if k == "q_halt_logits":
|
||||
q_values = v.to(torch.float64).sigmoid().cpu()
|
||||
else:
|
||||
outputs[k] = v.cpu()
|
||||
|
||||
assert q_values is not None
|
||||
|
||||
# Remove padding from outputs
|
||||
mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
|
||||
outputs = {k: v[mask] for k, v in outputs.items()}
|
||||
|
||||
# Get predictions
|
||||
for identifier, input, pred, q in zip(outputs["puzzle_identifiers"].numpy(), outputs["inputs"].numpy(), outputs["preds"].numpy(), q_values.numpy()):
|
||||
name = self.identifier_map[identifier]
|
||||
orig_name, _inverse_fn = inverse_aug(name)
|
||||
|
||||
input_hash = grid_hash(_inverse_fn(_crop(input)))
|
||||
|
||||
pred = _inverse_fn(_crop(pred))
|
||||
assert np.all((pred >= 0) & (pred <= 9)), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
|
||||
|
||||
# Store into local state
|
||||
pred_hash = grid_hash(pred)
|
||||
|
||||
self._local_hmap[pred_hash] = pred
|
||||
|
||||
self._local_preds.setdefault(orig_name, {})
|
||||
self._local_preds[orig_name].setdefault(input_hash, [])
|
||||
self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
|
||||
|
||||
def result(self, save_path: Optional[str], rank: int, world_size: int, group: Optional[torch.distributed.ProcessGroup] = None) -> Optional[Dict[str, float]]:
|
||||
# Gather predictions to rank 0 for voting
|
||||
global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
|
||||
dist.gather_object((self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group)
|
||||
|
||||
# Rank 0 logic
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
submission = {}
|
||||
correct = [0.0 for _ in range(len(self.pass_Ks))]
|
||||
|
||||
for name, puzzle in self.test_puzzles.items():
|
||||
# Process test examples in this puzzle
|
||||
submission[name] = []
|
||||
num_test_correct = [0 for _ in range(len(self.pass_Ks))]
|
||||
for pair in puzzle["test"]:
|
||||
input_hash = grid_hash(arc_grid_to_np(pair["input"]))
|
||||
label_hash = grid_hash(arc_grid_to_np(pair["output"]))
|
||||
|
||||
p_map = {}
|
||||
for hmap, preds in global_hmap_preds: # type: ignore
|
||||
for h, q in preds.get(name, {}).get(input_hash, {}):
|
||||
p_map.setdefault(h, [0, 0])
|
||||
p_map[h][0] += 1
|
||||
p_map[h][1] += q
|
||||
|
||||
if not len(p_map):
|
||||
print (f"Puzzle {name} has no predictions.")
|
||||
continue
|
||||
|
||||
for h, stats in p_map.items():
|
||||
stats[1] /= stats[0]
|
||||
|
||||
p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
|
||||
|
||||
# vote for different Ks
|
||||
for i, k in enumerate(self.pass_Ks):
|
||||
ok = False
|
||||
for h, stats in p_map[:k]:
|
||||
ok |= h == label_hash
|
||||
|
||||
num_test_correct[i] += ok
|
||||
|
||||
# Query grids
|
||||
pred_grids = []
|
||||
for h, stats in p_map[:self.submission_K]:
|
||||
for hmap, preds in global_hmap_preds: # type: ignore
|
||||
if h in hmap:
|
||||
pred_grids.append(hmap[h])
|
||||
break
|
||||
|
||||
# Pad to K
|
||||
while len(pred_grids) < self.submission_K:
|
||||
pred_grids.append(pred_grids[0])
|
||||
|
||||
submission[name].append({f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)})
|
||||
|
||||
# Total correctness
|
||||
for i in range(len(self.pass_Ks)):
|
||||
correct[i] += num_test_correct[i] / len(puzzle["test"])
|
||||
|
||||
# Save submission
|
||||
if save_path is not None:
|
||||
with open(os.path.join(save_path, "submission.json"), "w") as f:
|
||||
json.dump(submission, f)
|
||||
|
||||
# Final result
|
||||
all_results = {f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)}
|
||||
|
||||
return all_results
|
||||
17863
kaggle/combined/arc-agi_concept_challenges.json
Normal file
17863
kaggle/combined/arc-agi_concept_challenges.json
Normal file
File diff suppressed because it is too large
Load Diff
5384
kaggle/combined/arc-agi_concept_solutions.json
Normal file
5384
kaggle/combined/arc-agi_concept_solutions.json
Normal file
File diff suppressed because it is too large
Load Diff
1
kaggle/combined/arc-agi_evaluation2_challenges.json
Normal file
1
kaggle/combined/arc-agi_evaluation2_challenges.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_evaluation2_solutions.json
Normal file
1
kaggle/combined/arc-agi_evaluation2_solutions.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_evaluation_challenges.json
Normal file
1
kaggle/combined/arc-agi_evaluation_challenges.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_evaluation_solutions.json
Normal file
1
kaggle/combined/arc-agi_evaluation_solutions.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_training2_challenges.json
Normal file
1
kaggle/combined/arc-agi_training2_challenges.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_training2_solutions.json
Normal file
1
kaggle/combined/arc-agi_training2_solutions.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_training_challenges.json
Normal file
1
kaggle/combined/arc-agi_training_challenges.json
Normal file
File diff suppressed because one or more lines are too long
1
kaggle/combined/arc-agi_training_solutions.json
Normal file
1
kaggle/combined/arc-agi_training_solutions.json
Normal file
File diff suppressed because one or more lines are too long
32
models/common.py
Normal file
32
models/common.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
||||
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
||||
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
||||
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
||||
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
||||
|
||||
with torch.no_grad():
|
||||
if std == 0:
|
||||
tensor.zero_()
|
||||
else:
|
||||
sqrt2 = math.sqrt(2)
|
||||
a = math.erf(lower / sqrt2)
|
||||
b = math.erf(upper / sqrt2)
|
||||
z = (b - a) / 2
|
||||
|
||||
c = (2 * math.pi) ** -0.5
|
||||
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
||||
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
||||
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
||||
|
||||
tensor.uniform_(a, b)
|
||||
tensor.erfinv_()
|
||||
tensor.mul_(sqrt2 * comp_std)
|
||||
tensor.clip_(lower * comp_std, upper * comp_std)
|
||||
|
||||
return tensor
|
||||
40
models/ema.py
Normal file
40
models/ema.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
class EMAHelper(object):
|
||||
def __init__(self, mu=0.999):
|
||||
self.mu = mu
|
||||
self.shadow = {}
|
||||
|
||||
def register(self, module):
|
||||
if isinstance(module, nn.DataParallel):
|
||||
module = module.module
|
||||
for name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name] = param.data.clone()
|
||||
|
||||
def update(self, module):
|
||||
if isinstance(module, nn.DataParallel):
|
||||
module = module.module
|
||||
for name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
|
||||
|
||||
def ema(self, module):
|
||||
if isinstance(module, nn.DataParallel):
|
||||
module = module.module
|
||||
for name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
param.data.copy_(self.shadow[name].data)
|
||||
|
||||
def ema_copy(self, module):
|
||||
module_copy = copy.deepcopy(module)
|
||||
self.ema(module_copy)
|
||||
return module_copy
|
||||
|
||||
def state_dict(self):
|
||||
return self.shadow
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.shadow = state_dict
|
||||
|
||||
169
models/layers.py
Normal file
169
models/layers.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from typing import Tuple
|
||||
import einops
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
#try:
|
||||
# from flash_attn_interface import flash_attn_func # type: ignore[import]
|
||||
#except ImportError:
|
||||
# # Fallback to FlashAttention 2
|
||||
# from flash_attn import flash_attn_func # type: ignore[import]
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from models.common import trunc_normal_init_
|
||||
|
||||
|
||||
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _find_multiple(a, b):
|
||||
return (-(a // -b)) * b
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
# q, k: [bs, seq_len, num_heads, head_dim]
|
||||
# cos, sin: [seq_len, head_dim]
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(cos.dtype)
|
||||
k = k.to(cos.dtype)
|
||||
|
||||
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
||||
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
||||
|
||||
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
||||
|
||||
|
||||
class CastedLinear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool):
|
||||
super().__init__()
|
||||
# Truncated LeCun normal init
|
||||
self.weight = nn.Parameter(
|
||||
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
||||
)
|
||||
self.bias = None
|
||||
if bias:
|
||||
# Zero init bias
|
||||
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
||||
|
||||
|
||||
class CastedEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
init_std: float,
|
||||
cast_to: torch.dtype):
|
||||
super().__init__()
|
||||
self.cast_to = cast_to
|
||||
|
||||
# Truncated LeCun normal init
|
||||
self.embedding_weight = nn.Parameter(
|
||||
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings, base, device=None):
|
||||
super().__init__()
|
||||
|
||||
# RoPE
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
|
||||
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
|
||||
|
||||
def forward(self):
|
||||
return self.cos_cached, self.sin_cached
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.output_size = head_dim * num_heads
|
||||
self.num_heads = num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.causal = causal
|
||||
|
||||
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
||||
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
||||
|
||||
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
|
||||
# Split head
|
||||
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
query = qkv[:, :, :self.num_heads]
|
||||
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
||||
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
||||
|
||||
# RoPE
|
||||
if cos_sin is not None:
|
||||
cos, sin = cos_sin
|
||||
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
||||
|
||||
# flash attn
|
||||
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
|
||||
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
|
||||
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
|
||||
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
|
||||
return self.o_proj(attn_output)
|
||||
|
||||
class LinearSwish(nn.Module):
|
||||
def __init__(self, hidden_size: int, reverse=False):
|
||||
super().__init__()
|
||||
|
||||
self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
|
||||
self.reverse = reverse
|
||||
|
||||
def forward(self, x):
|
||||
if self.reverse:
|
||||
return F.silu(self.linear(x))
|
||||
else:
|
||||
return self.linear(F.silu(x))
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, hidden_size: int, expansion: float):
|
||||
super().__init__()
|
||||
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
||||
|
||||
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
||||
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
||||
return self.down_proj(F.silu(gate) * up)
|
||||
|
||||
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
variance = hidden_states.square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
||||
return hidden_states.to(input_dtype)
|
||||
103
models/losses.py
Normal file
103
models/losses.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Any, Tuple, Dict, Sequence, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
IGNORE_LABEL_ID = -100
|
||||
|
||||
|
||||
def s(x, epsilon=1e-30):
|
||||
return torch.where(
|
||||
x<0,
|
||||
1/(1-x+ epsilon),
|
||||
x + 1
|
||||
)
|
||||
|
||||
|
||||
def log_stablemax(x, dim=-1):
|
||||
s_x = s(x)
|
||||
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
||||
|
||||
|
||||
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
|
||||
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
||||
|
||||
if valid_mask is None:
|
||||
valid_mask = (labels != ignore_index)
|
||||
transformed_labels = torch.where(valid_mask, labels, 0)
|
||||
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
||||
|
||||
return -torch.where(valid_mask, prediction_logprobs, 0)
|
||||
|
||||
|
||||
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
||||
# Cast logits to f32
|
||||
# Flatten logits
|
||||
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
||||
|
||||
|
||||
class ACTLossHead(nn.Module):
|
||||
def __init__(self, model: nn.Module, loss_type: str):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.loss_fn = globals()[loss_type]
|
||||
|
||||
def initial_carry(self, *args, **kwargs):
|
||||
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
||||
|
||||
def forward(
|
||||
self,
|
||||
return_keys: Sequence[str],
|
||||
# Model args
|
||||
**model_kwargs,
|
||||
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
||||
# Model logits
|
||||
# B x SeqLen x D
|
||||
new_carry, outputs = self.model(**model_kwargs)
|
||||
labels = new_carry.current_data["labels"]
|
||||
|
||||
with torch.no_grad():
|
||||
# Preds
|
||||
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
|
||||
|
||||
# Correctness
|
||||
mask = (labels != IGNORE_LABEL_ID)
|
||||
loss_counts = mask.sum(-1)
|
||||
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
||||
|
||||
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
||||
seq_is_correct = is_correct.sum(-1) == loss_counts
|
||||
|
||||
# Metrics (halted)
|
||||
valid_metrics = new_carry.halted & (loss_counts > 0)
|
||||
metrics = {
|
||||
"count": valid_metrics.sum(),
|
||||
|
||||
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
||||
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
||||
|
||||
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
||||
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
||||
}
|
||||
|
||||
# Losses
|
||||
|
||||
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
|
||||
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
||||
metrics.update({
|
||||
"lm_loss": lm_loss.detach(),
|
||||
"q_halt_loss": q_halt_loss.detach(),
|
||||
})
|
||||
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
|
||||
q_continue_loss = 0
|
||||
if "target_q_continue" in outputs:
|
||||
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
||||
|
||||
metrics["q_continue_loss"] = q_continue_loss.detach()
|
||||
# Filter outputs for return
|
||||
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
||||
|
||||
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
||||
|
||||
294
models/recursive_reasoning/hrm.py
Normal file
294
models/recursive_reasoning/hrm.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.common import trunc_normal_init_
|
||||
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||
from models.sparse_embedding import CastedSparseEmbedding
|
||||
|
||||
@dataclass
|
||||
class HierarchicalReasoningModel_ACTV1InnerCarry:
|
||||
z_H: torch.Tensor
|
||||
z_L: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class HierarchicalReasoningModel_ACTV1Carry:
|
||||
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
|
||||
|
||||
steps: torch.Tensor
|
||||
halted: torch.Tensor
|
||||
|
||||
current_data: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
|
||||
batch_size: int
|
||||
seq_len: int
|
||||
puzzle_emb_ndim: int = 0
|
||||
num_puzzle_identifiers: int
|
||||
vocab_size: int
|
||||
|
||||
H_cycles: int
|
||||
L_cycles: int
|
||||
|
||||
H_layers: int
|
||||
L_layers: int
|
||||
|
||||
# Transformer config
|
||||
hidden_size: int
|
||||
expansion: float
|
||||
num_heads: int
|
||||
pos_encodings: str
|
||||
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
# Halting Q-learning config
|
||||
halt_max_steps: int
|
||||
halt_exploration_prob: float
|
||||
|
||||
forward_dtype: str = "bfloat16"
|
||||
|
||||
# Alexia: added
|
||||
mlp_t: bool=False # use mlp on L instead of transformer
|
||||
|
||||
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
|
||||
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if self.config.mlp_t:
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
|
||||
self.mlp_t = SwiGLU(
|
||||
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||
expansion=config.expansion,
|
||||
)
|
||||
else:
|
||||
self.self_attn = Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
head_dim=config.hidden_size // config.num_heads,
|
||||
num_heads=config.num_heads,
|
||||
num_key_value_heads=config.num_heads,
|
||||
causal=False
|
||||
)
|
||||
self.mlp = SwiGLU(
|
||||
hidden_size=config.hidden_size,
|
||||
expansion=config.expansion,
|
||||
)
|
||||
self.norm_eps = config.rms_norm_eps
|
||||
|
||||
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# B, L, D = hidden_states.shape
|
||||
# Post Norm
|
||||
if self.config.mlp_t:
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
out = self.mlp_t(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
else:
|
||||
# Self Attention
|
||||
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||
# Fully Connected
|
||||
out = self.mlp(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
return hidden_states
|
||||
|
||||
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
|
||||
super().__init__()
|
||||
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Input injection (add)
|
||||
hidden_states = hidden_states + input_injection
|
||||
# Layers
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
|
||||
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||
|
||||
# I/O
|
||||
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||
embed_init_std = 1.0 / self.embed_scale
|
||||
|
||||
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
# Zero init puzzle embeddings
|
||||
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||
|
||||
# LM Blocks
|
||||
if self.config.pos_encodings == "rope":
|
||||
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||
base=self.config.rope_theta)
|
||||
elif self.config.pos_encodings == "learned":
|
||||
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Reasoning Layers
|
||||
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
|
||||
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||
|
||||
# Initial states
|
||||
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
|
||||
# Q head special init
|
||||
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||
with torch.no_grad():
|
||||
self.q_head.weight.zero_()
|
||||
self.q_head.bias.fill_(-5) # type: ignore
|
||||
|
||||
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||
# Token embedding
|
||||
embedding = self.embed_tokens(input.to(torch.int32))
|
||||
|
||||
# Puzzle embeddings
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||
|
||||
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||
if pad_count > 0:
|
||||
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||
|
||||
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||
|
||||
# Position embeddings
|
||||
if self.config.pos_encodings == "learned":
|
||||
# scale by 1/sqrt(2) to maintain forward variance
|
||||
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||
|
||||
# Scale
|
||||
return self.embed_scale * embedding
|
||||
|
||||
def empty_carry(self, batch_size: int):
|
||||
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
||||
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
)
|
||||
|
||||
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
|
||||
return HierarchicalReasoningModel_ACTV1InnerCarry(
|
||||
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
||||
)
|
||||
|
||||
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
seq_info = dict(
|
||||
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||
)
|
||||
|
||||
# Input encoding
|
||||
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||
|
||||
# Forward iterations
|
||||
with torch.no_grad():
|
||||
z_H, z_L = carry.z_H, carry.z_L
|
||||
for _H_step in range(self.config.H_cycles):
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
|
||||
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||
if not (_H_step == self.config.H_cycles - 1):
|
||||
z_H = self.H_level(z_H, z_L, **seq_info)
|
||||
assert not z_H.requires_grad and not z_L.requires_grad
|
||||
# 1-step grad
|
||||
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||
z_H = self.H_level(z_H, z_L, **seq_info)
|
||||
|
||||
# LM Outputs
|
||||
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
||||
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
||||
|
||||
# Q head
|
||||
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
||||
|
||||
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||
|
||||
|
||||
class HierarchicalReasoningModel_ACTV1(nn.Module):
|
||||
"""ACT wrapper."""
|
||||
|
||||
def __init__(self, config_dict: dict):
|
||||
super().__init__()
|
||||
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
|
||||
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
|
||||
|
||||
@property
|
||||
def puzzle_emb(self):
|
||||
return self.inner.puzzle_emb
|
||||
|
||||
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
return HierarchicalReasoningModel_ACTV1Carry(
|
||||
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||
|
||||
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||
|
||||
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||
)
|
||||
|
||||
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||
# Update data, carry (removing halted sequences)
|
||||
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||
|
||||
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||
|
||||
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||
|
||||
# Forward inner model
|
||||
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||
|
||||
outputs = {
|
||||
"logits": logits,
|
||||
"q_halt_logits": q_halt_logits,
|
||||
"q_continue_logits": q_continue_logits
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
# Step
|
||||
new_steps = new_steps + 1
|
||||
is_last_step = new_steps >= self.config.halt_max_steps
|
||||
|
||||
halted = is_last_step
|
||||
|
||||
# if training, and ACT is enabled
|
||||
if self.training and (self.config.halt_max_steps > 1):
|
||||
# Halt signal
|
||||
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||
halted = halted | (q_halt_logits > q_continue_logits)
|
||||
|
||||
# Exploration
|
||||
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||
|
||||
halted = halted & (new_steps >= min_halt_steps)
|
||||
|
||||
# Compute target Q
|
||||
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||
# As batch_size is large, there're many parallel envs.
|
||||
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
|
||||
|
||||
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||
|
||||
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||
342
models/recursive_reasoning/transformers_baseline.py
Normal file
342
models/recursive_reasoning/transformers_baseline.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
HRM ACT V2: Transformer Baseline for Architecture Ablation
|
||||
|
||||
This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
|
||||
Key changes from V1:
|
||||
1. REMOVED hierarchical split (no separate H and L levels)
|
||||
2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
|
||||
3. KEPT ACT outer loop structure intact
|
||||
4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
|
||||
|
||||
Architecture: Single-level transformer that processes the full 30x30 grid as a
|
||||
900-token sequence, with the same positional encodings and sparse embeddings as V1.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.common import trunc_normal_init_
|
||||
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||
from models.sparse_embedding import CastedSparseEmbedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model_ACTV2InnerCarry:
|
||||
z_H: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Model_ACTV2Carry:
|
||||
inner_carry: Model_ACTV2InnerCarry
|
||||
|
||||
steps: torch.Tensor
|
||||
halted: torch.Tensor
|
||||
|
||||
current_data: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class Model_ACTV2Config(BaseModel):
|
||||
batch_size: int
|
||||
seq_len: int
|
||||
puzzle_emb_ndim: int = 0
|
||||
num_puzzle_identifiers: int
|
||||
vocab_size: int
|
||||
|
||||
H_cycles: int
|
||||
|
||||
H_layers: int
|
||||
|
||||
# Transformer config
|
||||
hidden_size: int
|
||||
expansion: float
|
||||
num_heads: int
|
||||
pos_encodings: str
|
||||
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
# Halting Q-learning config
|
||||
halt_max_steps: int
|
||||
halt_exploration_prob: float
|
||||
act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
|
||||
act_inference: bool = False # If True, use adaptive computation during inference
|
||||
|
||||
forward_dtype: str = "bfloat16"
|
||||
|
||||
|
||||
class Model_ACTV2Block(nn.Module):
|
||||
def __init__(self, config: Model_ACTV2Config) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
head_dim=config.hidden_size // config.num_heads,
|
||||
num_heads=config.num_heads,
|
||||
num_key_value_heads=config.num_heads,
|
||||
causal=False,
|
||||
)
|
||||
self.mlp = SwiGLU(
|
||||
hidden_size=config.hidden_size,
|
||||
expansion=config.expansion,
|
||||
)
|
||||
self.norm_eps = config.rms_norm_eps
|
||||
|
||||
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# Post Norm
|
||||
# Self Attention
|
||||
hidden_states = rms_norm(
|
||||
hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
|
||||
variance_epsilon=self.norm_eps,
|
||||
)
|
||||
# Fully Connected
|
||||
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Model_ACTV2ReasoningModule(nn.Module):
|
||||
def __init__(self, layers: List[Model_ACTV2Block]):
|
||||
super().__init__()
|
||||
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
# Input injection (add)
|
||||
hidden_states = hidden_states + input_injection
|
||||
# Layers
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Model_ACTV2_Inner(nn.Module):
|
||||
def __init__(self, config: Model_ACTV2Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||
|
||||
# I/O
|
||||
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||
embed_init_std = 1.0 / self.embed_scale
|
||||
|
||||
self.embed_tokens = CastedEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
init_std=embed_init_std,
|
||||
cast_to=self.forward_dtype,
|
||||
)
|
||||
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
# Zero init puzzle embeddings
|
||||
self.puzzle_emb = CastedSparseEmbedding(
|
||||
self.config.num_puzzle_identifiers,
|
||||
self.config.puzzle_emb_ndim,
|
||||
batch_size=self.config.batch_size,
|
||||
init_std=0,
|
||||
cast_to=self.forward_dtype,
|
||||
)
|
||||
|
||||
# LM Blocks
|
||||
if self.config.pos_encodings == "rope":
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
dim=self.config.hidden_size // self.config.num_heads,
|
||||
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||
base=self.config.rope_theta,
|
||||
)
|
||||
elif self.config.pos_encodings == "learned":
|
||||
self.embed_pos = CastedEmbedding(
|
||||
self.config.seq_len + self.puzzle_emb_len,
|
||||
self.config.hidden_size,
|
||||
init_std=embed_init_std,
|
||||
cast_to=self.forward_dtype,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Reasoning Layers
|
||||
self.H_level = Model_ACTV2ReasoningModule(
|
||||
layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
|
||||
)
|
||||
|
||||
# Initial states
|
||||
self.H_init = nn.Buffer(
|
||||
trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
|
||||
persistent=True,
|
||||
)
|
||||
|
||||
# Q head special init
|
||||
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||
with torch.no_grad():
|
||||
self.q_head.weight.zero_()
|
||||
self.q_head.bias.fill_(-5) # type: ignore
|
||||
|
||||
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||
# Token embedding
|
||||
embedding = self.embed_tokens(input.to(torch.int32))
|
||||
|
||||
# Puzzle embeddings
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||
|
||||
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||
if pad_count > 0:
|
||||
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||
|
||||
embedding = torch.cat(
|
||||
(puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
|
||||
)
|
||||
|
||||
# Position embeddings
|
||||
if self.config.pos_encodings == "learned":
|
||||
# scale by 1/sqrt(2) to maintain forward variance
|
||||
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||
|
||||
# Scale
|
||||
return self.embed_scale * embedding
|
||||
|
||||
def empty_carry(self, batch_size: int):
|
||||
return Model_ACTV2InnerCarry(
|
||||
z_H=torch.empty(
|
||||
batch_size,
|
||||
self.config.seq_len + self.puzzle_emb_len,
|
||||
self.config.hidden_size,
|
||||
dtype=self.forward_dtype,
|
||||
),
|
||||
)
|
||||
|
||||
def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
|
||||
return Model_ACTV2InnerCarry(
|
||||
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
|
||||
) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
seq_info = dict(
|
||||
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||
)
|
||||
|
||||
# Input encoding
|
||||
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||
|
||||
# 1-step grad
|
||||
z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
|
||||
|
||||
# LM Outputs
|
||||
new_carry = Model_ACTV2InnerCarry(
|
||||
z_H=z_H.detach(),
|
||||
) # New carry no grad
|
||||
output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
|
||||
|
||||
# Q head
|
||||
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
|
||||
|
||||
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||
|
||||
|
||||
class Model_ACTV2(nn.Module):
|
||||
"""ACT wrapper."""
|
||||
|
||||
def __init__(self, config_dict: dict):
|
||||
super().__init__()
|
||||
self.config = Model_ACTV2Config(**config_dict)
|
||||
self.inner = Model_ACTV2_Inner(self.config)
|
||||
|
||||
@property
|
||||
def puzzle_emb(self):
|
||||
return self.inner.puzzle_emb
|
||||
|
||||
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
return Model_ACTV2Carry(
|
||||
inner_carry=self.inner.empty_carry(
|
||||
batch_size
|
||||
), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||
steps=torch.zeros((batch_size,), dtype=torch.int32),
|
||||
halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
|
||||
current_data={k: torch.empty_like(v) for k, v in batch.items()},
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
carry: Model_ACTV2Carry,
|
||||
batch: Dict[str, torch.Tensor],
|
||||
compute_target_q: bool = False,
|
||||
) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
|
||||
# Update data, carry (removing halted sequences)
|
||||
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||
|
||||
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||
|
||||
new_current_data = {
|
||||
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
|
||||
for k, v in carry.current_data.items()
|
||||
}
|
||||
|
||||
# Forward inner model
|
||||
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
|
||||
new_inner_carry, new_current_data
|
||||
)
|
||||
|
||||
outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
|
||||
|
||||
with torch.no_grad():
|
||||
# Step
|
||||
new_steps = new_steps + 1
|
||||
is_last_step = new_steps >= self.config.halt_max_steps
|
||||
|
||||
halted = is_last_step
|
||||
|
||||
# Check if adaptive computation should be used
|
||||
use_adaptive = (self.config.halt_max_steps > 1) and (
|
||||
(self.training and self.config.act_enabled)
|
||||
or (not self.training and self.config.act_inference)
|
||||
)
|
||||
|
||||
if use_adaptive:
|
||||
# Halt signal based on Q-values (but always halt at max steps)
|
||||
q_halt_signal = q_halt_logits > q_continue_logits
|
||||
halted = halted | q_halt_signal
|
||||
|
||||
# Store actual steps used for logging (only during inference)
|
||||
if not self.training:
|
||||
outputs["actual_steps"] = new_steps.float()
|
||||
|
||||
# Exploration (only during training)
|
||||
if self.training:
|
||||
min_halt_steps = (
|
||||
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
|
||||
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||
halted = halted & (new_steps >= min_halt_steps)
|
||||
|
||||
# Compute target Q (only during training)
|
||||
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||
# As batch_size is large, there're many parallel envs.
|
||||
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||
if self.training and compute_target_q:
|
||||
next_q_halt_logits, next_q_continue_logits = self.inner(
|
||||
new_inner_carry, new_current_data
|
||||
)[-1]
|
||||
|
||||
outputs["target_q_continue"] = torch.sigmoid(
|
||||
torch.where(
|
||||
is_last_step,
|
||||
next_q_halt_logits,
|
||||
torch.maximum(next_q_halt_logits, next_q_continue_logits),
|
||||
)
|
||||
)
|
||||
|
||||
return Model_ACTV2Carry(
|
||||
new_inner_carry, new_steps, halted, new_current_data
|
||||
), outputs
|
||||
297
models/recursive_reasoning/trm.py
Normal file
297
models/recursive_reasoning/trm.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import BaseModel
|
||||
import random
|
||||
from models.common import trunc_normal_init_
|
||||
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||
from models.sparse_embedding import CastedSparseEmbedding
|
||||
|
||||
IGNORE_LABEL_ID = -100
|
||||
|
||||
@dataclass
|
||||
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
||||
z_H: torch.Tensor
|
||||
z_L: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class TinyRecursiveReasoningModel_ACTV1Carry:
|
||||
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
||||
|
||||
steps: torch.Tensor
|
||||
halted: torch.Tensor
|
||||
|
||||
current_data: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
||||
batch_size: int
|
||||
seq_len: int
|
||||
puzzle_emb_ndim: int = 0
|
||||
num_puzzle_identifiers: int
|
||||
vocab_size: int
|
||||
|
||||
H_cycles: int
|
||||
L_cycles: int
|
||||
|
||||
H_layers: int # ignored
|
||||
L_layers: int
|
||||
|
||||
# Transformer config
|
||||
hidden_size: int
|
||||
expansion: float
|
||||
num_heads: int
|
||||
pos_encodings: str
|
||||
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
# Halting Q-learning config
|
||||
halt_max_steps: int
|
||||
halt_exploration_prob: float
|
||||
|
||||
forward_dtype: str = "bfloat16"
|
||||
|
||||
# Alexia: added
|
||||
mlp_t: bool = False # use mlp on L instead of transformer
|
||||
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
||||
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
||||
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if self.config.mlp_t:
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
||||
self.mlp_t = SwiGLU(
|
||||
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||
expansion=config.expansion,
|
||||
)
|
||||
else:
|
||||
self.self_attn = Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
head_dim=config.hidden_size // config.num_heads,
|
||||
num_heads=config.num_heads,
|
||||
num_key_value_heads=config.num_heads,
|
||||
causal=False
|
||||
)
|
||||
self.mlp = SwiGLU(
|
||||
hidden_size=config.hidden_size,
|
||||
expansion=config.expansion,
|
||||
)
|
||||
self.norm_eps = config.rms_norm_eps
|
||||
|
||||
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# B, L, D = hidden_states.shape
|
||||
# Post Norm
|
||||
if self.config.mlp_t:
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
out = self.mlp_t(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
else:
|
||||
# Self Attention
|
||||
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||
# Fully Connected
|
||||
out = self.mlp(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
return hidden_states
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
hidden_states = hidden_states + input_injection
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
||||
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||
|
||||
# I/O
|
||||
|
||||
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||
embed_init_std = 1.0 / self.embed_scale
|
||||
|
||||
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
# Zero init puzzle embeddings
|
||||
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||
|
||||
# LM Blocks
|
||||
if self.config.pos_encodings == "rope":
|
||||
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||
base=self.config.rope_theta)
|
||||
elif self.config.pos_encodings == "learned":
|
||||
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Reasoning Layers
|
||||
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||
|
||||
# Initial states
|
||||
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
|
||||
# Q head special init
|
||||
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||
with torch.no_grad():
|
||||
self.q_head.weight.zero_()
|
||||
self.q_head.bias.fill_(-5) # type: ignore
|
||||
|
||||
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||
# Token embedding
|
||||
embedding = self.embed_tokens(input.to(torch.int32))
|
||||
|
||||
# Puzzle embeddings
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||
|
||||
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||
if pad_count > 0:
|
||||
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||
|
||||
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||
|
||||
# Position embeddings
|
||||
if self.config.pos_encodings == "learned":
|
||||
# scale by 1/sqrt(2) to maintain forward variance
|
||||
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||
|
||||
# Scale
|
||||
return self.embed_scale * embedding
|
||||
|
||||
def empty_carry(self, batch_size: int):
|
||||
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
)
|
||||
|
||||
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
||||
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
||||
)
|
||||
|
||||
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
seq_info = dict(
|
||||
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||
)
|
||||
|
||||
# Input encoding
|
||||
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||
|
||||
# Forward iterations
|
||||
it = 0
|
||||
z_H, z_L = carry.z_H, carry.z_L
|
||||
# H_cycles-1 without grad
|
||||
with torch.no_grad():
|
||||
for _H_step in range(self.config.H_cycles-1):
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||
z_H = self.L_level(z_H, z_L, **seq_info)
|
||||
# 1 with grad
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
|
||||
z_H = self.L_level(z_H, z_L, **seq_info)
|
||||
|
||||
# LM Outputs
|
||||
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
|
||||
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
||||
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
||||
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
||||
"""ACT wrapper."""
|
||||
|
||||
def __init__(self, config_dict: dict):
|
||||
super().__init__()
|
||||
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
||||
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
||||
|
||||
@property
|
||||
def puzzle_emb(self):
|
||||
return self.inner.puzzle_emb
|
||||
|
||||
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
return TinyRecursiveReasoningModel_ACTV1Carry(
|
||||
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||
|
||||
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||
|
||||
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||
)
|
||||
|
||||
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||
|
||||
# Update data, carry (removing halted sequences)
|
||||
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||
|
||||
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||
|
||||
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||
|
||||
# Forward inner model
|
||||
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||
|
||||
outputs = {
|
||||
"logits": logits,
|
||||
"q_halt_logits": q_halt_logits,
|
||||
"q_continue_logits": q_continue_logits
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
# Step
|
||||
new_steps = new_steps + 1
|
||||
is_last_step = new_steps >= self.config.halt_max_steps
|
||||
|
||||
halted = is_last_step
|
||||
|
||||
# if training, and ACT is enabled
|
||||
if self.training and (self.config.halt_max_steps > 1):
|
||||
|
||||
# Halt signal
|
||||
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||
|
||||
if self.config.no_ACT_continue:
|
||||
halted = halted | (q_halt_logits > 0)
|
||||
else:
|
||||
halted = halted | (q_halt_logits > q_continue_logits)
|
||||
|
||||
# Exploration
|
||||
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||
halted = halted & (new_steps >= min_halt_steps)
|
||||
|
||||
if not self.config.no_ACT_continue:
|
||||
# Compute target Q
|
||||
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||
# As batch_size is large, there're many parallel envs.
|
||||
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
||||
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||
|
||||
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||
323
models/recursive_reasoning/trm_hier6.py
Normal file
323
models/recursive_reasoning/trm_hier6.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import BaseModel
|
||||
import random
|
||||
from models.common import trunc_normal_init_
|
||||
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||
from models.sparse_embedding import CastedSparseEmbedding
|
||||
|
||||
IGNORE_LABEL_ID = -100
|
||||
|
||||
@dataclass
|
||||
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
||||
z_H: torch.Tensor
|
||||
z_L1: torch.Tensor
|
||||
z_L2: torch.Tensor
|
||||
z_L3: torch.Tensor
|
||||
z_L4: torch.Tensor
|
||||
z_L5: torch.Tensor
|
||||
z_L6: torch.Tensor
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class TinyRecursiveReasoningModel_ACTV1Carry:
|
||||
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
||||
|
||||
steps: torch.Tensor
|
||||
halted: torch.Tensor
|
||||
|
||||
current_data: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
||||
batch_size: int
|
||||
seq_len: int
|
||||
puzzle_emb_ndim: int = 0
|
||||
num_puzzle_identifiers: int
|
||||
vocab_size: int
|
||||
|
||||
H_cycles: int
|
||||
L_cycles: int
|
||||
|
||||
H_layers: int # ignored
|
||||
L_layers: int
|
||||
|
||||
# Transformer config
|
||||
hidden_size: int
|
||||
expansion: float
|
||||
num_heads: int
|
||||
pos_encodings: str
|
||||
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
# Halting Q-learning config
|
||||
halt_max_steps: int
|
||||
halt_exploration_prob: float
|
||||
|
||||
forward_dtype: str = "bfloat16"
|
||||
|
||||
# Alexia: added
|
||||
mlp_t: bool = False # use mlp on L instead of transformer
|
||||
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
||||
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
||||
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if self.config.mlp_t:
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
||||
self.mlp_t = SwiGLU(
|
||||
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||
expansion=config.expansion,
|
||||
)
|
||||
else:
|
||||
self.self_attn = Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
head_dim=config.hidden_size // config.num_heads,
|
||||
num_heads=config.num_heads,
|
||||
num_key_value_heads=config.num_heads,
|
||||
causal=False
|
||||
)
|
||||
self.mlp = SwiGLU(
|
||||
hidden_size=config.hidden_size,
|
||||
expansion=config.expansion,
|
||||
)
|
||||
self.norm_eps = config.rms_norm_eps
|
||||
|
||||
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# B, L, D = hidden_states.shape
|
||||
# Post Norm
|
||||
if self.config.mlp_t:
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
out = self.mlp_t(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
else:
|
||||
# Self Attention
|
||||
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||
# Fully Connected
|
||||
out = self.mlp(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
return hidden_states
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
hidden_states = hidden_states + input_injection
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
||||
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||
|
||||
# I/O
|
||||
|
||||
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||
embed_init_std = 1.0 / self.embed_scale
|
||||
|
||||
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
# Zero init puzzle embeddings
|
||||
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||
|
||||
# LM Blocks
|
||||
if self.config.pos_encodings == "rope":
|
||||
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||
base=self.config.rope_theta)
|
||||
elif self.config.pos_encodings == "learned":
|
||||
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Reasoning Layers
|
||||
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||
|
||||
# Initial states
|
||||
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
|
||||
# Q head special init
|
||||
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||
with torch.no_grad():
|
||||
self.q_head.weight.zero_()
|
||||
self.q_head.bias.fill_(-5) # type: ignore
|
||||
|
||||
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||
# Token embedding
|
||||
embedding = self.embed_tokens(input.to(torch.int32))
|
||||
|
||||
# Puzzle embeddings
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||
|
||||
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||
if pad_count > 0:
|
||||
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||
|
||||
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||
|
||||
# Position embeddings
|
||||
if self.config.pos_encodings == "learned":
|
||||
# scale by 1/sqrt(2) to maintain forward variance
|
||||
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||
|
||||
# Scale
|
||||
return self.embed_scale * embedding
|
||||
|
||||
def empty_carry(self, batch_size: int):
|
||||
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
)
|
||||
|
||||
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
||||
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
|
||||
z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1),
|
||||
z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2),
|
||||
z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3),
|
||||
z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4),
|
||||
z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5),
|
||||
z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
seq_info = dict(
|
||||
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||
)
|
||||
|
||||
# Input encoding
|
||||
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||
|
||||
# Forward iterations
|
||||
it = 0
|
||||
z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6]
|
||||
# H_cycles-1 without grad
|
||||
with torch.no_grad():
|
||||
for _H_step in range(self.config.H_cycles-1):
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
|
||||
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||
z_H = self.L_level(z_H, z_L_, **seq_info)
|
||||
# 1 with grad
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
|
||||
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
|
||||
z_H = self.L_level(z_H, z_L_, **seq_info)
|
||||
|
||||
# LM Outputs
|
||||
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad
|
||||
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
|
||||
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
||||
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
||||
"""ACT wrapper."""
|
||||
|
||||
def __init__(self, config_dict: dict):
|
||||
super().__init__()
|
||||
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
||||
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
||||
|
||||
@property
|
||||
def puzzle_emb(self):
|
||||
return self.inner.puzzle_emb
|
||||
|
||||
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
return TinyRecursiveReasoningModel_ACTV1Carry(
|
||||
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||
|
||||
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||
|
||||
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||
)
|
||||
|
||||
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||
|
||||
# Update data, carry (removing halted sequences)
|
||||
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||
|
||||
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||
|
||||
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||
|
||||
# Forward inner model
|
||||
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||
|
||||
outputs = {
|
||||
"logits": logits,
|
||||
"q_halt_logits": q_halt_logits,
|
||||
"q_continue_logits": q_continue_logits
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
# Step
|
||||
new_steps = new_steps + 1
|
||||
is_last_step = new_steps >= self.config.halt_max_steps
|
||||
|
||||
halted = is_last_step
|
||||
|
||||
# if training, and ACT is enabled
|
||||
if self.training and (self.config.halt_max_steps > 1):
|
||||
|
||||
# Halt signal
|
||||
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||
|
||||
if self.config.no_ACT_continue:
|
||||
halted = halted | (q_halt_logits > 0)
|
||||
else:
|
||||
halted = halted | (q_halt_logits > q_continue_logits)
|
||||
|
||||
# Exploration
|
||||
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||
halted = halted & (new_steps >= min_halt_steps)
|
||||
|
||||
if not self.config.no_ACT_continue:
|
||||
# Compute target Q
|
||||
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||
# As batch_size is large, there're many parallel envs.
|
||||
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
||||
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||
|
||||
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||
294
models/recursive_reasoning/trm_singlez.py
Normal file
294
models/recursive_reasoning/trm_singlez.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import BaseModel
|
||||
import random
|
||||
from models.common import trunc_normal_init_
|
||||
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
|
||||
from models.sparse_embedding import CastedSparseEmbedding
|
||||
|
||||
IGNORE_LABEL_ID = -100
|
||||
|
||||
@dataclass
|
||||
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
|
||||
z_L: torch.Tensor
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class TinyRecursiveReasoningModel_ACTV1Carry:
|
||||
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
|
||||
|
||||
steps: torch.Tensor
|
||||
halted: torch.Tensor
|
||||
|
||||
current_data: Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
|
||||
batch_size: int
|
||||
seq_len: int
|
||||
puzzle_emb_ndim: int = 0
|
||||
num_puzzle_identifiers: int
|
||||
vocab_size: int
|
||||
|
||||
H_cycles: int
|
||||
L_cycles: int
|
||||
|
||||
H_layers: int # ignored
|
||||
L_layers: int
|
||||
|
||||
# Transformer config
|
||||
hidden_size: int
|
||||
expansion: float
|
||||
num_heads: int
|
||||
pos_encodings: str
|
||||
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
# Halting Q-learning config
|
||||
halt_max_steps: int
|
||||
halt_exploration_prob: float
|
||||
|
||||
forward_dtype: str = "bfloat16"
|
||||
|
||||
# Alexia: added
|
||||
mlp_t: bool = False # use mlp on L instead of transformer
|
||||
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
|
||||
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
|
||||
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if self.config.mlp_t:
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
|
||||
self.mlp_t = SwiGLU(
|
||||
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
|
||||
expansion=config.expansion,
|
||||
)
|
||||
else:
|
||||
self.self_attn = Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
head_dim=config.hidden_size // config.num_heads,
|
||||
num_heads=config.num_heads,
|
||||
num_key_value_heads=config.num_heads,
|
||||
causal=False
|
||||
)
|
||||
self.mlp = SwiGLU(
|
||||
hidden_size=config.hidden_size,
|
||||
expansion=config.expansion,
|
||||
)
|
||||
self.norm_eps = config.rms_norm_eps
|
||||
|
||||
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# B, L, D = hidden_states.shape
|
||||
# Post Norm
|
||||
if self.config.mlp_t:
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
out = self.mlp_t(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
hidden_states = hidden_states.transpose(1,2)
|
||||
else:
|
||||
# Self Attention
|
||||
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
|
||||
# Fully Connected
|
||||
out = self.mlp(hidden_states)
|
||||
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
|
||||
return hidden_states
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
|
||||
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states=hidden_states, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
|
||||
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.forward_dtype = getattr(torch, self.config.forward_dtype)
|
||||
|
||||
# I/O
|
||||
|
||||
self.embed_scale = math.sqrt(self.config.hidden_size)
|
||||
embed_init_std = 1.0 / self.embed_scale
|
||||
|
||||
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
|
||||
|
||||
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
# Zero init puzzle embeddings
|
||||
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
|
||||
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
|
||||
|
||||
# LM Blocks
|
||||
if self.config.pos_encodings == "rope":
|
||||
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
|
||||
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
|
||||
base=self.config.rope_theta)
|
||||
elif self.config.pos_encodings == "learned":
|
||||
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
|
||||
else:
|
||||
pass
|
||||
|
||||
# Reasoning Layers
|
||||
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
|
||||
|
||||
# Initial states
|
||||
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
|
||||
|
||||
# Q head special init
|
||||
# Init Q to (almost) zero for faster learning during bootstrapping
|
||||
with torch.no_grad():
|
||||
self.q_head.weight.zero_()
|
||||
self.q_head.bias.fill_(-5) # type: ignore
|
||||
|
||||
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
|
||||
# Token embedding
|
||||
embedding = self.embed_tokens(input.to(torch.int32))
|
||||
|
||||
# Puzzle embeddings
|
||||
if self.config.puzzle_emb_ndim > 0:
|
||||
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
|
||||
|
||||
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
|
||||
if pad_count > 0:
|
||||
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
|
||||
|
||||
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
|
||||
|
||||
# Position embeddings
|
||||
if self.config.pos_encodings == "learned":
|
||||
# scale by 1/sqrt(2) to maintain forward variance
|
||||
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
|
||||
|
||||
# Scale
|
||||
return self.embed_scale * embedding
|
||||
|
||||
def empty_carry(self, batch_size: int):
|
||||
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
|
||||
)
|
||||
|
||||
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
|
||||
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
|
||||
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
|
||||
)
|
||||
|
||||
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
seq_info = dict(
|
||||
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
|
||||
)
|
||||
|
||||
# Input encoding
|
||||
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
|
||||
|
||||
# Forward iterations
|
||||
it = 0
|
||||
z_L = carry.z_L
|
||||
# H_cycles-1 without grad
|
||||
with torch.no_grad():
|
||||
for _H_step in range(self.config.H_cycles-1):
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
z_L = self.L_level(z_L + input_embeddings, **seq_info)
|
||||
z_L = self.L_level(z_L, **seq_info)
|
||||
# 1 with grad
|
||||
for _L_step in range(self.config.L_cycles):
|
||||
z_L = self.L_level(z_L + input_embeddings, **seq_info)
|
||||
z_L = self.L_level(z_L, **seq_info)
|
||||
z_out = z_L
|
||||
|
||||
# LM Outputs
|
||||
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad
|
||||
output = self.lm_head(z_out)[:, self.puzzle_emb_len:]
|
||||
q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
|
||||
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
|
||||
|
||||
|
||||
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
|
||||
"""ACT wrapper."""
|
||||
|
||||
def __init__(self, config_dict: dict):
|
||||
super().__init__()
|
||||
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
|
||||
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
|
||||
|
||||
@property
|
||||
def puzzle_emb(self):
|
||||
return self.inner.puzzle_emb
|
||||
|
||||
def initial_carry(self, batch: Dict[str, torch.Tensor]):
|
||||
batch_size = batch["inputs"].shape[0]
|
||||
|
||||
return TinyRecursiveReasoningModel_ACTV1Carry(
|
||||
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
|
||||
|
||||
steps=torch.zeros((batch_size, ), dtype=torch.int32),
|
||||
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
|
||||
|
||||
current_data={k: torch.empty_like(v) for k, v in batch.items()}
|
||||
)
|
||||
|
||||
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
|
||||
|
||||
# Update data, carry (removing halted sequences)
|
||||
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
|
||||
|
||||
new_steps = torch.where(carry.halted, 0, carry.steps)
|
||||
|
||||
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
|
||||
|
||||
# Forward inner model
|
||||
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
|
||||
|
||||
outputs = {
|
||||
"logits": logits,
|
||||
"q_halt_logits": q_halt_logits,
|
||||
"q_continue_logits": q_continue_logits
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
# Step
|
||||
new_steps = new_steps + 1
|
||||
is_last_step = new_steps >= self.config.halt_max_steps
|
||||
|
||||
halted = is_last_step
|
||||
|
||||
# if training, and ACT is enabled
|
||||
if self.training and (self.config.halt_max_steps > 1):
|
||||
|
||||
# Halt signal
|
||||
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
|
||||
|
||||
if self.config.no_ACT_continue:
|
||||
halted = halted | (q_halt_logits > 0)
|
||||
else:
|
||||
halted = halted | (q_halt_logits > q_continue_logits)
|
||||
|
||||
# Exploration
|
||||
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
|
||||
halted = halted & (new_steps >= min_halt_steps)
|
||||
|
||||
if not self.config.no_ACT_continue:
|
||||
# Compute target Q
|
||||
# NOTE: No replay buffer and target networks for computing target Q-value.
|
||||
# As batch_size is large, there're many parallel envs.
|
||||
# Similar concept as PQN https://arxiv.org/abs/2407.04811
|
||||
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
|
||||
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
|
||||
|
||||
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
|
||||
132
models/sparse_embedding.py
Normal file
132
models/sparse_embedding.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
from torch.optim.optimizer import Optimizer, ParamsT
|
||||
|
||||
from models.common import trunc_normal_init_
|
||||
|
||||
|
||||
class CastedSparseEmbedding(nn.Module):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
|
||||
super().__init__()
|
||||
self.cast_to = cast_to
|
||||
|
||||
# Real Weights
|
||||
# Truncated LeCun normal init
|
||||
self.weights = nn.Buffer(
|
||||
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
|
||||
)
|
||||
|
||||
# Local weights and IDs
|
||||
# Local embeddings, with gradient, not persistent
|
||||
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
|
||||
# Local embedding IDs, not persistent
|
||||
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
if not self.training:
|
||||
# Test mode, no gradient
|
||||
return self.weights[inputs].to(self.cast_to)
|
||||
|
||||
# Training mode, fill puzzle embedding from weights
|
||||
with torch.no_grad():
|
||||
self.local_weights.copy_(self.weights[inputs])
|
||||
self.local_ids.copy_(inputs)
|
||||
|
||||
return self.local_weights.to(self.cast_to)
|
||||
|
||||
|
||||
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params: ParamsT,
|
||||
|
||||
world_size: int,
|
||||
lr: Union[float, torch.Tensor] = 1e-3,
|
||||
weight_decay: float = 1e-2,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
world_size=world_size
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad
|
||||
def step(self, closure=None): # type: ignore
|
||||
for group in self.param_groups:
|
||||
# Find the sparse embedding weights
|
||||
local_weights_grad = None
|
||||
local_ids = None
|
||||
weights = None
|
||||
|
||||
assert len(group["params"]) == 3
|
||||
for p in group["params"]:
|
||||
if p.requires_grad:
|
||||
local_weights_grad = p.grad
|
||||
elif p.ndim == 1:
|
||||
local_ids = p
|
||||
elif p.ndim == 2:
|
||||
weights = p
|
||||
else:
|
||||
assert False
|
||||
|
||||
assert local_ids is not None
|
||||
assert weights is not None
|
||||
|
||||
# Apply SignSGD
|
||||
# Adam ≈ SignSGD if gradient is very sparse
|
||||
if local_weights_grad is not None:
|
||||
_sparse_emb_signsgd_dist(
|
||||
local_weights_grad,
|
||||
local_ids,
|
||||
weights,
|
||||
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
world_size=group["world_size"]
|
||||
)
|
||||
|
||||
|
||||
def _sparse_emb_signsgd_dist(
|
||||
local_weights_grad: torch.Tensor,
|
||||
local_ids: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
world_size: int
|
||||
) -> None:
|
||||
N, D = local_weights_grad.shape
|
||||
|
||||
# All-gather
|
||||
all_weights_grad = local_weights_grad
|
||||
all_ids = local_ids
|
||||
|
||||
if world_size > 1:
|
||||
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
|
||||
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
|
||||
|
||||
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
|
||||
dist.all_gather_into_tensor(all_ids, local_ids)
|
||||
|
||||
# Unique
|
||||
grad_ids, inv = all_ids.unique(return_inverse=True)
|
||||
|
||||
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
|
||||
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
|
||||
|
||||
# SignSGD with decoupled weight decay
|
||||
p = weights[grad_ids]
|
||||
|
||||
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
|
||||
|
||||
# Write updated slices back
|
||||
weights[grad_ids] = p
|
||||
654
pretrain.py
Normal file
654
pretrain.py
Normal file
@@ -0,0 +1,654 @@
|
||||
from typing import Optional, Any, Sequence, List
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import math
|
||||
import yaml
|
||||
import shutil
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import tqdm
|
||||
import wandb
|
||||
import coolname
|
||||
import hydra
|
||||
import pydantic
|
||||
from omegaconf import DictConfig
|
||||
from adam_atan2 import AdamATan2
|
||||
|
||||
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
|
||||
from utils.functions import load_model_class, get_model_source_path
|
||||
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
|
||||
from models.ema import EMAHelper
|
||||
|
||||
|
||||
class LossConfig(pydantic.BaseModel):
|
||||
model_config = pydantic.ConfigDict(extra='allow')
|
||||
name: str
|
||||
|
||||
|
||||
class ArchConfig(pydantic.BaseModel):
|
||||
model_config = pydantic.ConfigDict(extra='allow')
|
||||
name: str
|
||||
loss: LossConfig
|
||||
|
||||
|
||||
class EvaluatorConfig(pydantic.BaseModel):
|
||||
model_config = pydantic.ConfigDict(extra="allow")
|
||||
name: str
|
||||
|
||||
|
||||
class PretrainConfig(pydantic.BaseModel):
|
||||
# Config
|
||||
arch: ArchConfig
|
||||
# Data
|
||||
data_paths: List[str]
|
||||
data_paths_test: List[str] = []
|
||||
# Evaluators
|
||||
evaluators: List[EvaluatorConfig] = []
|
||||
|
||||
# Hyperparams
|
||||
global_batch_size: int
|
||||
epochs: int
|
||||
|
||||
lr: float
|
||||
lr_min_ratio: float
|
||||
lr_warmup_steps: int
|
||||
|
||||
weight_decay: float
|
||||
beta1: float
|
||||
beta2: float
|
||||
|
||||
# Puzzle embedding
|
||||
puzzle_emb_lr: float
|
||||
puzzle_emb_weight_decay: float
|
||||
|
||||
# Names
|
||||
project_name: Optional[str] = None
|
||||
run_name: Optional[str] = None
|
||||
load_checkpoint: Optional[str] = None
|
||||
checkpoint_path: Optional[str] = None
|
||||
|
||||
# Extras
|
||||
seed: int = 0
|
||||
checkpoint_every_eval: bool = False
|
||||
eval_interval: Optional[int] = None
|
||||
min_eval_interval: Optional[int] = 0 # when to start eval
|
||||
eval_save_outputs: List[str] = []
|
||||
|
||||
ema: bool = False # use Exponential-Moving-Average
|
||||
ema_rate: float = 0.999 # EMA-rate
|
||||
freeze_weights: bool = False # If True, freeze weights and only learn the embeddings
|
||||
|
||||
@dataclass
|
||||
class TrainState:
|
||||
model: nn.Module
|
||||
optimizers: Sequence[torch.optim.Optimizer]
|
||||
optimizer_lrs: Sequence[float]
|
||||
carry: Any
|
||||
|
||||
step: int
|
||||
total_steps: int
|
||||
|
||||
|
||||
def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
|
||||
dataset = PuzzleDataset(PuzzleDatasetConfig(
|
||||
seed=config.seed,
|
||||
dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths,
|
||||
rank=rank,
|
||||
num_replicas=world_size,
|
||||
**kwargs
|
||||
), split=split)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=None,
|
||||
num_workers=1,
|
||||
prefetch_factor=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True
|
||||
)
|
||||
return dataloader, dataset.metadata
|
||||
|
||||
|
||||
def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
|
||||
model_cfg = dict(
|
||||
**config.arch.__pydantic_extra__, # type: ignore
|
||||
batch_size=config.global_batch_size // world_size,
|
||||
vocab_size=train_metadata.vocab_size,
|
||||
seq_len=train_metadata.seq_len,
|
||||
num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
|
||||
causal=False # Non-autoregressive
|
||||
)
|
||||
|
||||
# Instantiate model with loss head
|
||||
model_cls = load_model_class(config.arch.name)
|
||||
loss_head_cls = load_model_class(config.arch.loss.name)
|
||||
|
||||
with torch.device("cuda"):
|
||||
model: nn.Module = model_cls(model_cfg)
|
||||
print(model)
|
||||
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
|
||||
if "DISABLE_COMPILE" not in os.environ:
|
||||
model = torch.compile(model) # type: ignore
|
||||
|
||||
# Load checkpoint
|
||||
if rank == 0:
|
||||
load_checkpoint(model, config)
|
||||
|
||||
# Broadcast parameters from rank 0
|
||||
if world_size > 1:
|
||||
with torch.no_grad():
|
||||
for param in list(model.parameters()) + list(model.buffers()):
|
||||
dist.broadcast(param, src=0)
|
||||
|
||||
# Optimizers and lr
|
||||
if config.arch.puzzle_emb_ndim == 0:
|
||||
optimizers = [
|
||||
AdamATan2(
|
||||
model.parameters(),
|
||||
lr=0, # Needs to be set by scheduler
|
||||
weight_decay=config.weight_decay,
|
||||
betas=(config.beta1, config.beta2)
|
||||
)
|
||||
]
|
||||
optimizer_lrs = [
|
||||
config.lr
|
||||
]
|
||||
elif config.freeze_weights:
|
||||
optimizers = [
|
||||
CastedSparseEmbeddingSignSGD_Distributed(
|
||||
model.model.puzzle_emb.buffers(), # type: ignore
|
||||
lr=0, # Needs to be set by scheduler
|
||||
weight_decay=config.puzzle_emb_weight_decay,
|
||||
world_size=world_size
|
||||
)
|
||||
]
|
||||
optimizer_lrs = [
|
||||
config.puzzle_emb_lr
|
||||
]
|
||||
else:
|
||||
optimizers = [
|
||||
CastedSparseEmbeddingSignSGD_Distributed(
|
||||
model.model.puzzle_emb.buffers(), # type: ignore
|
||||
lr=0, # Needs to be set by scheduler
|
||||
weight_decay=config.puzzle_emb_weight_decay,
|
||||
world_size=world_size
|
||||
),
|
||||
AdamATan2(
|
||||
model.parameters(),
|
||||
lr=0, # Needs to be set by scheduler
|
||||
weight_decay=config.weight_decay,
|
||||
betas=(config.beta1, config.beta2)
|
||||
)
|
||||
]
|
||||
optimizer_lrs = [
|
||||
config.puzzle_emb_lr,
|
||||
config.lr
|
||||
]
|
||||
|
||||
return model, optimizers, optimizer_lrs
|
||||
|
||||
def mix_weights_direct(device, alpha, net, nets):
|
||||
sd = []
|
||||
for i in range(len(nets)):
|
||||
sd += [nets[i].state_dict()]
|
||||
sd_alpha = {}
|
||||
for k in sd[0].keys():
|
||||
comb_net = alpha[0]*sd[0][k].to(device)
|
||||
for i in range(1,len(nets)):
|
||||
comb_net += alpha[i]*sd[i][k].to(device)
|
||||
sd_alpha[k] = comb_net
|
||||
net.load_state_dict(sd_alpha)
|
||||
return net
|
||||
|
||||
def cosine_schedule_with_warmup_lr_lambda(
|
||||
current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
|
||||
):
|
||||
if current_step < num_warmup_steps:
|
||||
return base_lr * float(current_step) / float(max(1, num_warmup_steps))
|
||||
|
||||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
|
||||
|
||||
|
||||
def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
|
||||
# Estimated total training steps
|
||||
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
|
||||
|
||||
# Model
|
||||
model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size)
|
||||
|
||||
return TrainState(
|
||||
step=0,
|
||||
total_steps=total_steps,
|
||||
|
||||
model=model,
|
||||
optimizers=optimizers,
|
||||
optimizer_lrs=optimizer_lrs,
|
||||
carry=None
|
||||
)
|
||||
|
||||
|
||||
def save_train_state(config: PretrainConfig, train_state: TrainState):
|
||||
# FIXME: Only saved model.
|
||||
if config.checkpoint_path is None:
|
||||
return
|
||||
|
||||
os.makedirs(config.checkpoint_path, exist_ok=True)
|
||||
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
|
||||
|
||||
|
||||
def load_checkpoint(model: nn.Module, config: PretrainConfig):
|
||||
if config.load_checkpoint is not None:
|
||||
print(f"Loading checkpoint {config.load_checkpoint}")
|
||||
|
||||
# Load state dict
|
||||
state_dict = torch.load(config.load_checkpoint, map_location="cuda")
|
||||
|
||||
# Resize and reset puzzle emb if needed
|
||||
puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights"
|
||||
expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore
|
||||
if puzzle_emb_name in state_dict:
|
||||
puzzle_emb = state_dict[puzzle_emb_name]
|
||||
if puzzle_emb.shape != expected_shape:
|
||||
print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}")
|
||||
# Re-initialize using mean
|
||||
state_dict[puzzle_emb_name] = (
|
||||
torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous()
|
||||
)
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
|
||||
|
||||
def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
|
||||
return cosine_schedule_with_warmup_lr_lambda(
|
||||
current_step=train_state.step,
|
||||
base_lr=base_lr,
|
||||
num_warmup_steps=round(config.lr_warmup_steps),
|
||||
num_training_steps=train_state.total_steps,
|
||||
min_ratio=config.lr_min_ratio
|
||||
)
|
||||
|
||||
|
||||
|
||||
def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]:
|
||||
data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths
|
||||
# Initialize evaluators
|
||||
evaluators = []
|
||||
for cfg in config.evaluators:
|
||||
for data_path in data_paths:
|
||||
cls = load_model_class(cfg.name, "evaluators.")(
|
||||
data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__
|
||||
) # type: ignore
|
||||
evaluators.append(cls)
|
||||
|
||||
return evaluators
|
||||
|
||||
def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
|
||||
train_state.step += 1
|
||||
if train_state.step > train_state.total_steps: # At most train_total_steps
|
||||
return
|
||||
|
||||
# To device
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
# Init carry if it is None
|
||||
if train_state.carry is None:
|
||||
with torch.device("cuda"):
|
||||
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
|
||||
|
||||
# Forward
|
||||
train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
|
||||
|
||||
((1 / global_batch_size) * loss).backward()
|
||||
|
||||
# Allreduce
|
||||
if world_size > 1:
|
||||
for param in train_state.model.parameters():
|
||||
if param.grad is not None:
|
||||
dist.all_reduce(param.grad)
|
||||
|
||||
# Apply optimizer
|
||||
lr_this_step = None
|
||||
for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
|
||||
lr_this_step = compute_lr(base_lr, config, train_state)
|
||||
|
||||
for param_group in optim.param_groups:
|
||||
param_group['lr'] = lr_this_step
|
||||
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
# Reduce metrics
|
||||
if len(metrics):
|
||||
assert not any(v.requires_grad for v in metrics.values())
|
||||
|
||||
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
|
||||
# Reduce and reconstruct
|
||||
metric_values = torch.stack([metrics[k] for k in metric_keys])
|
||||
if world_size > 1:
|
||||
dist.reduce(metric_values, dst=0)
|
||||
|
||||
if rank == 0:
|
||||
metric_values = metric_values.cpu().numpy()
|
||||
reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
|
||||
|
||||
# Postprocess
|
||||
count = max(reduced_metrics["count"], 1) # Avoid NaNs
|
||||
reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
|
||||
|
||||
reduced_metrics["train/lr"] = lr_this_step
|
||||
return reduced_metrics
|
||||
|
||||
def evaluate(
|
||||
config: PretrainConfig,
|
||||
train_state: TrainState,
|
||||
eval_loader: torch.utils.data.DataLoader,
|
||||
eval_metadata: PuzzleDatasetMetadata,
|
||||
evaluators: List[Any],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
cpu_group: Optional[dist.ProcessGroup],
|
||||
):
|
||||
reduced_metrics = None
|
||||
|
||||
with torch.inference_mode():
|
||||
return_keys = set(config.eval_save_outputs)
|
||||
for evaluator in evaluators:
|
||||
evaluator.begin_eval()
|
||||
return_keys.update(evaluator.required_outputs)
|
||||
|
||||
# Run evaluation
|
||||
set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
|
||||
|
||||
save_preds = {}
|
||||
|
||||
metric_keys = []
|
||||
metric_values = None
|
||||
|
||||
carry = None
|
||||
processed_batches = 0
|
||||
|
||||
for set_name, batch, global_batch_size in eval_loader:
|
||||
processed_batches += 1
|
||||
if rank == 0:
|
||||
print(f"Processing batch {processed_batches}: {set_name}")
|
||||
|
||||
# To device
|
||||
batch = {k: v.cuda() for k, v in batch.items()}
|
||||
with torch.device("cuda"):
|
||||
carry = train_state.model.initial_carry(batch) # type: ignore
|
||||
|
||||
# Forward
|
||||
inference_steps = 0
|
||||
while True:
|
||||
carry, loss, metrics, preds, all_finish = train_state.model(
|
||||
carry=carry, batch=batch, return_keys=return_keys
|
||||
)
|
||||
inference_steps += 1
|
||||
|
||||
if all_finish:
|
||||
break
|
||||
|
||||
if rank == 0:
|
||||
print(f" Completed inference in {inference_steps} steps")
|
||||
|
||||
for collection in (batch, preds):
|
||||
for k, v in collection.items():
|
||||
if k in config.eval_save_outputs:
|
||||
save_preds.setdefault(k, [])
|
||||
save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
|
||||
|
||||
for evaluator in evaluators:
|
||||
evaluator.update_batch(batch, preds)
|
||||
|
||||
del carry, loss, preds, batch, all_finish
|
||||
|
||||
# Aggregate metrics
|
||||
set_id = set_ids[set_name]
|
||||
|
||||
if metric_values is None:
|
||||
metric_keys = list(
|
||||
sorted(metrics.keys())
|
||||
) # Sort keys to guarantee all processes use the same order.
|
||||
metric_values = torch.zeros(
|
||||
(len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
|
||||
|
||||
del metrics
|
||||
|
||||
# concatenate save preds
|
||||
save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()}
|
||||
|
||||
# Save preds
|
||||
if config.checkpoint_path is not None and len(save_preds):
|
||||
# Each rank save predictions independently
|
||||
os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)
|
||||
torch.save(
|
||||
save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")
|
||||
)
|
||||
|
||||
del save_preds
|
||||
|
||||
# Reduce to rank 0
|
||||
if metric_values is not None:
|
||||
if world_size > 1:
|
||||
dist.reduce(metric_values, dst=0)
|
||||
|
||||
if rank == 0:
|
||||
reduced_metrics = metric_values.cpu().numpy()
|
||||
reduced_metrics = {
|
||||
set_name: {
|
||||
metric_name: reduced_metrics[set_id, metric_id]
|
||||
for metric_id, metric_name in enumerate(metric_keys)
|
||||
}
|
||||
for set_id, set_name in enumerate(set_ids)
|
||||
}
|
||||
|
||||
# Postprocess
|
||||
for set_name, m in reduced_metrics.items():
|
||||
count = m.pop("count")
|
||||
reduced_metrics[set_name] = {k: v / count for k, v in m.items()}
|
||||
|
||||
# Run evaluators
|
||||
if rank == 0:
|
||||
print(f"\nRunning {len(evaluators)} evaluator(s)...")
|
||||
|
||||
for i, evaluator in enumerate(evaluators):
|
||||
if rank == 0:
|
||||
print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}")
|
||||
|
||||
# Path for saving
|
||||
evaluator_save_path = None
|
||||
if config.checkpoint_path is not None:
|
||||
evaluator_save_path = os.path.join(
|
||||
config.checkpoint_path,
|
||||
f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}",
|
||||
)
|
||||
os.makedirs(evaluator_save_path, exist_ok=True)
|
||||
|
||||
# Run and log
|
||||
metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group)
|
||||
if rank == 0 and metrics is not None:
|
||||
if reduced_metrics is None:
|
||||
reduced_metrics = {}
|
||||
|
||||
reduced_metrics.update(metrics)
|
||||
print(f" Completed {evaluator.__class__.__name__}")
|
||||
|
||||
if rank == 0:
|
||||
print("All evaluators completed!")
|
||||
|
||||
return reduced_metrics
|
||||
|
||||
def save_code_and_config(config: PretrainConfig):
|
||||
if config.checkpoint_path is None or wandb.run is None:
|
||||
return
|
||||
|
||||
os.makedirs(config.checkpoint_path, exist_ok=True)
|
||||
|
||||
# Copy code
|
||||
code_list = [
|
||||
get_model_source_path(config.arch.name),
|
||||
get_model_source_path(config.arch.loss.name)
|
||||
]
|
||||
for code_file in code_list:
|
||||
if code_file is not None:
|
||||
code_name = os.path.basename(code_file)
|
||||
|
||||
shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
|
||||
|
||||
# Dump config as yaml
|
||||
config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
|
||||
with open(config_file, "wt") as f:
|
||||
yaml.dump(config.model_dump(), f)
|
||||
|
||||
# Log code
|
||||
wandb.run.log_code(config.checkpoint_path)
|
||||
|
||||
|
||||
def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
|
||||
objects = [None]
|
||||
if rank == 0:
|
||||
config = PretrainConfig(**hydra_config) # type: ignore
|
||||
|
||||
# Naming
|
||||
if config.project_name is None:
|
||||
config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch"
|
||||
if config.run_name is None:
|
||||
config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
|
||||
if config.checkpoint_path is None:
|
||||
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
|
||||
|
||||
objects = [config]
|
||||
|
||||
if world_size > 1:
|
||||
dist.broadcast_object_list(objects, src=0)
|
||||
|
||||
return objects[0] # type: ignore
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
|
||||
def launch(hydra_config: DictConfig):
|
||||
RANK = 0
|
||||
WORLD_SIZE = 1
|
||||
CPU_PROCESS_GROUP = None
|
||||
|
||||
# Initialize distributed training if in distributed environment (e.g. torchrun)
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
# Initialize distributed, default device and dtype
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
RANK = dist.get_rank()
|
||||
WORLD_SIZE = dist.get_world_size()
|
||||
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
# CPU GLOO process group
|
||||
CPU_PROCESS_GROUP = dist.new_group(backend="gloo")
|
||||
assert (
|
||||
dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE
|
||||
)
|
||||
|
||||
# Load sync'ed config
|
||||
config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
|
||||
|
||||
# Seed RNGs to ensure consistency
|
||||
torch.random.manual_seed(config.seed + RANK)
|
||||
|
||||
# Dataset
|
||||
train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
|
||||
total_iters = config.epochs // train_epochs_per_iter
|
||||
|
||||
assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
|
||||
|
||||
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
||||
try:
|
||||
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
||||
except:
|
||||
print("NO EVAL DATA FOUND")
|
||||
eval_loader = eval_metadata = None
|
||||
|
||||
try:
|
||||
evaluators = create_evaluators(config, eval_metadata)
|
||||
except:
|
||||
print("No evaluator found")
|
||||
evaluators = []
|
||||
|
||||
# Train state
|
||||
train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE)
|
||||
|
||||
# Progress bar and logger
|
||||
progress_bar = None
|
||||
ema_helper = None
|
||||
if RANK == 0:
|
||||
progress_bar = tqdm.tqdm(total=train_state.total_steps)
|
||||
wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
|
||||
wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
|
||||
save_code_and_config(config)
|
||||
if config.ema:
|
||||
print('Setup EMA')
|
||||
ema_helper = EMAHelper(mu=config.ema_rate)
|
||||
ema_helper.register(train_state.model)
|
||||
|
||||
# Training Loop
|
||||
for _iter_id in range(total_iters):
|
||||
print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
|
||||
|
||||
############ Train Iter
|
||||
if RANK == 0:
|
||||
print("TRAIN")
|
||||
train_state.model.train()
|
||||
for set_name, batch, global_batch_size in train_loader:
|
||||
metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
|
||||
|
||||
if RANK == 0 and metrics is not None:
|
||||
wandb.log(metrics, step=train_state.step)
|
||||
progress_bar.update(train_state.step - progress_bar.n) # type: ignore
|
||||
if config.ema:
|
||||
ema_helper.update(train_state.model)
|
||||
|
||||
if _iter_id >= config.min_eval_interval:
|
||||
############ Evaluation
|
||||
if RANK == 0:
|
||||
print("EVALUATE")
|
||||
if config.ema:
|
||||
print("SWITCH TO EMA")
|
||||
train_state_eval = copy.deepcopy(train_state)
|
||||
train_state_eval.model = ema_helper.ema_copy(train_state_eval.model)
|
||||
else:
|
||||
train_state_eval = train_state
|
||||
train_state_eval.model.eval()
|
||||
metrics = evaluate(config,
|
||||
train_state_eval,
|
||||
eval_loader,
|
||||
eval_metadata,
|
||||
evaluators,
|
||||
rank=RANK,
|
||||
world_size=WORLD_SIZE,
|
||||
cpu_group=CPU_PROCESS_GROUP)
|
||||
|
||||
if RANK == 0 and metrics is not None:
|
||||
wandb.log(metrics, step=train_state.step)
|
||||
|
||||
############ Checkpointing
|
||||
if RANK == 0:
|
||||
print("SAVE CHECKPOINT")
|
||||
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
|
||||
save_train_state(config, train_state_eval)
|
||||
|
||||
if config.ema:
|
||||
del train_state_eval
|
||||
|
||||
# finalize
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch()
|
||||
250
puzzle_dataset.py
Normal file
250
puzzle_dataset.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
import numpy as np
|
||||
import pydantic
|
||||
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset, get_worker_info
|
||||
|
||||
from models.losses import IGNORE_LABEL_ID
|
||||
from dataset.common import PuzzleDatasetMetadata
|
||||
|
||||
from argdantic import ArgParser
|
||||
from pydantic import BaseModel
|
||||
|
||||
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
|
||||
# Pack examples into a full batch
|
||||
batch = []
|
||||
batch_puzzle_indices = []
|
||||
current_size = 0
|
||||
|
||||
while (start_index < group_order.size) and (current_size < global_batch_size):
|
||||
# Pick a group and a puzzle from that group
|
||||
group_id = group_order[start_index]
|
||||
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
|
||||
start_index += 1
|
||||
|
||||
# Get range of the puzzle
|
||||
puzzle_start = puzzle_indices[puzzle_id]
|
||||
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
|
||||
|
||||
append_size = min(puzzle_size, global_batch_size - current_size)
|
||||
|
||||
# Put into batch
|
||||
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
|
||||
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
|
||||
|
||||
current_size += append_size
|
||||
|
||||
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
|
||||
|
||||
|
||||
class PuzzleDatasetConfig(pydantic.BaseModel):
|
||||
seed: int
|
||||
dataset_paths: List[str]
|
||||
global_batch_size: int
|
||||
test_set_mode: bool
|
||||
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
|
||||
rank: int
|
||||
num_replicas: int
|
||||
|
||||
class PuzzleDataset(IterableDataset):
|
||||
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.split = split
|
||||
|
||||
# Merge multiple metadata
|
||||
prev_seq_len = None
|
||||
prev_vocab_size = None
|
||||
prev_pad_id = None
|
||||
prev_ignore_label_id = None
|
||||
prev_blank_identifier_id = None
|
||||
prev_sets = None
|
||||
prev_num_identifiers = None
|
||||
mean_puzzle_examples = 0
|
||||
total_puzzles = 0
|
||||
total_groups = 0
|
||||
num_identifiers = 0
|
||||
for dataset_path in config.dataset_paths:
|
||||
current_metadata = self._load_metadata(dataset_path)
|
||||
if prev_seq_len is None:
|
||||
prev_seq_len = current_metadata.seq_len
|
||||
prev_vocab_size = current_metadata.vocab_size
|
||||
prev_pad_id = current_metadata.pad_id
|
||||
prev_ignore_label_id = current_metadata.ignore_label_id
|
||||
prev_blank_identifier_id = current_metadata.blank_identifier_id
|
||||
prev_sets = current_metadata.sets
|
||||
prev_num_identifiers = current_metadata.num_puzzle_identifiers
|
||||
else:
|
||||
assert prev_seq_len == current_metadata.seq_len
|
||||
assert prev_vocab_size == current_metadata.vocab_size
|
||||
assert prev_pad_id == current_metadata.pad_id
|
||||
assert prev_ignore_label_id == current_metadata.ignore_label_id
|
||||
assert prev_blank_identifier_id == current_metadata.blank_identifier_id
|
||||
assert prev_sets == current_metadata.sets
|
||||
assert prev_num_identifiers == current_metadata.num_puzzle_identifiers
|
||||
mean_puzzle_examples += current_metadata.mean_puzzle_examples*current_metadata.total_puzzles
|
||||
total_puzzles += current_metadata.total_puzzles
|
||||
total_groups += current_metadata.total_groups
|
||||
num_identifiers += current_metadata.num_puzzle_identifiers
|
||||
mean_puzzle_examples = mean_puzzle_examples / total_puzzles
|
||||
|
||||
self.metadata = PuzzleDatasetMetadata(
|
||||
seq_len=prev_seq_len,
|
||||
vocab_size=prev_vocab_size,
|
||||
pad_id=prev_pad_id,
|
||||
ignore_label_id=prev_ignore_label_id,
|
||||
blank_identifier_id=prev_blank_identifier_id,
|
||||
num_puzzle_identifiers=num_identifiers,
|
||||
total_groups=total_groups,
|
||||
mean_puzzle_examples=mean_puzzle_examples,
|
||||
total_puzzles=total_puzzles,
|
||||
sets=prev_sets
|
||||
)
|
||||
|
||||
# Checks
|
||||
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
|
||||
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
|
||||
|
||||
# State
|
||||
self._data = None
|
||||
self._iters = 0
|
||||
|
||||
def _load_metadata(self, dataset_path) -> PuzzleDatasetMetadata:
|
||||
with open(os.path.join(dataset_path, self.split, "dataset.json"), "r") as f:
|
||||
return PuzzleDatasetMetadata(**json.load(f))
|
||||
|
||||
def _lazy_load_dataset(self):
|
||||
if self._data is not None:
|
||||
return
|
||||
|
||||
field_mmap_modes = {
|
||||
"inputs": "r",
|
||||
"labels": "r",
|
||||
|
||||
# Keep indices in memory
|
||||
"puzzle_identifiers": None,
|
||||
"puzzle_indices": None,
|
||||
"group_indices": None
|
||||
}
|
||||
|
||||
# Load data
|
||||
self._data = {}
|
||||
for set_name in self.metadata.sets: # Load subset
|
||||
for i, dataset_path in enumerate(self.config.dataset_paths):
|
||||
if i > 0:
|
||||
set_name_ = set_name + str(i)
|
||||
else:
|
||||
set_name_ = set_name
|
||||
self._data[set_name_] = {
|
||||
field_name: np.load(os.path.join(dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
|
||||
for field_name, mmap_mode in field_mmap_modes.items()
|
||||
}
|
||||
|
||||
|
||||
def _collate_batch(self, batch):
|
||||
# Convert dtype
|
||||
batch = {k: v.astype(np.int32) for k, v in batch.items()}
|
||||
|
||||
# Convert ignore label IDs
|
||||
if self.metadata.ignore_label_id is not None:
|
||||
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
|
||||
|
||||
# Pad
|
||||
if batch["puzzle_identifiers"].size < self.local_batch_size:
|
||||
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
|
||||
pad_values = {
|
||||
"inputs": self.metadata.pad_id,
|
||||
"labels": IGNORE_LABEL_ID,
|
||||
"puzzle_identifiers": self.metadata.blank_identifier_id
|
||||
}
|
||||
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
|
||||
|
||||
# To tensor
|
||||
return {k: torch.from_numpy(v) for k, v in batch.items()}
|
||||
|
||||
def _iter_test(self):
|
||||
for set_i, (set_name, dataset) in enumerate(self._data.items()): # type: ignore
|
||||
total_examples = len(dataset["inputs"])
|
||||
|
||||
# Load examples one by one
|
||||
start_index = 0
|
||||
while start_index < total_examples:
|
||||
# Compute indices
|
||||
end_index = min(total_examples, start_index + self.config.global_batch_size)
|
||||
|
||||
local_start = start_index + self.config.rank * self.local_batch_size
|
||||
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
|
||||
|
||||
# Get batch of examples, and also puzzle IDs
|
||||
puzzle_indices = []
|
||||
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
|
||||
for i in range(local_start, local_end):
|
||||
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
|
||||
puzzle_index += 1
|
||||
|
||||
puzzle_indices.append(puzzle_index)
|
||||
|
||||
batch = self._collate_batch({
|
||||
"inputs": dataset["inputs"][local_start: local_end],
|
||||
"labels": dataset["labels"][local_start: local_end],
|
||||
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
|
||||
})
|
||||
|
||||
yield set_name, batch, end_index - start_index
|
||||
|
||||
# Advance to next batch
|
||||
start_index += self.config.global_batch_size
|
||||
|
||||
def _iter_train(self):
|
||||
for set_name, dataset in self._data.items(): # type: ignore
|
||||
# Increase epoch count
|
||||
self._iters += 1
|
||||
|
||||
# Randomly shuffle groups
|
||||
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
|
||||
|
||||
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
|
||||
start_index = 0
|
||||
|
||||
while start_index < group_order.size:
|
||||
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
|
||||
rng,
|
||||
group_order=group_order,
|
||||
puzzle_indices=dataset["puzzle_indices"],
|
||||
group_indices=dataset["group_indices"],
|
||||
start_index=start_index,
|
||||
global_batch_size=self.config.global_batch_size,
|
||||
)
|
||||
|
||||
# Select current rank and collate
|
||||
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
|
||||
|
||||
# Drop last batch
|
||||
if global_effective_batch_size < self.config.global_batch_size:
|
||||
break
|
||||
|
||||
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
||||
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
|
||||
batch = self._collate_batch({
|
||||
"inputs": dataset["inputs"][batch_indices],
|
||||
"labels": dataset["labels"][batch_indices],
|
||||
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
|
||||
})
|
||||
|
||||
yield set_name, batch, global_effective_batch_size
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = get_worker_info()
|
||||
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
|
||||
|
||||
self._lazy_load_dataset()
|
||||
|
||||
# Iterate using specified mode
|
||||
if self.config.test_set_mode:
|
||||
yield from self._iter_test()
|
||||
else:
|
||||
yield from self._iter_train()
|
||||
|
||||
20
requirements.txt
Normal file
20
requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
torch
|
||||
adam-atan2
|
||||
einops
|
||||
tqdm
|
||||
coolname
|
||||
pydantic
|
||||
argdantic
|
||||
wandb
|
||||
omegaconf
|
||||
hydra-core
|
||||
huggingface_hub
|
||||
packaging
|
||||
ninja
|
||||
wheel
|
||||
setuptools
|
||||
setuptools-scm
|
||||
pydantic-core
|
||||
huggingface_hub
|
||||
numba
|
||||
triton
|
||||
19
utils/functions.py
Normal file
19
utils/functions.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
|
||||
def load_model_class(identifier: str, prefix: str = "models."):
|
||||
module_path, class_name = identifier.split('@')
|
||||
|
||||
# Import the module
|
||||
module = importlib.import_module(prefix + module_path)
|
||||
cls = getattr(module, class_name)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
def get_model_source_path(identifier: str, prefix: str = "models."):
|
||||
module_path, class_name = identifier.split('@')
|
||||
|
||||
module = importlib.import_module(prefix + module_path)
|
||||
return inspect.getsourcefile(module)
|
||||
Reference in New Issue
Block a user