218 lines
7.6 KiB
Python
218 lines
7.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import time
|
|
import torch.nn.functional as F
|
|
|
|
# Import necessary functions from the quantize.py file
|
|
def get_group_qparams(w, n_bit=4, groupsize=128):
|
|
# needed for GPTQ with padding
|
|
if groupsize > w.shape[-1]:
|
|
groupsize = w.shape[-1]
|
|
assert groupsize > 1
|
|
assert w.shape[-1] % groupsize == 0
|
|
assert w.dim() == 2
|
|
|
|
to_quant = w.reshape(-1, groupsize)
|
|
assert torch.isnan(to_quant).sum() == 0
|
|
|
|
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
max_int = 2**n_bit - 1
|
|
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
|
torch.bfloat16
|
|
).reshape(w.shape[0], -1)
|
|
|
|
def pack_scales_and_zeros(scales, zeros):
|
|
assert scales.shape == zeros.shape
|
|
assert scales.dtype == torch.bfloat16
|
|
assert zeros.dtype == torch.bfloat16
|
|
return (
|
|
torch.cat(
|
|
[
|
|
scales.reshape(scales.size(0), scales.size(1), 1),
|
|
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
|
],
|
|
2,
|
|
)
|
|
.transpose(0, 1)
|
|
.contiguous()
|
|
)
|
|
|
|
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
|
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
|
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
|
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
|
return w_int32, scales_and_zeros
|
|
|
|
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
|
assert groupsize > 1
|
|
# needed for GPTQ single column quantize
|
|
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
|
groupsize = w.shape[-1]
|
|
|
|
assert w.shape[-1] % groupsize == 0
|
|
assert w.dim() == 2
|
|
|
|
to_quant = w.reshape(-1, groupsize)
|
|
assert torch.isnan(to_quant).sum() == 0
|
|
|
|
scales = scales.reshape(-1, 1)
|
|
zeros = zeros.reshape(-1, 1)
|
|
min_val = zeros - scales * (2 ** (n_bit - 1))
|
|
max_int = 2**n_bit - 1
|
|
min_int = 0
|
|
w_int32 = (
|
|
to_quant.sub(min_val)
|
|
.div(scales)
|
|
.round()
|
|
.clamp_(min_int, max_int)
|
|
.to(torch.int32)
|
|
.reshape_as(w)
|
|
)
|
|
|
|
return w_int32
|
|
|
|
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
|
weight_int32, scales_and_zeros = group_quantize_tensor(
|
|
weight_bf16, n_bit=4, groupsize=groupsize
|
|
)
|
|
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
|
return weight_int4pack, scales_and_zeros
|
|
|
|
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
|
origin_x_size = x.size()
|
|
x = x.reshape(-1, origin_x_size[-1])
|
|
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
|
new_shape = origin_x_size[:-1] + (out_features,)
|
|
c = c.reshape(new_shape)
|
|
return c
|
|
|
|
class WeightOnlyInt4Linear(torch.nn.Module):
|
|
__constants__ = ['in_features', 'out_features']
|
|
in_features: int
|
|
out_features: int
|
|
weight: torch.Tensor
|
|
|
|
def __init__(
|
|
self, in_features: int, out_features: int,
|
|
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
|
) -> None:
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.groupsize = groupsize
|
|
self.inner_k_tiles = inner_k_tiles
|
|
|
|
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
|
self.register_buffer(
|
|
"weight",
|
|
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
|
)
|
|
self.register_buffer(
|
|
"scales_and_zeros",
|
|
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
|
)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
input = input.to(torch.bfloat16)
|
|
return linear_forward_int4(
|
|
input,
|
|
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
|
)
|
|
|
|
# Define dimensions that satisfy the requirements for INT4 quantization
|
|
# in_features must be divisible by inner_k_tiles * 16
|
|
# out_features must be divisible by 8
|
|
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
|
out_features = 2048 # Must be divisible by 8
|
|
groupsize = 128
|
|
inner_k_tiles = 8
|
|
|
|
# Create models
|
|
fp16_model = nn.Sequential(
|
|
nn.Linear(in_features, out_features, bias=False)
|
|
)
|
|
|
|
# Create INT4 model
|
|
int4_model = nn.Sequential(
|
|
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
|
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
|
)
|
|
|
|
# Quantize the weights and set up the INT4 model
|
|
with torch.no_grad():
|
|
# Convert FP16 weights to INT4
|
|
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
|
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
|
fp16_weight, groupsize, inner_k_tiles
|
|
)
|
|
|
|
# Set the quantized weights in the INT4 model
|
|
int4_model[0].weight.copy_(weight_int4pack)
|
|
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
|
|
|
# Move models to GPU
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
fp16_model = fp16_model.to(device)
|
|
int4_model = int4_model.to(device)
|
|
|
|
# Create random input tensor
|
|
batch_size = 1024
|
|
input_tensor = torch.randn(batch_size, in_features, device=device)
|
|
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
|
|
|
# Speed test function
|
|
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
# Warmup
|
|
for _ in range(10):
|
|
_ = model(input_tensor)
|
|
|
|
# Actual timing
|
|
torch.cuda.synchronize()
|
|
start_time = time.time()
|
|
|
|
for _ in range(num_iterations):
|
|
_ = model(input_tensor)
|
|
|
|
torch.cuda.synchronize()
|
|
end_time = time.time()
|
|
|
|
avg_time = (end_time - start_time) / num_iterations
|
|
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
return avg_time
|
|
|
|
# Run speed tests
|
|
with torch.no_grad(): # Disable gradient calculation for inference
|
|
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
|
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
|
|
|
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
|
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
|
|
|
# Calculate speedup
|
|
speedup = fp16_time / int4_time
|
|
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
|
|
|
# Calculate memory savings
|
|
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
|
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
|
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
|
|
|
memory_reduction = fp16_memory / int4_memory
|
|
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
|
|
|
# Check accuracy
|
|
with torch.no_grad():
|
|
fp16_output = fp16_model(input_tensor_bf16)
|
|
int4_output = int4_model(input_tensor)
|
|
|
|
# Calculate error metrics
|
|
abs_error = torch.abs(fp16_output - int4_output)
|
|
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
|
|
|
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
|
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
|
print(f"Mean relative error: {rel_error.mean().item():.6f}") |