fix ruff errors and formatting

This commit is contained in:
yichuan520030910320
2025-07-27 02:22:54 -07:00
parent 383c6d8d7e
commit af1790395a
35 changed files with 166 additions and 107 deletions

View File

@@ -58,7 +58,8 @@ class GraphWrapper:
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input, attention_mask=self.static_attention_mask
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
self.use_cuda_graph = True
else:
@@ -82,7 +83,10 @@ class GraphWrapper:
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask)
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
@@ -261,7 +265,10 @@ class Benchmark:
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features, out_features, bias=bias, has_fp16_weights=False
in_features,
out_features,
bias=bias,
has_fp16_weights=False,
)
# Copy weights and bias
@@ -350,8 +357,6 @@ class Benchmark:
# Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
@@ -427,7 +432,11 @@ class Benchmark:
else "cpu"
)
return torch.randint(
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long
0,
1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long,
)
def _run_inference(