Files
LEANN/research/micro/int8.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

83 lines
2.0 KiB
Python

import torch
import nvmath.bindings.cublas
import ctypes
# 创建 CUBLAS 句柄
handle = nvmath.bindings.cublas.create()
# 准备数据 - 使用 uint8 类型,并确保内存连续
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
# 确保张量在 CUDA 上
assert a.is_cuda and b.is_cuda and c.is_cuda
# 确保张量是连续的
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
# 获取指针
a_ptr = a.data_ptr()
b_ptr = b.data_ptr()
c_ptr = c.data_ptr()
# 设置参数
transa = 0 # CUBLAS_OP_N (不转置)
transb = 0 # CUBLAS_OP_N (不转置)
transc = 0 # CUBLAS_OP_N (不转置)
# 设置偏置值
a_bias = 0
b_bias = 0
c_bias = 0
# 设置正确的 leading dimensions
lda = k # A 的 leading dimension
ldb = n # B 的 leading dimension
ldc = n # C 的 leading dimension
c_mult = 1
c_shift = 0
# 打印调试信息
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
try:
# 调用 uint8gemm_bias
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr, a_bias, lda,
b_ptr, b_bias, ldb,
c_ptr, c_bias, ldc,
c_mult, c_shift
)
except Exception as e:
print(f"Error: {e}")
# 尝试使用 ctypes 转换指针
a_ptr_c = ctypes.c_void_p(a_ptr).value
b_ptr_c = ctypes.c_void_p(b_ptr).value
c_ptr_c = ctypes.c_void_p(c_ptr).value
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
# 再次尝试调用
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr_c, a_bias, lda,
b_ptr_c, b_bias, ldb,
c_ptr_c, c_bias, ldc,
c_mult, c_shift
)
# 销毁 CUBLAS 句柄
nvmath.bindings.cublas.destroy(handle)
# 打印结果
print("Result:")
print(c)