Initial commit
This commit is contained in:
12
research/micro/analyze_HNSW.py
Normal file
12
research/micro/analyze_HNSW.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import faiss
|
||||
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
|
||||
# print total number of nodes
|
||||
print(hnsw_index.ntotal)
|
||||
|
||||
# print stats of the graph
|
||||
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
||||
|
||||
|
||||
# save_degree_distribution
|
||||
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
||||
11
research/micro/analyze_NSG.py
Normal file
11
research/micro/analyze_NSG.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import faiss
|
||||
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||
|
||||
# print total number of nodes
|
||||
print(nsg_index.ntotal)
|
||||
|
||||
# print stats of the graph
|
||||
print(nsg_index.nsg.print_neighbor_stats(0))
|
||||
|
||||
# save degree distribution
|
||||
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
||||
63
research/micro/bnbtest.py
Normal file
63
research/micro/bnbtest.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
|
||||
# import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
M = 2048
|
||||
N = 2048
|
||||
|
||||
bsz = 2048
|
||||
import torch_int
|
||||
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
||||
|
||||
fp16_model = nn.Sequential(
|
||||
nn.Linear(M, N),
|
||||
# nn.Linear(2048, 2048)
|
||||
)
|
||||
|
||||
int8_model = nn.Sequential(
|
||||
Linear8bitLt(M, N, has_fp16_weights=False),
|
||||
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
||||
)
|
||||
|
||||
int8_model.load_state_dict(fp16_model.state_dict())
|
||||
int8_model = int8_model.to(0) # Quantization happens here
|
||||
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
||||
|
||||
# Create random input tensor
|
||||
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
||||
|
||||
# Speed test function
|
||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(input_tensor)
|
||||
|
||||
# Actual timing
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(num_iterations):
|
||||
_ = model(input_tensor)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
avg_time = (end_time - start_time) / num_iterations
|
||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||
return avg_time
|
||||
|
||||
# Run speed tests
|
||||
with torch.no_grad(): # Disable gradient calculation for inference
|
||||
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
||||
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
||||
|
||||
# Calculate speedup
|
||||
speedup = fp16_time / int8_time
|
||||
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
||||
89
research/micro/data/transformer-batching-microbenchmarks.csv
Normal file
89
research/micro/data/transformer-batching-microbenchmarks.csv
Normal file
@@ -0,0 +1,89 @@
|
||||
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
|
||||
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
|
||||
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
|
||||
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
|
||||
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
|
||||
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
|
||||
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
|
||||
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
|
||||
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
|
||||
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
|
||||
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
|
||||
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
|
||||
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
|
||||
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
|
||||
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
|
||||
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
|
||||
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
|
||||
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
|
||||
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
|
||||
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
|
||||
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
|
||||
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
|
||||
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
|
||||
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
|
||||
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
|
||||
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
|
||||
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
|
||||
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
|
||||
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
|
||||
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
|
||||
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
|
||||
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
|
||||
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
|
||||
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
|
||||
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
|
||||
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
|
||||
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
|
||||
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
|
||||
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
|
||||
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
|
||||
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
|
||||
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
|
||||
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
|
||||
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
|
||||
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
|
||||
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
|
||||
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
|
||||
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
|
||||
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
|
||||
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
|
||||
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
|
||||
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
|
||||
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
|
||||
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
|
||||
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
|
||||
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
|
||||
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
|
||||
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
|
||||
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
|
||||
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
|
||||
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
|
||||
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
|
||||
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
|
||||
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
|
||||
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
|
||||
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
|
||||
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
|
||||
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
|
||||
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
|
||||
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
|
||||
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
|
||||
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
|
||||
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
|
||||
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
|
||||
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
|
||||
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
|
||||
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
|
||||
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
|
||||
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
|
||||
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
|
||||
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
|
||||
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
|
||||
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
|
||||
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
|
||||
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
|
||||
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
|
||||
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
|
||||
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
|
||||
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
|
||||
|
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
594
research/micro/embedd_micro.py
Executable file
594
research/micro/embedd_micro.py
Executable file
@@ -0,0 +1,594 @@
|
||||
# python embedd_micro.py --use_int8 Fastest
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao import quantize_
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_int4: bool = False
|
||||
use_int8: bool = False # Add this parameter
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
|
||||
|
||||
class CUDAGraphContainer:
|
||||
"""Container for managing CUDA graphs for different batch sizes."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
||||
self.model, batch_size, self.seq_length
|
||||
)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wrapper for CUDA graph capture and replay."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
# Move to GPU
|
||||
model = model.cuda()
|
||||
print("- Model moved to GPU")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention
|
||||
if config.use_flash_attention:
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Memory efficient attention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using CUDA events."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
self.cuda_graphs = (
|
||||
CUDAGraphContainer(self.model, config.seq_length)
|
||||
if config.use_cuda_graphs
|
||||
else None
|
||||
)
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
||||
raise
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# 检查是否使用自定义的8bit量化
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# 加载原始模型(不使用量化配置)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# 定义替换函数
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# 获取原始线性层的参数
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# 创建8bit线性层
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False
|
||||
)
|
||||
|
||||
# 复制权重和偏置
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# 替换模块
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# 递归处理子模块
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
return model
|
||||
|
||||
# 替换所有线性层
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# 将模型移到GPU(量化发生在这里)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = model.to(device)
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
else:
|
||||
# 使用原来的Int4量化方法
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
# Create quantization config
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto" # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Try xformers if available
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
# Int8 quantization using HuggingFace integration
|
||||
# Int8 quantization using TorchAO
|
||||
elif self.config.use_int8:
|
||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
||||
|
||||
# Import the quantize_ function and the quantization config
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
||||
print("- Successfully imported TorchAO")
|
||||
|
||||
# Load model normally first
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
print("- Model loaded in full precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply quantization - call the function to get the config, then apply it
|
||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
# For older PyTorch versions that have issues with tensor subclasses
|
||||
from torchao.utils import unwrap_tensor_subclass
|
||||
import torch
|
||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
||||
unwrap_tensor_subclass(model)
|
||||
|
||||
# Apply optimizations
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
# For better performance with int8 dynamic quantization
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
print("- Enabled fusion of int matmul with mul operations")
|
||||
|
||||
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if cuda_graph_wrapper is not None:
|
||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
# Reset peak memory stats
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create CUDA graph for this batch size
|
||||
cuda_graph_wrapper = (
|
||||
self.cuda_graphs.get_or_create(batch_size)
|
||||
if self.cuda_graphs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
||||
if i == 0: # Only print on first run
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
# Log memory usage
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
action="store_true",
|
||||
help="Enable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int4",
|
||||
action="store_true",
|
||||
help="Enable INT4 quantization using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
action="store_true",
|
||||
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_linear8bitlt",
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int4=args.use_int4,
|
||||
use_int8=args.use_int8, # Add this line
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
||||
"results": {str(k): v for k, v in results.items()}
|
||||
},
|
||||
f,
|
||||
indent=2
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
376
research/micro/embedd_micro_seq.py
Normal file
376
research/micro/embedd_micro_seq.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import argparse
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
import math
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
max_batch_size: int = 256 # Maximum batch size before splitting
|
||||
|
||||
|
||||
class CUDAGraphContainer:
|
||||
"""Container for managing CUDA graphs for different batch sizes."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
||||
# For CUDA graphs, we always use the actual batch size or max_batch_size
|
||||
effective_batch_size = min(batch_size, self.max_batch_size)
|
||||
|
||||
if effective_batch_size not in self.graphs:
|
||||
self.graphs[effective_batch_size] = CUDAGraphWrapper(
|
||||
self.model, effective_batch_size, self.seq_length
|
||||
)
|
||||
return self.graphs[effective_batch_size]
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wrapper for CUDA graph capture and replay."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
# Move to GPU
|
||||
model = model.cuda()
|
||||
print("- Model moved to GPU")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16:
|
||||
model = model.half()
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA
|
||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
# No need to do anything as it's automatically enabled
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention
|
||||
if config.use_flash_attention:
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Optimize LayerNorm
|
||||
try:
|
||||
num_layernorms = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.LayerNorm):
|
||||
module.forward = torch.jit.script(module.forward)
|
||||
num_layernorms += 1
|
||||
if num_layernorms > 0:
|
||||
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
|
||||
except Exception as e:
|
||||
print(f"- LayerNorm optimization failed: {e}")
|
||||
|
||||
# Memory efficient attention
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using CUDA events."""
|
||||
|
||||
def __init__(self):
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model = self._load_model()
|
||||
self.cuda_graphs = (
|
||||
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
|
||||
if config.use_cuda_graphs
|
||||
else None
|
||||
)
|
||||
self.timer = Timer()
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
return ModelOptimizer.optimize(model, self.config)
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device="cuda",
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
original_batch_size = input_ids.shape[0]
|
||||
print(f"Original input_ids shape: {input_ids.shape}")
|
||||
|
||||
# Split large batches to avoid OOM
|
||||
max_batch_size = self.config.max_batch_size
|
||||
if original_batch_size > max_batch_size:
|
||||
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
|
||||
total_time = 0
|
||||
outputs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(0, original_batch_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, original_batch_size)
|
||||
batch_slice = input_ids[i:end_idx]
|
||||
mask_slice = attention_mask[i:end_idx]
|
||||
|
||||
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
|
||||
|
||||
# Use CUDA graph if available (with the smaller batch size)
|
||||
chunk_cuda_graph = None
|
||||
if cuda_graph_wrapper is not None:
|
||||
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
|
||||
|
||||
with self.timer.timing():
|
||||
if chunk_cuda_graph is not None:
|
||||
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
|
||||
else:
|
||||
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
|
||||
|
||||
total_time += self.timer.elapsed_time()
|
||||
outputs.append(chunk_output.last_hidden_state)
|
||||
|
||||
# Combine outputs
|
||||
combined_output = torch.cat(outputs, dim=0)
|
||||
print(f"Combined output shape: {combined_output.shape}")
|
||||
|
||||
# Create a wrapper object similar to model output to maintain consistency
|
||||
class DummyOutput:
|
||||
def __init__(self, hidden_states):
|
||||
self.last_hidden_state = hidden_states
|
||||
|
||||
output = DummyOutput(combined_output)
|
||||
return total_time, output
|
||||
else:
|
||||
# Process normally for small batches
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if cuda_graph_wrapper is not None:
|
||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create CUDA graph for this batch size
|
||||
cuda_graph_wrapper = None
|
||||
if self.cuda_graphs is not None:
|
||||
if batch_size <= self.config.max_batch_size:
|
||||
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
|
||||
else:
|
||||
# For large batches, we'll use the max_batch_size graph in chunks
|
||||
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
# Run benchmark
|
||||
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
|
||||
times.append(elapsed_time)
|
||||
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_fp16",
|
||||
action="store_true",
|
||||
help="Disable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_batch_size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum batch size before splitting to prevent OOM",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=not args.no_fp16,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
max_batch_size=args.max_batch_size,
|
||||
)
|
||||
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Print overall summary
|
||||
print("\n===== BENCHMARK SUMMARY =====")
|
||||
print(f"Model: {config.model_path}")
|
||||
print(f"Sequence Length: {config.seq_length}")
|
||||
print(f"FP16: {config.use_fp16}")
|
||||
print(f"CUDA Graphs: {config.use_cuda_graphs}")
|
||||
print(f"Flash Attention: {config.use_flash_attention}")
|
||||
print(f"Max Batch Size: {config.max_batch_size}")
|
||||
print("\nResults:")
|
||||
|
||||
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
|
||||
print("-" * 50)
|
||||
for bs in sorted(results.keys()):
|
||||
r = results[bs]
|
||||
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
218
research/micro/int4benchmark.py
Normal file
218
research/micro/int4benchmark.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Import necessary functions from the quantize.py file
|
||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
||||
# needed for GPTQ with padding
|
||||
if groupsize > w.shape[-1]:
|
||||
groupsize = w.shape[-1]
|
||||
assert groupsize > 1
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
||||
torch.bfloat16
|
||||
).reshape(w.shape[0], -1)
|
||||
|
||||
def pack_scales_and_zeros(scales, zeros):
|
||||
assert scales.shape == zeros.shape
|
||||
assert scales.dtype == torch.bfloat16
|
||||
assert zeros.dtype == torch.bfloat16
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
||||
],
|
||||
2,
|
||||
)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
||||
return w_int32, scales_and_zeros
|
||||
|
||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
||||
assert groupsize > 1
|
||||
# needed for GPTQ single column quantize
|
||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
||||
groupsize = w.shape[-1]
|
||||
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
assert w.dim() == 2
|
||||
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
|
||||
scales = scales.reshape(-1, 1)
|
||||
zeros = zeros.reshape(-1, 1)
|
||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
w_int32 = (
|
||||
to_quant.sub(min_val)
|
||||
.div(scales)
|
||||
.round()
|
||||
.clamp_(min_int, max_int)
|
||||
.to(torch.int32)
|
||||
.reshape_as(w)
|
||||
)
|
||||
|
||||
return w_int32
|
||||
|
||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
||||
weight_bf16, n_bit=4, groupsize=groupsize
|
||||
)
|
||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
||||
return weight_int4pack, scales_and_zeros
|
||||
|
||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
||||
origin_x_size = x.size()
|
||||
x = x.reshape(-1, origin_x_size[-1])
|
||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
||||
new_shape = origin_x_size[:-1] + (out_features,)
|
||||
c = c.reshape(new_shape)
|
||||
return c
|
||||
|
||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
||||
__constants__ = ['in_features', 'out_features']
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int,
|
||||
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.groupsize = groupsize
|
||||
self.inner_k_tiles = inner_k_tiles
|
||||
|
||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
||||
self.register_buffer(
|
||||
"weight",
|
||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
"scales_and_zeros",
|
||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
input = input.to(torch.bfloat16)
|
||||
return linear_forward_int4(
|
||||
input,
|
||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
||||
)
|
||||
|
||||
# Define dimensions that satisfy the requirements for INT4 quantization
|
||||
# in_features must be divisible by inner_k_tiles * 16
|
||||
# out_features must be divisible by 8
|
||||
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
||||
out_features = 2048 # Must be divisible by 8
|
||||
groupsize = 128
|
||||
inner_k_tiles = 8
|
||||
|
||||
# Create models
|
||||
fp16_model = nn.Sequential(
|
||||
nn.Linear(in_features, out_features, bias=False)
|
||||
)
|
||||
|
||||
# Create INT4 model
|
||||
int4_model = nn.Sequential(
|
||||
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
||||
)
|
||||
|
||||
# Quantize the weights and set up the INT4 model
|
||||
with torch.no_grad():
|
||||
# Convert FP16 weights to INT4
|
||||
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
||||
fp16_weight, groupsize, inner_k_tiles
|
||||
)
|
||||
|
||||
# Set the quantized weights in the INT4 model
|
||||
int4_model[0].weight.copy_(weight_int4pack)
|
||||
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
||||
|
||||
# Move models to GPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
fp16_model = fp16_model.to(device)
|
||||
int4_model = int4_model.to(device)
|
||||
|
||||
# Create random input tensor
|
||||
batch_size = 1024
|
||||
input_tensor = torch.randn(batch_size, in_features, device=device)
|
||||
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
||||
|
||||
# Speed test function
|
||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = model(input_tensor)
|
||||
|
||||
# Actual timing
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(num_iterations):
|
||||
_ = model(input_tensor)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
avg_time = (end_time - start_time) / num_iterations
|
||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||
return avg_time
|
||||
|
||||
# Run speed tests
|
||||
with torch.no_grad(): # Disable gradient calculation for inference
|
||||
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
||||
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
||||
|
||||
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
||||
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
||||
|
||||
# Calculate speedup
|
||||
speedup = fp16_time / int4_time
|
||||
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
||||
|
||||
# Calculate memory savings
|
||||
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
||||
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
||||
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
||||
|
||||
memory_reduction = fp16_memory / int4_memory
|
||||
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
||||
|
||||
# Check accuracy
|
||||
with torch.no_grad():
|
||||
fp16_output = fp16_model(input_tensor_bf16)
|
||||
int4_output = int4_model(input_tensor)
|
||||
|
||||
# Calculate error metrics
|
||||
abs_error = torch.abs(fp16_output - int4_output)
|
||||
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
||||
|
||||
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
||||
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
||||
print(f"Mean relative error: {rel_error.mean().item():.6f}")
|
||||
83
research/micro/int8.py
Normal file
83
research/micro/int8.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import nvmath.bindings.cublas
|
||||
import ctypes
|
||||
|
||||
# 创建 CUBLAS 句柄
|
||||
handle = nvmath.bindings.cublas.create()
|
||||
|
||||
# 准备数据 - 使用 uint8 类型,并确保内存连续
|
||||
m, n, k = 64, 32, 48
|
||||
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
|
||||
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
|
||||
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
|
||||
|
||||
# 确保张量在 CUDA 上
|
||||
assert a.is_cuda and b.is_cuda and c.is_cuda
|
||||
# 确保张量是连续的
|
||||
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
|
||||
|
||||
# 获取指针
|
||||
a_ptr = a.data_ptr()
|
||||
b_ptr = b.data_ptr()
|
||||
c_ptr = c.data_ptr()
|
||||
|
||||
# 设置参数
|
||||
transa = 0 # CUBLAS_OP_N (不转置)
|
||||
transb = 0 # CUBLAS_OP_N (不转置)
|
||||
transc = 0 # CUBLAS_OP_N (不转置)
|
||||
|
||||
# 设置偏置值
|
||||
a_bias = 0
|
||||
b_bias = 0
|
||||
c_bias = 0
|
||||
|
||||
# 设置正确的 leading dimensions
|
||||
lda = k # A 的 leading dimension
|
||||
ldb = n # B 的 leading dimension
|
||||
ldc = n # C 的 leading dimension
|
||||
|
||||
c_mult = 1
|
||||
c_shift = 0
|
||||
|
||||
# 打印调试信息
|
||||
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
|
||||
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
|
||||
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
|
||||
|
||||
try:
|
||||
# 调用 uint8gemm_bias
|
||||
nvmath.bindings.cublas.uint8gemm_bias(
|
||||
handle,
|
||||
transa, transb, transc,
|
||||
m, n, k,
|
||||
a_ptr, a_bias, lda,
|
||||
b_ptr, b_bias, ldb,
|
||||
c_ptr, c_bias, ldc,
|
||||
c_mult, c_shift
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
# 尝试使用 ctypes 转换指针
|
||||
a_ptr_c = ctypes.c_void_p(a_ptr).value
|
||||
b_ptr_c = ctypes.c_void_p(b_ptr).value
|
||||
c_ptr_c = ctypes.c_void_p(c_ptr).value
|
||||
|
||||
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
|
||||
|
||||
# 再次尝试调用
|
||||
nvmath.bindings.cublas.uint8gemm_bias(
|
||||
handle,
|
||||
transa, transb, transc,
|
||||
m, n, k,
|
||||
a_ptr_c, a_bias, lda,
|
||||
b_ptr_c, b_bias, ldb,
|
||||
c_ptr_c, c_bias, ldc,
|
||||
c_mult, c_shift
|
||||
)
|
||||
|
||||
# 销毁 CUBLAS 句柄
|
||||
nvmath.bindings.cublas.destroy(handle)
|
||||
|
||||
# 打印结果
|
||||
print("Result:")
|
||||
print(c)
|
||||
23
research/micro/llm_compress.py
Normal file
23
research/micro/llm_compress.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
from llmcompressor import oneshot
|
||||
|
||||
# Select quantization algorithm. In this case, we:
|
||||
# * apply SmoothQuant to make the activations easier to quantize
|
||||
# * quantize the weights to int8 with GPTQ (static per channel)
|
||||
# * quantize the activations to int8 (dynamic per token)
|
||||
recipe = [
|
||||
SmoothQuantModifier(smoothing_strength=0.8),
|
||||
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
|
||||
]
|
||||
|
||||
# Apply quantization using the built in open_platypus dataset.
|
||||
# * See examples for demos showing how to pass a custom calibration set
|
||||
oneshot(
|
||||
model="facebook/contriever",
|
||||
dataset="open_platypus",
|
||||
recipe=recipe,
|
||||
output_dir="contriever-INT4",
|
||||
max_seq_length=2048,
|
||||
num_calibration_samples=512,
|
||||
)
|
||||
41
research/micro/nvmath_test.py
Normal file
41
research/micro/nvmath_test.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""
|
||||
This example demonstrates basic matrix multiplication of FP8 tensors.
|
||||
|
||||
In narrow-precision operations, quantization scales must be provided for each tensor. These
|
||||
scales are used to dequantize input operands and quantize the result. Without proper
|
||||
scaling, the results of FP8 operations will likely exceed the type's range.
|
||||
|
||||
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
|
||||
capability 8.9 or higher.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import nvmath
|
||||
|
||||
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
|
||||
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
|
||||
# transpose it.
|
||||
m, n, k = 64, 32, 48
|
||||
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
|
||||
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
|
||||
|
||||
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
|
||||
# range of the data type used. Scales can be provided either as a dictionary or as a
|
||||
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
|
||||
scales = {"a": 1, "b": 1, "d": 0.1}
|
||||
|
||||
# Perform the multiplication. The result of the multiplication will be:
|
||||
# (scales.a * A) @ (scales.b * B) * scales.d
|
||||
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
|
||||
|
||||
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
|
||||
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
|
||||
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
|
||||
print(result_without_scaling)
|
||||
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
|
||||
print(result)
|
||||
0
research/micro/result.md
Normal file
0
research/micro/result.md
Normal file
58
research/micro/save_small_model.py
Normal file
58
research/micro/save_small_model.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from pathlib import Path
|
||||
|
||||
def save_model_in_pth_format(model_name, output_dir):
|
||||
"""
|
||||
Download a model from Hugging Face and save it in PTH format
|
||||
for use with quantization benchmarks.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model on Hugging Face
|
||||
output_dir: Directory to save the model
|
||||
"""
|
||||
print(f"Loading model {model_name}...")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Extract and save the model weights in PTH format
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
# Save the model weights
|
||||
model_path = Path(output_dir) / "model.pth"
|
||||
torch.save(model_state_dict, model_path)
|
||||
|
||||
print(f"Model saved to {model_path}")
|
||||
|
||||
# Print model size information
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
||||
|
||||
print(f"Model parameters: {param_count:,}")
|
||||
print(f"Model size: {model_size_mb:.2f} MB")
|
||||
|
||||
return model_path
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Use a small model for testing
|
||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
output_dir = "./tinyllama-1.1b-chat"
|
||||
|
||||
model_path = save_model_in_pth_format(model_name, output_dir)
|
||||
|
||||
print("\nYou can now use this model with the INT4 benchmark script.")
|
||||
print("Example command:")
|
||||
print(f"python int4benchmark.py --model_path {model_path}")
|
||||
677
research/micro/transformer-batching-benchmark.ipynb
Normal file
677
research/micro/transformer-batching-benchmark.ipynb
Normal file
@@ -0,0 +1,677 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "cab91cfc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import copy\n",
|
||||
"import dataclasses\n",
|
||||
"import os\n",
|
||||
"import time\n",
|
||||
"import pathlib\n",
|
||||
"import itertools\n",
|
||||
"import multiprocessing\n",
|
||||
"import scipy\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import pickle\n",
|
||||
"import gzip\n",
|
||||
"import threading\n",
|
||||
"import queue\n",
|
||||
"import pytz\n",
|
||||
"import traceback\n",
|
||||
"from datetime import datetime\n",
|
||||
"from tqdm.auto import tqdm, trange\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib.ticker as mtick\n",
|
||||
"%matplotlib inline\n",
|
||||
"%config InlineBackend.figure_format='retina'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "8d24fbd7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sat Apr 12 00:10:05 2025 \n",
|
||||
"+-----------------------------------------------------------------------------------------+\n",
|
||||
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
|
||||
"|-----------------------------------------+------------------------+----------------------+\n",
|
||||
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
||||
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
||||
"| | | MIG M. |\n",
|
||||
"|=========================================+========================+======================|\n",
|
||||
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
|
||||
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
|
||||
"| | | N/A |\n",
|
||||
"+-----------------------------------------+------------------------+----------------------+\n",
|
||||
" \n",
|
||||
"+-----------------------------------------------------------------------------------------+\n",
|
||||
"| Processes: |\n",
|
||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
||||
"| ID ID Usage |\n",
|
||||
"|=========================================================================================|\n",
|
||||
"| No running processes found |\n",
|
||||
"+-----------------------------------------------------------------------------------------+\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "538b2c11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
|
||||
" latency = []\n",
|
||||
" \n",
|
||||
" # First run, ignore min_secs\n",
|
||||
" if f_setup is not None:\n",
|
||||
" f_setup()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" f()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" latency.append((ed-st)/1e9)\n",
|
||||
" \n",
|
||||
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
|
||||
" min_nanos = int(min_secs * 1e9)\n",
|
||||
" start_nanos = time.perf_counter_ns()\n",
|
||||
" while True:\n",
|
||||
" now_nanos = time.perf_counter_ns()\n",
|
||||
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
|
||||
" break\n",
|
||||
" if f_setup is not None:\n",
|
||||
" f_setup()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" f()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" latency.append((ed-st)/1e9)\n",
|
||||
" return np.array(latency)\n",
|
||||
"\n",
|
||||
"def tail_mean(xs, skip=0.2):\n",
|
||||
" return xs[int(len(xs) * skip):].mean()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "02c9c9b1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"torch.set_grad_enabled(False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "3405fdc7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
|
||||
"seqlen_list = [256]\n",
|
||||
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "10dc981a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[(12, 256), (3, 256)]\n",
|
||||
"[256]\n",
|
||||
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(nd_list)\n",
|
||||
"print(seqlen_list)\n",
|
||||
"print(bs_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7e0ee385",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" seqlen_list = [1] + seqlen_list\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" maxbs = max(bs_list)\n",
|
||||
" print(maxbs, n, d, seqlen)\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" def run():\n",
|
||||
" torch.matmul(X[:bs], W)\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, X, W\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "c206a502",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" try:\n",
|
||||
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
|
||||
" except ValueError:\n",
|
||||
" pbar.update(len(bs_list))\n",
|
||||
" continue\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" if bs > maxbs:\n",
|
||||
" pbar.update()\n",
|
||||
" continue\n",
|
||||
" Q = Qmax[:bs]\n",
|
||||
" K = Kmax[:bs]\n",
|
||||
" def run():\n",
|
||||
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, Q, K, Qmax, Kmax\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "a3a2103c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
|
||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
||||
" pbar = tqdm(total=total)\n",
|
||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
||||
" h = n * d\n",
|
||||
" try:\n",
|
||||
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
|
||||
" except ValueError:\n",
|
||||
" pbar.update(len(bs_list))\n",
|
||||
" continue\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
||||
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" for bs in reversed(bs_list):\n",
|
||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
||||
" if bs > maxbs:\n",
|
||||
" pbar.update()\n",
|
||||
" continue\n",
|
||||
" Q = Qmax[:bs]\n",
|
||||
" K = Kmax[:bs]\n",
|
||||
" def run():\n",
|
||||
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" def clear_cache():\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
||||
" l = tail_mean(latency)\n",
|
||||
" out.append({\n",
|
||||
" \"n\": n,\n",
|
||||
" \"d\": d,\n",
|
||||
" \"seqlen\": seqlen,\n",
|
||||
" \"bs\": bs,\n",
|
||||
" \"latency\": l\n",
|
||||
" })\n",
|
||||
" pbar.update()\n",
|
||||
" del cache, Q, K, Qmax, Kmax\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
" pbar.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "3aaad98a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = {}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "18137de3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/22 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"qk_init\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "26c76e15",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"qk_ar\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "313e36eb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 3 256 256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 3 256 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 12 256 256\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2048 12 256 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db = []\n",
|
||||
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
|
||||
"data[\"dense\"] = db"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "50c37959",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
|
||||
" pickle.dump(data, f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "828ddb54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df_dense = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"dense\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"dense\")\n",
|
||||
")\n",
|
||||
"df_qk_init = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"qk_init\")\n",
|
||||
")\n",
|
||||
"df_qk_ar = (\n",
|
||||
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
|
||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
|
||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
|
||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
||||
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
|
||||
" .assign(series=\"qk_ar\")\n",
|
||||
")\n",
|
||||
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "c296a395",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a25cdd5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "63b8a531",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af90eff1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
|
||||
" return transformers.OPTConfig(\n",
|
||||
" num_hidden_layers=n_layers,\n",
|
||||
" hidden_size=d_model,\n",
|
||||
" ffn_dim=d_model*4,\n",
|
||||
" num_attention_heads=n_heads,\n",
|
||||
" **kwargs\n",
|
||||
" )\n",
|
||||
"optcfg = {\n",
|
||||
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
|
||||
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
|
||||
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
|
||||
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
|
||||
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
|
||||
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
|
||||
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
|
||||
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
|
||||
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
|
||||
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
|
||||
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
|
||||
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
|
||||
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5b9ebbec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
|
||||
" bs, tgt_len = input_ids.shape\n",
|
||||
" if past_key_values is not None:\n",
|
||||
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
|
||||
" assert bs == _bs\n",
|
||||
" else:\n",
|
||||
" src_len = 0\n",
|
||||
" if attention_mask is None:\n",
|
||||
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
|
||||
" ret = model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" past_key_values=past_key_values,\n",
|
||||
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
|
||||
" )\n",
|
||||
" return ret\n",
|
||||
"\n",
|
||||
"def time_greedy_generate(model, input_ids, new_tokens):\n",
|
||||
" ts = []\n",
|
||||
" output = input_ids\n",
|
||||
" past_key_values = None\n",
|
||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
|
||||
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
|
||||
" for _ in range(new_tokens):\n",
|
||||
" cache.zero_()\n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" st = time.perf_counter_ns()\n",
|
||||
" \n",
|
||||
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
|
||||
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
|
||||
" output = torch.cat([output, input_ids], axis=1)\n",
|
||||
" past_key_values = ret.past_key_values\n",
|
||||
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
|
||||
" \n",
|
||||
" torch.cuda.synchronize()\n",
|
||||
" ed = time.perf_counter_ns()\n",
|
||||
" ts.append((ed-st)/1e9)\n",
|
||||
" return np.array(ts)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc92f940",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"opt_config = optcfg[\"6.7b\"]\n",
|
||||
"\n",
|
||||
"torch.set_default_dtype(torch.bfloat16)\n",
|
||||
"with transformers.modeling_utils.no_init_weights():\n",
|
||||
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
|
||||
"torch.set_default_dtype(torch.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c19fa396",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db = {}\n",
|
||||
"input_tokens = 200\n",
|
||||
"new_tokens = 500\n",
|
||||
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
|
||||
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
|
||||
" stack = []\n",
|
||||
" for _ in range(10):\n",
|
||||
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
|
||||
" stack.append(l)\n",
|
||||
" db[bs] = np.median(np.stack(stack), axis=0)\n",
|
||||
" del x\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
"del model\n",
|
||||
"torch.cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
|
||||
" pickle.dump(db, f)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user