fix micro bench and fix pre commit

This commit is contained in:
yichuan-w
2025-08-20 16:59:35 -07:00
parent a913903d73
commit 35f4fbd9d1
6 changed files with 35 additions and 33 deletions

View File

@@ -233,18 +233,6 @@ class HNSWSearcher(BaseSearcher):
# HNSW-specific batch processing parameter
params.batch_size = batch_size
# Log recompute mode and batching for visibility
logger.info(
"HNSW search: recompute=%s, zmq_port=%s, batch_size=%d, efSearch=%d, beam=%d, prune_ratio=%.3f, strategy=%s",
bool(recompute_embeddings),
str(zmq_port),
int(batch_size),
int(complexity),
int(beam_width),
float(prune_ratio),
pruning_strategy,
)
batch_size_query = query.shape[0]
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)

View File

@@ -632,6 +632,7 @@ class LeannSearcher:
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)

View File

@@ -6,11 +6,11 @@ Preserves all optimization parameters to ensure performance
import logging
import os
import time
from typing import Any
import numpy as np
import torch
import time
# Set up logger with proper level
logger = logging.getLogger(__name__)
@@ -248,9 +248,7 @@ def compute_embeddings_sentence_transformers(
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(
f"transformers is required for manual_tokenize=True: {e}"
)
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
@@ -261,13 +259,9 @@ def compute_embeddings_sentence_transformers(
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name, use_fast=True
)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(
model_name, torch_dtype=torch_dtype
)
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
@@ -285,6 +279,7 @@ def compute_embeddings_sentence_transformers(
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
@@ -295,10 +290,12 @@ def compute_embeddings_sentence_transformers(
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
@@ -306,8 +303,23 @@ def compute_embeddings_sentence_transformers(
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
@@ -328,7 +340,8 @@ def compute_embeddings_sentence_transformers(
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")