fix micro bench and fix pre commit
This commit is contained in:
@@ -13,4 +13,5 @@ repos:
|
|||||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ results and the golden standard results, making the comparison robust to ID chan
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -16,11 +14,6 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
# Configure logging level (default INFO; override with LEANN_LOG_LEVEL)
|
|
||||||
_log_level_str = os.getenv("LEANN_LOG_LEVEL", "INFO").upper()
|
|
||||||
_log_level = getattr(logging, _log_level_str, logging.INFO)
|
|
||||||
logging.basicConfig(level=_log_level, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ except ImportError:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: list[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
@@ -179,10 +179,16 @@ class Benchmark:
|
|||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|||||||
@@ -233,18 +233,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
# HNSW-specific batch processing parameter
|
# HNSW-specific batch processing parameter
|
||||||
params.batch_size = batch_size
|
params.batch_size = batch_size
|
||||||
|
|
||||||
# Log recompute mode and batching for visibility
|
|
||||||
logger.info(
|
|
||||||
"HNSW search: recompute=%s, zmq_port=%s, batch_size=%d, efSearch=%d, beam=%d, prune_ratio=%.3f, strategy=%s",
|
|
||||||
bool(recompute_embeddings),
|
|
||||||
str(zmq_port),
|
|
||||||
int(batch_size),
|
|
||||||
int(complexity),
|
|
||||||
int(beam_width),
|
|
||||||
float(prune_ratio),
|
|
||||||
pruning_strategy,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size_query = query.shape[0]
|
batch_size_query = query.shape[0]
|
||||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|||||||
@@ -632,6 +632,7 @@ class LeannSearcher:
|
|||||||
# Only HNSW supports batching; forward conditionally
|
# Only HNSW supports batching; forward conditionally
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
backend_search_kwargs["batch_size"] = batch_size
|
backend_search_kwargs["batch_size"] = batch_size
|
||||||
|
|
||||||
# Merge any extra kwargs last
|
# Merge any extra kwargs last
|
||||||
backend_search_kwargs.update(kwargs)
|
backend_search_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import time
|
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -248,9 +248,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
try:
|
try:
|
||||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
f"transformers is required for manual_tokenize=True: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cache tokenizer and model
|
# Cache tokenizer and model
|
||||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
@@ -261,13 +259,9 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info("Using cached HF tokenizer/model for manual path")
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
else:
|
else:
|
||||||
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
model_name, use_fast=True
|
|
||||||
)
|
|
||||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
hf_model = AutoModel.from_pretrained(
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
model_name, torch_dtype=torch_dtype
|
|
||||||
)
|
|
||||||
hf_model.to(device)
|
hf_model.to(device)
|
||||||
hf_model.eval()
|
hf_model.eval()
|
||||||
# Optional compile on supported devices
|
# Optional compile on supported devices
|
||||||
@@ -285,6 +279,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
try:
|
try:
|
||||||
if show_progress:
|
if show_progress:
|
||||||
from tqdm import tqdm # type: ignore
|
from tqdm import tqdm # type: ignore
|
||||||
|
|
||||||
batch_iter = tqdm(
|
batch_iter = tqdm(
|
||||||
range(0, len(texts), batch_size),
|
range(0, len(texts), batch_size),
|
||||||
desc="Embedding (manual)",
|
desc="Embedding (manual)",
|
||||||
@@ -295,10 +290,12 @@ def compute_embeddings_sentence_transformers(
|
|||||||
except Exception:
|
except Exception:
|
||||||
batch_iter = range(0, len(texts), batch_size)
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
start_time_manual = time.time()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for start_index in batch_iter:
|
for start_index in batch_iter:
|
||||||
end_index = min(start_index + batch_size, len(texts))
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
batch_texts = texts[start_index:end_index]
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
inputs = hf_tokenizer(
|
inputs = hf_tokenizer(
|
||||||
batch_texts,
|
batch_texts,
|
||||||
padding=True,
|
padding=True,
|
||||||
@@ -306,8 +303,23 @@ def compute_embeddings_sentence_transformers(
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
outputs = hf_model(**inputs)
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
attention_mask = inputs.get("attention_mask")
|
attention_mask = inputs.get("attention_mask")
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@@ -328,7 +340,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
logger.info(f"Time taken: {end_time - start_time} seconds")
|
logger.info(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
|||||||
Reference in New Issue
Block a user