import argparse import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import numpy as np import torch from torch import nn from transformers import AutoModel from tqdm import tqdm from contextlib import contextmanager import math @dataclass class BenchmarkConfig: model_path: str batch_sizes: List[int] seq_length: int num_runs: int use_fp16: bool = True use_cuda_graphs: bool = False use_flash_attention: bool = False max_batch_size: int = 256 # Maximum batch size before splitting class CUDAGraphContainer: """Container for managing CUDA graphs for different batch sizes.""" def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int): self.model = model self.seq_length = seq_length self.max_batch_size = max_batch_size self.graphs: Dict[int, CUDAGraphWrapper] = {} def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper': # For CUDA graphs, we always use the actual batch size or max_batch_size effective_batch_size = min(batch_size, self.max_batch_size) if effective_batch_size not in self.graphs: self.graphs[effective_batch_size] = CUDAGraphWrapper( self.model, effective_batch_size, self.seq_length ) return self.graphs[effective_batch_size] class CUDAGraphWrapper: """Wrapper for CUDA graph capture and replay.""" def __init__(self, model: nn.Module, batch_size: int, seq_length: int): self.model = model self.static_input = self._create_random_batch(batch_size, seq_length) self.static_attention_mask = torch.ones_like(self.static_input) # Warm up self._warmup() # Capture graph self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): self.static_output = self.model( input_ids=self.static_input, attention_mask=self.static_attention_mask ) def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor: return torch.randint( 0, 1000, (batch_size, seq_length), device="cuda", dtype=torch.long ) def _warmup(self, num_warmup: int = 3): with torch.no_grad(): for _ in range(num_warmup): self.model( input_ids=self.static_input, attention_mask=self.static_attention_mask ) def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: self.static_input.copy_(input_ids) self.static_attention_mask.copy_(attention_mask) self.graph.replay() return self.static_output class ModelOptimizer: """Applies various optimizations to the model.""" @staticmethod def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module: print("\nApplying model optimizations:") # Move to GPU model = model.cuda() print("- Model moved to GPU") # FP16 if config.use_fp16: model = model.half() print("- Using FP16 precision") # Check if using SDPA if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6: if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): print("- Using PyTorch SDPA (scaled_dot_product_attention)") # No need to do anything as it's automatically enabled else: print("- PyTorch SDPA not available") # Flash Attention if config.use_flash_attention: try: from flash_attn.flash_attention import FlashAttention print("- Flash Attention 2 available") if hasattr(model.config, "attention_mode"): model.config.attention_mode = "flash_attention_2" print(" - Enabled Flash Attention 2 mode") except ImportError: print("- Flash Attention not available") # Optimize LayerNorm try: num_layernorms = 0 for module in model.modules(): if isinstance(module, torch.nn.LayerNorm): module.forward = torch.jit.script(module.forward) num_layernorms += 1 if num_layernorms > 0: print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript") except Exception as e: print(f"- LayerNorm optimization failed: {e}") # Memory efficient attention try: from xformers.ops import memory_efficient_attention model.enable_xformers_memory_efficient_attention() print("- Enabled xformers memory efficient attention") except (ImportError, AttributeError): print("- Xformers not available") model.eval() print("- Model set to eval mode") return model class Timer: """Handles accurate GPU timing using CUDA events.""" def __init__(self): self.start_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True) @contextmanager def timing(self): self.start_event.record() yield self.end_event.record() self.end_event.synchronize() def elapsed_time(self) -> float: return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds class Benchmark: """Main benchmark runner.""" def __init__(self, config: BenchmarkConfig): self.config = config self.model = self._load_model() self.cuda_graphs = ( CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size) if config.use_cuda_graphs else None ) self.timer = Timer() def _load_model(self) -> nn.Module: print(f"Loading model from {self.config.model_path}...") model = AutoModel.from_pretrained(self.config.model_path) return ModelOptimizer.optimize(model, self.config) def _create_random_batch(self, batch_size: int) -> torch.Tensor: return torch.randint( 0, 1000, (batch_size, self.config.seq_length), device="cuda", dtype=torch.long ) def _run_inference( self, input_ids: torch.Tensor, cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None ) -> Tuple[float, torch.Tensor]: attention_mask = torch.ones_like(input_ids) original_batch_size = input_ids.shape[0] print(f"Original input_ids shape: {input_ids.shape}") # Split large batches to avoid OOM max_batch_size = self.config.max_batch_size if original_batch_size > max_batch_size: print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}") total_time = 0 outputs = [] with torch.no_grad(): for i in range(0, original_batch_size, max_batch_size): end_idx = min(i + max_batch_size, original_batch_size) batch_slice = input_ids[i:end_idx] mask_slice = attention_mask[i:end_idx] print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}") # Use CUDA graph if available (with the smaller batch size) chunk_cuda_graph = None if cuda_graph_wrapper is not None: chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0]) with self.timer.timing(): if chunk_cuda_graph is not None: chunk_output = chunk_cuda_graph(batch_slice, mask_slice) else: chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice) total_time += self.timer.elapsed_time() outputs.append(chunk_output.last_hidden_state) # Combine outputs combined_output = torch.cat(outputs, dim=0) print(f"Combined output shape: {combined_output.shape}") # Create a wrapper object similar to model output to maintain consistency class DummyOutput: def __init__(self, hidden_states): self.last_hidden_state = hidden_states output = DummyOutput(combined_output) return total_time, output else: # Process normally for small batches with torch.no_grad(), self.timer.timing(): if cuda_graph_wrapper is not None: output = cuda_graph_wrapper(input_ids, attention_mask) else: output = self.model(input_ids=input_ids, attention_mask=attention_mask) print(f"Output shape: {output.last_hidden_state.shape}") return self.timer.elapsed_time(), output def run(self) -> Dict[int, Dict[str, float]]: results = {} for batch_size in self.config.batch_sizes: print(f"\nTesting batch size: {batch_size}") times = [] # Get or create CUDA graph for this batch size cuda_graph_wrapper = None if self.cuda_graphs is not None: if batch_size <= self.config.max_batch_size: cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size) else: # For large batches, we'll use the max_batch_size graph in chunks cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs # Pre-allocate input tensor input_ids = self._create_random_batch(batch_size) # Run benchmark for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"): elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper) times.append(elapsed_time) print(f"Run {run_idx+1}: {elapsed_time:.4f}s") # Calculate statistics avg_time = np.mean(times) std_time = np.std(times) throughput = batch_size / avg_time results[batch_size] = { "avg_time": avg_time, "std_time": std_time, "throughput": throughput, } print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s") print(f"Throughput: {throughput:.2f} sequences/second") return results def main(): parser = argparse.ArgumentParser(description="Model Inference Benchmark") parser.add_argument( "--model_path", type=str, default="facebook/contriever", help="Path to the model", ) parser.add_argument( "--batch_sizes", type=str, default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096", help="Comma-separated list of batch sizes", ) parser.add_argument( "--seq_length", type=int, default=256, help="Sequence length for input", ) parser.add_argument( "--num_runs", type=int, default=5, help="Number of runs for each batch size", ) parser.add_argument( "--no_fp16", action="store_true", help="Disable FP16 inference", ) parser.add_argument( "--use_cuda_graphs", action="store_true", help="Enable CUDA Graphs optimization", ) parser.add_argument( "--use_flash_attention", action="store_true", help="Enable Flash Attention 2 if available", ) parser.add_argument( "--max_batch_size", type=int, default=256, help="Maximum batch size before splitting to prevent OOM", ) args = parser.parse_args() config = BenchmarkConfig( model_path=args.model_path, batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")], seq_length=args.seq_length, num_runs=args.num_runs, use_fp16=not args.no_fp16, use_cuda_graphs=args.use_cuda_graphs, use_flash_attention=args.use_flash_attention, max_batch_size=args.max_batch_size, ) benchmark = Benchmark(config) results = benchmark.run() # Print overall summary print("\n===== BENCHMARK SUMMARY =====") print(f"Model: {config.model_path}") print(f"Sequence Length: {config.seq_length}") print(f"FP16: {config.use_fp16}") print(f"CUDA Graphs: {config.use_cuda_graphs}") print(f"Flash Attention: {config.use_flash_attention}") print(f"Max Batch Size: {config.max_batch_size}") print("\nResults:") print("\nBatch Size | Avg Time (s) | Throughput (seq/s)") print("-" * 50) for bs in sorted(results.keys()): r = results[bs] print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}") if __name__ == "__main__": main()