fix micro bench and fix pre commit

This commit is contained in:
yichuan-w
2025-08-20 16:59:35 -07:00
parent a913903d73
commit 35f4fbd9d1
6 changed files with 35 additions and 33 deletions

View File

@@ -13,4 +13,5 @@ repos:
rev: v0.12.7 # Fixed version to match pyproject.toml
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

View File

@@ -6,8 +6,6 @@ results and the golden standard results, making the comparison robust to ID chan
"""
import argparse
import logging
import os
import json
import sys
import time
@@ -16,11 +14,6 @@ from pathlib import Path
import numpy as np
from leann.api import LeannBuilder, LeannChat, LeannSearcher
# Configure logging level (default INFO; override with LEANN_LOG_LEVEL)
_log_level_str = os.getenv("LEANN_LOG_LEVEL", "INFO").upper()
_log_level = getattr(logging, _log_level_str, logging.INFO)
logging.basicConfig(level=_log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""

View File

@@ -20,7 +20,7 @@ except ImportError:
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever"
model_path: str = "facebook/contriever-msmarco"
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
@@ -34,7 +34,7 @@ class BenchmarkConfig:
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
class MLXBenchmark:
@@ -179,10 +179,16 @@ class Benchmark:
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
# print shape of input_ids and attention_mask
print(f"input_ids shape: {input_ids.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
start_time = time.time()
with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask)
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.synchronize()
end_time = time.time()
return end_time - start_time

View File

@@ -233,18 +233,6 @@ class HNSWSearcher(BaseSearcher):
# HNSW-specific batch processing parameter
params.batch_size = batch_size
# Log recompute mode and batching for visibility
logger.info(
"HNSW search: recompute=%s, zmq_port=%s, batch_size=%d, efSearch=%d, beam=%d, prune_ratio=%.3f, strategy=%s",
bool(recompute_embeddings),
str(zmq_port),
int(batch_size),
int(complexity),
int(beam_width),
float(prune_ratio),
pruning_strategy,
)
batch_size_query = query.shape[0]
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)

View File

@@ -632,6 +632,7 @@ class LeannSearcher:
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)

View File

@@ -6,11 +6,11 @@ Preserves all optimization parameters to ensure performance
import logging
import os
import time
from typing import Any
import numpy as np
import torch
import time
# Set up logger with proper level
logger = logging.getLogger(__name__)
@@ -248,9 +248,7 @@ def compute_embeddings_sentence_transformers(
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(
f"transformers is required for manual_tokenize=True: {e}"
)
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
@@ -261,13 +259,9 @@ def compute_embeddings_sentence_transformers(
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, use_fast=True
)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(
model_name, torch_dtype=torch_dtype
)
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
@@ -285,6 +279,7 @@ def compute_embeddings_sentence_transformers(
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
@@ -295,10 +290,12 @@ def compute_embeddings_sentence_transformers(
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
@@ -306,8 +303,23 @@ def compute_embeddings_sentence_transformers(
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
@@ -328,7 +340,8 @@ def compute_embeddings_sentence_transformers(
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")