This commit is contained in:
Alexia Jolicoeur-Martineau
2025-10-07 09:26:04 -04:00
commit 8120f2bdf7
39 changed files with 27428 additions and 0 deletions

32
models/common.py Normal file
View File

@@ -0,0 +1,32 @@
import math
import torch
from torch import nn
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
with torch.no_grad():
if std == 0:
tensor.zero_()
else:
sqrt2 = math.sqrt(2)
a = math.erf(lower / sqrt2)
b = math.erf(upper / sqrt2)
z = (b - a) / 2
c = (2 * math.pi) ** -0.5
pdf_u = c * math.exp(-0.5 * lower ** 2)
pdf_l = c * math.exp(-0.5 * upper ** 2)
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
tensor.uniform_(a, b)
tensor.erfinv_()
tensor.mul_(sqrt2 * comp_std)
tensor.clip_(lower * comp_std, upper * comp_std)
return tensor

40
models/ema.py Normal file
View File

@@ -0,0 +1,40 @@
import copy
import torch.nn as nn
class EMAHelper(object):
def __init__(self, mu=0.999):
self.mu = mu
self.shadow = {}
def register(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
def ema(self, module):
if isinstance(module, nn.DataParallel):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
param.data.copy_(self.shadow[name].data)
def ema_copy(self, module):
module_copy = copy.deepcopy(module)
self.ema(module_copy)
return module_copy
def state_dict(self):
return self.shadow
def load_state_dict(self, state_dict):
self.shadow = state_dict

169
models/layers.py Normal file
View File

@@ -0,0 +1,169 @@
from typing import Tuple
import einops
import torch
from torch import nn
import torch.nn.functional as F
#try:
# from flash_attn_interface import flash_attn_func # type: ignore[import]
#except ImportError:
# # Fallback to FlashAttention 2
# from flash_attn import flash_attn_func # type: ignore[import]
from torch.nn.functional import scaled_dot_product_attention
from models.common import trunc_normal_init_
CosSin = Tuple[torch.Tensor, torch.Tensor]
def _find_multiple(a, b):
return (-(a // -b)) * b
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [bs, seq_len, num_heads, head_dim]
# cos, sin: [seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class CastedLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool):
super().__init__()
# Truncated LeCun normal init
self.weight = nn.Parameter(
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
)
self.bias = None
if bias:
# Zero init bias
self.bias = nn.Parameter(torch.zeros((out_features, )))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
class CastedEmbedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
init_std: float,
cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Truncated LeCun normal init
self.embedding_weight = nn.Parameter(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.embedding(input, self.embedding_weight.to(self.cast_to))
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings, base, device=None):
super().__init__()
# RoPE
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
def forward(self):
return self.cos_cached, self.sin_cached
class Attention(nn.Module):
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
super().__init__()
self.hidden_size = hidden_size
self.head_dim = head_dim
self.output_size = head_dim * num_heads
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads
self.causal = causal
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# hidden_states: [bs, seq_len, num_heads, head_dim]
qkv = self.qkv_proj(hidden_states)
# Split head
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query = qkv[:, :, :self.num_heads]
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
# RoPE
if cos_sin is not None:
cos, sin = cos_sin
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# flash attn
query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func
attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal)
attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D')
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)
class LinearSwish(nn.Module):
def __init__(self, hidden_size: int, reverse=False):
super().__init__()
self.linear = CastedLinear(hidden_size, hidden_size, bias=False)
self.reverse = reverse
def forward(self, x):
if self.reverse:
return F.silu(self.linear(x))
else:
return self.linear(F.silu(x))
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, expansion: float):
super().__init__()
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)

103
models/losses.py Normal file
View File

@@ -0,0 +1,103 @@
from typing import Any, Tuple, Dict, Sequence, Optional
import torch
import torch.nn.functional as F
from torch import nn
import math
IGNORE_LABEL_ID = -100
def s(x, epsilon=1e-30):
return torch.where(
x<0,
1/(1-x+ epsilon),
x + 1
)
def log_stablemax(x, dim=-1):
s_x = s(x)
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
if valid_mask is None:
valid_mask = (labels != ignore_index)
transformed_labels = torch.where(valid_mask, labels, 0)
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
return -torch.where(valid_mask, prediction_logprobs, 0)
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
# Cast logits to f32
# Flatten logits
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
class ACTLossHead(nn.Module):
def __init__(self, model: nn.Module, loss_type: str):
super().__init__()
self.model = model
self.loss_fn = globals()[loss_type]
def initial_carry(self, *args, **kwargs):
return self.model.initial_carry(*args, **kwargs) # type: ignore
def forward(
self,
return_keys: Sequence[str],
# Model args
**model_kwargs,
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
# Model logits
# B x SeqLen x D
new_carry, outputs = self.model(**model_kwargs)
labels = new_carry.current_data["labels"]
with torch.no_grad():
# Preds
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
# Correctness
mask = (labels != IGNORE_LABEL_ID)
loss_counts = mask.sum(-1)
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
seq_is_correct = is_correct.sum(-1) == loss_counts
# Metrics (halted)
valid_metrics = new_carry.halted & (loss_counts > 0)
metrics = {
"count": valid_metrics.sum(),
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
}
# Losses
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
metrics.update({
"lm_loss": lm_loss.detach(),
"q_halt_loss": q_halt_loss.detach(),
})
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
q_continue_loss = 0
if "target_q_continue" in outputs:
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
metrics["q_continue_loss"] = q_continue_loss.detach()
# Filter outputs for return
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()

View File

@@ -0,0 +1,294 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
@dataclass
class HierarchicalReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L: torch.Tensor
@dataclass
class HierarchicalReasoningModel_ACTV1Carry:
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool=False # use mlp on L instead of transformer
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
# Input injection (add)
hidden_states = hidden_states + input_injection
# Layers
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return HierarchicalReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
return HierarchicalReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
with torch.no_grad():
z_H, z_L = carry.z_H, carry.z_L
for _H_step in range(self.config.H_cycles):
for _L_step in range(self.config.L_cycles):
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
if not (_H_step == self.config.H_cycles - 1):
z_H = self.H_level(z_H, z_L, **seq_info)
assert not z_H.requires_grad and not z_L.requires_grad
# 1-step grad
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.H_level(z_H, z_L, **seq_info)
# LM Outputs
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
# Q head
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class HierarchicalReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return HierarchicalReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

View File

@@ -0,0 +1,342 @@
"""
HRM ACT V2: Transformer Baseline for Architecture Ablation
This is an architecture ablation of the Hierarchical Reasoning Model (HRM).
Key changes from V1:
1. REMOVED hierarchical split (no separate H and L levels)
2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning)
3. KEPT ACT outer loop structure intact
4. KEPT all data preprocessing, embeddings, and evaluation infrastructure
Architecture: Single-level transformer that processes the full 30x30 grid as a
900-token sequence, with the same positional encodings and sparse embeddings as V1.
"""
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
@dataclass
class Model_ACTV2InnerCarry:
z_H: torch.Tensor
@dataclass
class Model_ACTV2Carry:
inner_carry: Model_ACTV2InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class Model_ACTV2Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
H_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training)
act_inference: bool = False # If True, use adaptive computation during inference
forward_dtype: str = "bfloat16"
class Model_ACTV2Block(nn.Module):
def __init__(self, config: Model_ACTV2Config) -> None:
super().__init__()
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False,
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# Post Norm
# Self Attention
hidden_states = rms_norm(
hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states),
variance_epsilon=self.norm_eps,
)
# Fully Connected
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
return hidden_states
class Model_ACTV2ReasoningModule(nn.Module):
def __init__(self, layers: List[Model_ACTV2Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
# Input injection (add)
hidden_states = hidden_states + input_injection
# Layers
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class Model_ACTV2_Inner(nn.Module):
def __init__(self, config: Model_ACTV2Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(
self.config.vocab_size,
self.config.hidden_size,
init_std=embed_init_std,
cast_to=self.forward_dtype,
)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(
self.config.num_puzzle_identifiers,
self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size,
init_std=0,
cast_to=self.forward_dtype,
)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(
dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta,
)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(
self.config.seq_len + self.puzzle_emb_len,
self.config.hidden_size,
init_std=embed_init_std,
cast_to=self.forward_dtype,
)
else:
raise NotImplementedError()
# Reasoning Layers
self.H_level = Model_ACTV2ReasoningModule(
layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)]
)
# Initial states
self.H_init = nn.Buffer(
trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1),
persistent=True,
)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat(
(puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2
)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return Model_ACTV2InnerCarry(
z_H=torch.empty(
batch_size,
self.config.seq_len + self.puzzle_emb_len,
self.config.hidden_size,
dtype=self.forward_dtype,
),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry):
return Model_ACTV2InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
)
def forward(
self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor]
) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# 1-step grad
z_H = self.H_level(carry.z_H, input_embeddings, **seq_info)
# LM Outputs
new_carry = Model_ACTV2InnerCarry(
z_H=z_H.detach(),
) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len :]
# Q head
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class Model_ACTV2(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = Model_ACTV2Config(**config_dict)
self.inner = Model_ACTV2_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return Model_ACTV2Carry(
inner_carry=self.inner.empty_carry(
batch_size
), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size,), dtype=torch.int32),
halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()},
)
def forward(
self,
carry: Model_ACTV2Carry,
batch: Dict[str, torch.Tensor],
compute_target_q: bool = False,
) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
for k, v in carry.current_data.items()
}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(
new_inner_carry, new_current_data
)
outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# Check if adaptive computation should be used
use_adaptive = (self.config.halt_max_steps > 1) and (
(self.training and self.config.act_enabled)
or (not self.training and self.config.act_inference)
)
if use_adaptive:
# Halt signal based on Q-values (but always halt at max steps)
q_halt_signal = q_halt_logits > q_continue_logits
halted = halted | q_halt_signal
# Store actual steps used for logging (only during inference)
if not self.training:
outputs["actual_steps"] = new_steps.float()
# Exploration (only during training)
if self.training:
min_halt_steps = (
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
# Compute target Q (only during training)
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
if self.training and compute_target_q:
next_q_halt_logits, next_q_continue_logits = self.inner(
new_inner_carry, new_current_data
)[-1]
outputs["target_q_continue"] = torch.sigmoid(
torch.where(
is_last_step,
next_q_halt_logits,
torch.maximum(next_q_halt_logits, next_q_continue_logits),
)
)
return Model_ACTV2Carry(
new_inner_carry, new_steps, halted, new_current_data
), outputs

View File

@@ -0,0 +1,297 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
from models.common import trunc_normal_init_
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # ignored
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool = False # use mlp on L instead of transformer
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
hidden_states = hidden_states + input_injection
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
it = 0
z_H, z_L = carry.z_H, carry.z_L
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.L_level(z_H, z_L, **seq_info)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.L_level(z_H, z_L, **seq_info)
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

View File

@@ -0,0 +1,323 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
from models.common import trunc_normal_init_
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L1: torch.Tensor
z_L2: torch.Tensor
z_L3: torch.Tensor
z_L4: torch.Tensor
z_L5: torch.Tensor
z_L6: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # ignored
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool = False # use mlp on L instead of transformer
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
hidden_states = hidden_states + input_injection
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1),
z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2),
z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3),
z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4),
z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5),
z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6),
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
it = 0
z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6]
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_H = self.L_level(z_H, z_L_, **seq_info)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info)
z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5]
z_H = self.L_level(z_H, z_L_, **seq_info)
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

View File

@@ -0,0 +1,294 @@
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import copy
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
import random
from models.common import trunc_normal_init_
from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_L: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # ignored
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Alexia: added
mlp_t: bool = False # use mlp on L instead of transformer
puzzle_emb_len: int = 16 # if non-zero, its specified to this value
no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense
class TinyRecursiveReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
if self.config.mlp_t:
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
self.mlp_t = SwiGLU(
hidden_size=self.config.seq_len + self.puzzle_emb_len, # L
expansion=config.expansion,
)
else:
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# B, L, D = hidden_states.shape
# Post Norm
if self.config.mlp_t:
hidden_states = hidden_states.transpose(1,2)
out = self.mlp_t(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
hidden_states = hidden_states.transpose(1,2)
else:
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
out = self.mlp(hidden_states)
hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
pass
# Reasoning Layers
self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
it = 0
z_L = carry.z_L
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L + input_embeddings, **seq_info)
z_L = self.L_level(z_L, **seq_info)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L = self.L_level(z_L + input_embeddings, **seq_info)
z_L = self.L_level(z_L, **seq_info)
z_out = z_L
# LM Outputs
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_out)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

132
models/sparse_embedding.py Normal file
View File

@@ -0,0 +1,132 @@
from typing import Union
import torch
from torch import nn
import torch.distributed as dist
from torch.optim.optimizer import Optimizer, ParamsT
from models.common import trunc_normal_init_
class CastedSparseEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Real Weights
# Truncated LeCun normal init
self.weights = nn.Buffer(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
)
# Local weights and IDs
# Local embeddings, with gradient, not persistent
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
# Local embedding IDs, not persistent
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not self.training:
# Test mode, no gradient
return self.weights[inputs].to(self.cast_to)
# Training mode, fill puzzle embedding from weights
with torch.no_grad():
self.local_weights.copy_(self.weights[inputs])
self.local_ids.copy_(inputs)
return self.local_weights.to(self.cast_to)
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
def __init__(
self,
params: ParamsT,
world_size: int,
lr: Union[float, torch.Tensor] = 1e-3,
weight_decay: float = 1e-2,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
weight_decay=weight_decay,
world_size=world_size
)
super().__init__(params, defaults)
@torch.no_grad
def step(self, closure=None): # type: ignore
for group in self.param_groups:
# Find the sparse embedding weights
local_weights_grad = None
local_ids = None
weights = None
assert len(group["params"]) == 3
for p in group["params"]:
if p.requires_grad:
local_weights_grad = p.grad
elif p.ndim == 1:
local_ids = p
elif p.ndim == 2:
weights = p
else:
assert False
assert local_ids is not None
assert weights is not None
# Apply SignSGD
# Adam ≈ SignSGD if gradient is very sparse
if local_weights_grad is not None:
_sparse_emb_signsgd_dist(
local_weights_grad,
local_ids,
weights,
lr=group["lr"],
weight_decay=group["weight_decay"],
world_size=group["world_size"]
)
def _sparse_emb_signsgd_dist(
local_weights_grad: torch.Tensor,
local_ids: torch.Tensor,
weights: torch.Tensor,
lr: float,
weight_decay: float,
world_size: int
) -> None:
N, D = local_weights_grad.shape
# All-gather
all_weights_grad = local_weights_grad
all_ids = local_ids
if world_size > 1:
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
dist.all_gather_into_tensor(all_ids, local_ids)
# Unique
grad_ids, inv = all_ids.unique(return_inverse=True)
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
# SignSGD with decoupled weight decay
p = weights[grad_ids]
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
# Write updated slices back
weights[grad_ids] = p