fix micro bench and fix pre commit

This commit is contained in:
yichuan-w
2025-08-20 16:59:35 -07:00
parent a913903d73
commit 35f4fbd9d1
6 changed files with 35 additions and 33 deletions

View File

@@ -13,4 +13,5 @@ repos:
rev: v0.12.7 # Fixed version to match pyproject.toml rev: v0.12.7 # Fixed version to match pyproject.toml
hooks: hooks:
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format - id: ruff-format

View File

@@ -6,8 +6,6 @@ results and the golden standard results, making the comparison robust to ID chan
""" """
import argparse import argparse
import logging
import os
import json import json
import sys import sys
import time import time
@@ -16,11 +14,6 @@ from pathlib import Path
import numpy as np import numpy as np
from leann.api import LeannBuilder, LeannChat, LeannSearcher from leann.api import LeannBuilder, LeannChat, LeannSearcher
# Configure logging level (default INFO; override with LEANN_LOG_LEVEL)
_log_level_str = os.getenv("LEANN_LOG_LEVEL", "INFO").upper()
_log_level = getattr(logging, _log_level_str, logging.INFO)
logging.basicConfig(level=_log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
def download_data_if_needed(data_root: Path, download_embeddings: bool = False): def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub.""" """Checks if the data directory exists, and if not, downloads it from HF Hub."""

View File

@@ -20,7 +20,7 @@ except ImportError:
@dataclass @dataclass
class BenchmarkConfig: class BenchmarkConfig:
model_path: str = "facebook/contriever" model_path: str = "facebook/contriever-msmarco"
batch_sizes: list[int] = None batch_sizes: list[int] = None
seq_length: int = 256 seq_length: int = 256
num_runs: int = 5 num_runs: int = 5
@@ -34,7 +34,7 @@ class BenchmarkConfig:
def __post_init__(self): def __post_init__(self):
if self.batch_sizes is None: if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64] self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
class MLXBenchmark: class MLXBenchmark:
@@ -179,10 +179,16 @@ class Benchmark:
def _run_inference(self, input_ids: torch.Tensor) -> float: def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
# print shape of input_ids and attention_mask
print(f"input_ids shape: {input_ids.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
start_time = time.time() start_time = time.time()
with torch.no_grad(): with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask) self.model(input_ids=input_ids, attention_mask=attention_mask)
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.synchronize()
end_time = time.time() end_time = time.time()
return end_time - start_time return end_time - start_time

View File

@@ -233,18 +233,6 @@ class HNSWSearcher(BaseSearcher):
# HNSW-specific batch processing parameter # HNSW-specific batch processing parameter
params.batch_size = batch_size params.batch_size = batch_size
# Log recompute mode and batching for visibility
logger.info(
"HNSW search: recompute=%s, zmq_port=%s, batch_size=%d, efSearch=%d, beam=%d, prune_ratio=%.3f, strategy=%s",
bool(recompute_embeddings),
str(zmq_port),
int(batch_size),
int(complexity),
int(beam_width),
float(prune_ratio),
pruning_strategy,
)
batch_size_query = query.shape[0] batch_size_query = query.shape[0]
distances = np.empty((batch_size_query, top_k), dtype=np.float32) distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64) labels = np.empty((batch_size_query, top_k), dtype=np.int64)

View File

@@ -632,6 +632,7 @@ class LeannSearcher:
# Only HNSW supports batching; forward conditionally # Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw": if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last # Merge any extra kwargs last
backend_search_kwargs.update(kwargs) backend_search_kwargs.update(kwargs)

View File

@@ -6,11 +6,11 @@ Preserves all optimization parameters to ensure performance
import logging import logging
import os import os
import time
from typing import Any from typing import Any
import numpy as np import numpy as np
import torch import torch
import time
# Set up logger with proper level # Set up logger with proper level
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -248,9 +248,7 @@ def compute_embeddings_sentence_transformers(
try: try:
from transformers import AutoModel, AutoTokenizer # type: ignore from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
f"transformers is required for manual_tokenize=True: {e}"
)
# Cache tokenizer and model # Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}" tok_cache_key = f"hf_tokenizer_{model_name}"
@@ -261,13 +259,9 @@ def compute_embeddings_sentence_transformers(
logger.info("Using cached HF tokenizer/model for manual path") logger.info("Using cached HF tokenizer/model for manual path")
else: else:
logger.info("Loading HF tokenizer/model for manual tokenization path") logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained( hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model_name, use_fast=True
)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32 torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained( hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
model_name, torch_dtype=torch_dtype
)
hf_model.to(device) hf_model.to(device)
hf_model.eval() hf_model.eval()
# Optional compile on supported devices # Optional compile on supported devices
@@ -285,6 +279,7 @@ def compute_embeddings_sentence_transformers(
try: try:
if show_progress: if show_progress:
from tqdm import tqdm # type: ignore from tqdm import tqdm # type: ignore
batch_iter = tqdm( batch_iter = tqdm(
range(0, len(texts), batch_size), range(0, len(texts), batch_size),
desc="Embedding (manual)", desc="Embedding (manual)",
@@ -295,10 +290,12 @@ def compute_embeddings_sentence_transformers(
except Exception: except Exception:
batch_iter = range(0, len(texts), batch_size) batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode(): with torch.inference_mode():
for start_index in batch_iter: for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts)) end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index] batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer( inputs = hf_tokenizer(
batch_texts, batch_texts,
padding=True, padding=True,
@@ -306,8 +303,23 @@ def compute_embeddings_sentence_transformers(
max_length=max_length, max_length=max_length,
return_tensors="pt", return_tensors="pt",
) )
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()} inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs) outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H) last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask") attention_mask = inputs.get("attention_mask")
if attention_mask is None: if attention_mask is None:
@@ -328,7 +340,8 @@ def compute_embeddings_sentence_transformers(
torch.cuda.synchronize() torch.cuda.synchronize()
except Exception: except Exception:
pass pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time() end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds") logger.info(f"Time taken: {end_time - start_time} seconds")