from typing import Optional, Any, Sequence, List from dataclasses import dataclass import os import math import yaml import shutil import copy import torch import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader import tqdm import wandb import coolname import hydra import pydantic from omegaconf import DictConfig from adam_atan2 import AdamATan2 from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata from utils.functions import load_model_class, get_model_source_path from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed from models.ema import EMAHelper class LossConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str class ArchConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str loss: LossConfig class EvaluatorConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra="allow") name: str class PretrainConfig(pydantic.BaseModel): # Config arch: ArchConfig # Data data_paths: List[str] data_paths_test: List[str] = [] # Evaluators evaluators: List[EvaluatorConfig] = [] # Hyperparams global_batch_size: int epochs: int lr: float lr_min_ratio: float lr_warmup_steps: int weight_decay: float beta1: float beta2: float # Puzzle embedding puzzle_emb_lr: float puzzle_emb_weight_decay: float # Names project_name: Optional[str] = None run_name: Optional[str] = None load_checkpoint: Optional[str] = None checkpoint_path: Optional[str] = None # Extras seed: int = 0 checkpoint_every_eval: bool = False eval_interval: Optional[int] = None min_eval_interval: Optional[int] = 0 # when to start eval eval_save_outputs: List[str] = [] ema: bool = False # use Exponential-Moving-Average ema_rate: float = 0.999 # EMA-rate freeze_weights: bool = False # If True, freeze weights and only learn the embeddings @dataclass class TrainState: model: nn.Module optimizers: Sequence[torch.optim.Optimizer] optimizer_lrs: Sequence[float] carry: Any step: int total_steps: int def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): dataset = PuzzleDataset(PuzzleDatasetConfig( seed=config.seed, dataset_paths=config.data_paths_test if len(config.data_paths_test)>0 and split=="test" else config.data_paths, rank=rank, num_replicas=world_size, **kwargs ), split=split) dataloader = DataLoader( dataset, batch_size=None, num_workers=1, prefetch_factor=8, pin_memory=True, persistent_workers=True ) return dataloader, dataset.metadata def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): model_cfg = dict( **config.arch.__pydantic_extra__, # type: ignore batch_size=config.global_batch_size // world_size, vocab_size=train_metadata.vocab_size, seq_len=train_metadata.seq_len, num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, causal=False # Non-autoregressive ) # Instantiate model with loss head model_cls = load_model_class(config.arch.name) loss_head_cls = load_model_class(config.arch.loss.name) with torch.device("cuda"): model: nn.Module = model_cls(model_cfg) print(model) model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore if "DISABLE_COMPILE" not in os.environ: model = torch.compile(model) # type: ignore # Load checkpoint if rank == 0: load_checkpoint(model, config) # Broadcast parameters from rank 0 if world_size > 1: with torch.no_grad(): for param in list(model.parameters()) + list(model.buffers()): dist.broadcast(param, src=0) # Optimizers and lr if config.arch.puzzle_emb_ndim == 0: optimizers = [ AdamATan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) ] optimizer_lrs = [ config.lr ] elif config.freeze_weights: optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore lr=0, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ) ] optimizer_lrs = [ config.puzzle_emb_lr ] else: optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore lr=0, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ), AdamATan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) ] optimizer_lrs = [ config.puzzle_emb_lr, config.lr ] return model, optimizers, optimizer_lrs def mix_weights_direct(device, alpha, net, nets): sd = [] for i in range(len(nets)): sd += [nets[i].state_dict()] sd_alpha = {} for k in sd[0].keys(): comb_net = alpha[0]*sd[0][k].to(device) for i in range(1,len(nets)): comb_net += alpha[i]*sd[i][k].to(device) sd_alpha[k] = comb_net net.load_state_dict(sd_alpha) return net def cosine_schedule_with_warmup_lr_lambda( current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 ): if current_step < num_warmup_steps: return base_lr * float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): # Estimated total training steps total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) # Model model, optimizers, optimizer_lrs = create_model(config, train_metadata, rank=rank, world_size=world_size) return TrainState( step=0, total_steps=total_steps, model=model, optimizers=optimizers, optimizer_lrs=optimizer_lrs, carry=None ) def save_train_state(config: PretrainConfig, train_state: TrainState): # FIXME: Only saved model. if config.checkpoint_path is None: return os.makedirs(config.checkpoint_path, exist_ok=True) torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) def load_checkpoint(model: nn.Module, config: PretrainConfig): if config.load_checkpoint is not None: print(f"Loading checkpoint {config.load_checkpoint}") # Load state dict state_dict = torch.load(config.load_checkpoint, map_location="cuda") # Resize and reset puzzle emb if needed puzzle_emb_name = "_orig_mod.model.inner.puzzle_emb.weights" expected_shape: torch.Size = model.model.puzzle_emb.weights.shape # type: ignore if puzzle_emb_name in state_dict: puzzle_emb = state_dict[puzzle_emb_name] if puzzle_emb.shape != expected_shape: print(f"Resetting puzzle embedding as shape is different. Found {puzzle_emb.shape}, Expected {expected_shape}") # Re-initialize using mean state_dict[puzzle_emb_name] = ( torch.mean(puzzle_emb, dim=0, keepdim=True).expand(expected_shape).contiguous() ) model.load_state_dict(state_dict, assign=True) def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): return cosine_schedule_with_warmup_lr_lambda( current_step=train_state.step, base_lr=base_lr, num_warmup_steps=round(config.lr_warmup_steps), num_training_steps=train_state.total_steps, min_ratio=config.lr_min_ratio ) def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetadata) -> List[Any]: data_paths =config.data_paths_test if len(config.data_paths_test)>0 else config.data_paths # Initialize evaluators evaluators = [] for cfg in config.evaluators: for data_path in data_paths: cls = load_model_class(cfg.name, "evaluators.")( data_path=data_path, eval_metadata=eval_metadata, **cfg.__pydantic_extra__ ) # type: ignore evaluators.append(cls) return evaluators def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): train_state.step += 1 if train_state.step > train_state.total_steps: # At most train_total_steps return # To device batch = {k: v.cuda() for k, v in batch.items()} # Init carry if it is None if train_state.carry is None: with torch.device("cuda"): train_state.carry = train_state.model.initial_carry(batch) # type: ignore # Forward train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) ((1 / global_batch_size) * loss).backward() # Allreduce if world_size > 1: for param in train_state.model.parameters(): if param.grad is not None: dist.all_reduce(param.grad) # Apply optimizer lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group['lr'] = lr_this_step optim.step() optim.zero_grad() # Reduce metrics if len(metrics): assert not any(v.requires_grad for v in metrics.values()) metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. # Reduce and reconstruct metric_values = torch.stack([metrics[k] for k in metric_keys]) if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: metric_values = metric_values.cpu().numpy() reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} # Postprocess count = max(reduced_metrics["count"], 1) # Avoid NaNs reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} reduced_metrics["train/lr"] = lr_this_step return reduced_metrics def evaluate( config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, evaluators: List[Any], rank: int, world_size: int, cpu_group: Optional[dist.ProcessGroup], ): reduced_metrics = None with torch.inference_mode(): return_keys = set(config.eval_save_outputs) for evaluator in evaluators: evaluator.begin_eval() return_keys.update(evaluator.required_outputs) # Run evaluation set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} save_preds = {} metric_keys = [] metric_values = None carry = None processed_batches = 0 for set_name, batch, global_batch_size in eval_loader: processed_batches += 1 if rank == 0: print(f"Processing batch {processed_batches}: {set_name}") # To device batch = {k: v.cuda() for k, v in batch.items()} with torch.device("cuda"): carry = train_state.model.initial_carry(batch) # type: ignore # Forward inference_steps = 0 while True: carry, loss, metrics, preds, all_finish = train_state.model( carry=carry, batch=batch, return_keys=return_keys ) inference_steps += 1 if all_finish: break if rank == 0: print(f" Completed inference in {inference_steps} steps") for collection in (batch, preds): for k, v in collection.items(): if k in config.eval_save_outputs: save_preds.setdefault(k, []) save_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory for evaluator in evaluators: evaluator.update_batch(batch, preds) del carry, loss, preds, batch, all_finish # Aggregate metrics set_id = set_ids[set_name] if metric_values is None: metric_keys = list( sorted(metrics.keys()) ) # Sort keys to guarantee all processes use the same order. metric_values = torch.zeros( (len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda" ) metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) del metrics # concatenate save preds save_preds = {k: torch.cat(v, dim=0) for k, v in save_preds.items()} # Save preds if config.checkpoint_path is not None and len(save_preds): # Each rank save predictions independently os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True) torch.save( save_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}") ) del save_preds # Reduce to rank 0 if metric_values is not None: if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: reduced_metrics = metric_values.cpu().numpy() reduced_metrics = { set_name: { metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys) } for set_id, set_name in enumerate(set_ids) } # Postprocess for set_name, m in reduced_metrics.items(): count = m.pop("count") reduced_metrics[set_name] = {k: v / count for k, v in m.items()} # Run evaluators if rank == 0: print(f"\nRunning {len(evaluators)} evaluator(s)...") for i, evaluator in enumerate(evaluators): if rank == 0: print(f"Running evaluator {i+1}/{len(evaluators)}: {evaluator.__class__.__name__}") # Path for saving evaluator_save_path = None if config.checkpoint_path is not None: evaluator_save_path = os.path.join( config.checkpoint_path, f"evaluator_{evaluator.__class__.__name__}_step_{train_state.step}", ) os.makedirs(evaluator_save_path, exist_ok=True) # Run and log metrics = evaluator.result(evaluator_save_path, rank=rank, world_size=world_size, group=cpu_group) if rank == 0 and metrics is not None: if reduced_metrics is None: reduced_metrics = {} reduced_metrics.update(metrics) print(f" Completed {evaluator.__class__.__name__}") if rank == 0: print("All evaluators completed!") return reduced_metrics def save_code_and_config(config: PretrainConfig): if config.checkpoint_path is None or wandb.run is None: return os.makedirs(config.checkpoint_path, exist_ok=True) # Copy code code_list = [ get_model_source_path(config.arch.name), get_model_source_path(config.arch.loss.name) ] for code_file in code_list: if code_file is not None: code_name = os.path.basename(code_file) shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) # Dump config as yaml config_file = os.path.join(config.checkpoint_path, "all_config.yaml") with open(config_file, "wt") as f: yaml.dump(config.model_dump(), f) # Log code wandb.run.log_code(config.checkpoint_path) def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: objects = [None] if rank == 0: config = PretrainConfig(**hydra_config) # type: ignore # Naming if config.project_name is None: config.project_name = f"{os.path.basename(config.data_paths[0]).capitalize()}-ACT-torch" if config.run_name is None: config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" if config.checkpoint_path is None: config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) objects = [config] if world_size > 1: dist.broadcast_object_list(objects, src=0) return objects[0] # type: ignore @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) def launch(hydra_config: DictConfig): RANK = 0 WORLD_SIZE = 1 CPU_PROCESS_GROUP = None # Initialize distributed training if in distributed environment (e.g. torchrun) if "LOCAL_RANK" in os.environ: # Initialize distributed, default device and dtype dist.init_process_group(backend="nccl") RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # CPU GLOO process group CPU_PROCESS_GROUP = dist.new_group(backend="gloo") assert ( dist.get_rank(CPU_PROCESS_GROUP) == RANK and dist.get_world_size(CPU_PROCESS_GROUP) == WORLD_SIZE ) # Load sync'ed config config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) # Seed RNGs to ensure consistency torch.random.manual_seed(config.seed + RANK) # Dataset train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs total_iters = config.epochs // train_epochs_per_iter assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) try: eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) except: print("NO EVAL DATA FOUND") eval_loader = eval_metadata = None try: evaluators = create_evaluators(config, eval_metadata) except: print("No evaluator found") evaluators = [] # Train state train_state = init_train_state(config, train_metadata, rank=RANK, world_size=WORLD_SIZE) # Progress bar and logger progress_bar = None ema_helper = None if RANK == 0: progress_bar = tqdm.tqdm(total=train_state.total_steps) wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) save_code_and_config(config) if config.ema: print('Setup EMA') ema_helper = EMAHelper(mu=config.ema_rate) ema_helper.register(train_state.model) # Training Loop for _iter_id in range(total_iters): print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") ############ Train Iter if RANK == 0: print("TRAIN") train_state.model.train() for set_name, batch, global_batch_size in train_loader: metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) progress_bar.update(train_state.step - progress_bar.n) # type: ignore if config.ema: ema_helper.update(train_state.model) if _iter_id >= config.min_eval_interval: ############ Evaluation if RANK == 0: print("EVALUATE") if config.ema: print("SWITCH TO EMA") train_state_eval = copy.deepcopy(train_state) train_state_eval.model = ema_helper.ema_copy(train_state_eval.model) else: train_state_eval = train_state train_state_eval.model.eval() metrics = evaluate(config, train_state_eval, eval_loader, eval_metadata, evaluators, rank=RANK, world_size=WORLD_SIZE, cpu_group=CPU_PROCESS_GROUP) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) ############ Checkpointing if RANK == 0: print("SAVE CHECKPOINT") if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): save_train_state(config, train_state_eval) if config.ema: del train_state_eval # finalize if dist.is_initialized(): dist.destroy_process_group() wandb.finish() if __name__ == "__main__": launch()