refactor: reorgnize all examples/ and test/

This commit is contained in:
Andy Lee
2025-08-03 22:37:45 -07:00
parent 58556ef44c
commit b0239b6e4d
41 changed files with 127 additions and 1926 deletions

120
benchmarks/README.md Normal file
View File

@@ -0,0 +1,120 @@
# 🧪 Leann Sanity Checks
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
## 📁 Test Files
### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend:
-**MIPS** (Maximum Inner Product Search)
-**L2** (Euclidean Distance)
-**Cosine** (Cosine Similarity)
```bash
uv run python tests/sanity_checks/test_distance_functions.py
```
### `test_l2_verification.py`
Specifically verifies that L2 distance is correctly implemented by:
- Building indices with L2 vs Cosine metrics
- Comparing search results and score ranges
- Validating that different metrics produce expected score patterns
```bash
uv run python tests/sanity_checks/test_l2_verification.py
```
### `test_sanity_check.py`
Comprehensive end-to-end verification including:
- Distance function testing
- Embedding model compatibility
- Search result correctness validation
- Backend integration testing
```bash
uv run python tests/sanity_checks/test_sanity_check.py
```
## 🎯 What These Tests Verify
### ✅ Distance Function Support
- All three distance metrics (MIPS, L2, Cosine) work correctly
- Score ranges are appropriate for each metric type
- Different metrics can produce different rankings (as expected)
### ✅ Backend Integration
- DiskANN backend properly initializes and builds indices
- Graph construction completes without errors
- Search operations return valid results
### ✅ Embedding Pipeline
- Real-time embedding computation works
- Multiple embedding models are supported
- ZMQ server communication functions correctly
### ✅ End-to-End Functionality
- Index building → searching → result retrieval pipeline
- Metadata preservation through the entire flow
- Error handling and graceful degradation
## 🔍 Expected Output
When all tests pass, you should see:
```
📊 测试结果总结:
mips : ✅ 通过
l2 : ✅ 通过
cosine : ✅ 通过
🎉 测试完成!
```
## 🐛 Troubleshooting
### Common Issues
**Import Errors**: Ensure you're running from the project root:
```bash
cd /path/to/leann
uv run python tests/sanity_checks/test_distance_functions.py
```
**Memory Issues**: Reduce graph complexity for resource-constrained systems:
```python
builder = LeannBuilder(
backend_name="diskann",
graph_degree=8, # Reduced from 16
complexity=16 # Reduced from 32
)
```
**ZMQ Port Conflicts**: The tests use different ports to avoid conflicts, but you may need to kill existing processes:
```bash
pkill -f "embedding_server"
```
## 📊 Performance Expectations
### Typical Timing (3 documents, consumer hardware):
- **Index Building**: 2-5 seconds per distance function
- **Search Query**: 50-200ms
- **Recompute Mode**: 5-15 seconds (higher accuracy)
### Memory Usage:
- **Index Storage**: ~1-2 MB per distance function
- **Runtime Memory**: ~500MB (including model loading)
## 🔗 Integration with CI/CD
These tests are designed to be run in automated environments:
```yaml
# GitHub Actions example
- name: Run Sanity Checks
run: |
uv run python tests/sanity_checks/test_distance_functions.py
uv run python tests/sanity_checks/test_l2_verification.py
```
The tests are deterministic and should produce consistent results across different platforms.

View File

@@ -0,0 +1,141 @@
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from mlx_lm import load
from sentence_transformers import SentenceTransformer
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
input_ids.append(padded)
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Load MLX model
print(f"Loading MLX model: {MODEL_NAME_MLX}")
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
print("MLX model loaded.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(
BATCH_SIZES,
results_torch,
marker="o",
linestyle="-",
label=f"PyTorch ({device})",
)
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,326 @@
#!/usr/bin/env python3
"""
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import gc
import logging
import os
import subprocess
import sys
import time
from pathlib import Path
import psutil
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def print_memory_stats(stage: str, start_mem: float):
"""Print memory statistics"""
current_mem = get_memory_usage()
diff = current_mem - start_mem
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
return current_mem
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
print(f"\n=== {self.name} Memory Summary ===")
for stage, mem in self.stages:
print(f"{stage}: {mem:.1f} MB")
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
return peak_mem
def test_faiss_hnsw():
"""Test Faiss HNSW Vector Store in subprocess"""
print("\n" + "=" * 50)
print("TESTING FAISS HNSW VECTOR STORE")
print("=" * 50)
try:
result = subprocess.run(
[sys.executable, "benchmarks/faiss_only.py"],
capture_output=True,
text=True,
timeout=300,
)
print(result.stdout)
if result.stderr:
print("Stderr:", result.stderr)
if result.returncode != 0:
return {
"peak_memory": float("inf"),
"error": f"Process failed with code {result.returncode}",
}
# Parse peak memory from output
lines = result.stdout.split("\n")
peak_memory = 0.0
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
return {"peak_memory": peak_memory}
except Exception as e:
return {
"peak_memory": float("inf"),
"error": str(e),
}
def test_leann_hnsw():
"""Test LEANN HNSW Search Memory (load existing index)"""
print("\n" + "=" * 50)
print("TESTING LEANN HNSW SEARCH MEMORY")
print("=" * 50)
tracker = MemoryTracker("LEANN HNSW Search")
# Import and setup
tracker.checkpoint("Initial")
from leann.api import LeannSearcher
tracker.checkpoint("After imports")
from leann.api import LeannBuilder
from llama_index.core import SimpleDirectoryReader
# Load and parse documents
documents = SimpleDirectoryReader(
"data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking")
# Build LEANN index
INDEX_DIR = Path("./test_leann_comparison")
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
# Check if index already exists
if os.path.exists(INDEX_PATH + ".meta.json"):
print("Loading existing LEANN HNSW index...")
tracker.checkpoint("After loading existing index")
else:
print("Building new LEANN HNSW index...")
# Clean up previous index
import shutil
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1,
)
tracker.checkpoint("After builder setup")
print("Building LEANN HNSW index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
del builder
gc.collect()
tracker.checkpoint("After index building")
# Find existing LEANN index
index_paths = [
"./test_leann_comparison/comparison.leann",
]
index_path = None
for path in index_paths:
if os.path.exists(path + ".meta.json"):
index_path = path
break
if not index_path:
print("❌ LEANN index not found. Please build it first")
return {"peak_memory": float("inf"), "error": "Index not found"}
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
_ = searcher.search(query, top_k=20, ef=120)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
# Get storage size before cleanup
storage_size = 0
INDEX_DIR = Path(index_path).parent
if INDEX_DIR.exists():
total_size = 0
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
for filename in filenames:
# Only count actual index files, skip text data and backups
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
continue
# Count .index, .idx, .map files (actual index structures)
if filename.endswith((".index", ".idx", ".map")):
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
storage_size = total_size / (1024 * 1024) # Convert to MB
# Clean up
del searcher
gc.collect()
return {
"peak_memory": peak_memory,
"storage_size": storage_size,
}
def main():
"""Run comparison tests"""
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
print("=" * 60)
# Test Faiss HNSW
faiss_results = test_faiss_hnsw()
# Force garbage collection
gc.collect()
time.sleep(2)
# Test LEANN HNSW
leann_results = test_leann_hnsw()
# Final comparison
print("\n" + "=" * 60)
print("STORAGE + SEARCH MEMORY COMPARISON")
print("=" * 60)
# Get storage sizes
faiss_storage_size = 0
leann_storage_size = leann_results.get("storage_size", 0)
# Get Faiss storage size using Python
if os.path.exists("./storage_faiss"):
total_size = 0
for dirpath, _, filenames in os.walk("./storage_faiss"):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
print("Faiss HNSW:")
if "error" in faiss_results:
print(f" ❌ Failed: {faiss_results['error']}")
else:
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f" Storage Size: {faiss_storage_size:.1f} MB")
print("\nLEANN HNSW:")
if "error" in leann_results:
print(f" ❌ Failed: {leann_results['error']}")
else:
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f" Storage Size: {leann_storage_size:.1f} MB")
# Calculate improvements only if both tests succeeded
if "error" not in faiss_results and "error" not in leann_results:
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
else:
print(" Storage Size: similar")
else:
if "error" not in leann_results:
print("\n✅ LEANN HNSW completed successfully!")
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
if "error" not in faiss_results:
print("\n✅ Faiss HNSW completed successfully!")
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
if __name__ == "__main__":
main()

151
benchmarks/faiss_only.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import os
import sys
import time
import psutil
def get_memory_usage():
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = get_memory_usage()
diff = current_mem - self.start_mem
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
return peak_mem
def main():
try:
import faiss
except ImportError:
print("Faiss is not installed.")
print(
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
)
sys.exit(1)
from llama_index.core import (
Settings,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
Settings.embed_model = embed_model
tracker.checkpoint("After embedding model setup")
d = 768
faiss_index = faiss.IndexHNSWFlat(d, 32)
faiss_index.hnsw.efConstruction = 64
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks using the same splitter as LEANN
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
tracker.checkpoint("After text splitter setup")
# Check if index already exists and try to load it
index_loaded = False
if os.path.exists("./storage_faiss"):
print("Loading existing Faiss HNSW index...")
try:
# Use the correct Faiss loading pattern from the example
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print("Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
print(f"Failed to load existing index: {e}")
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context, transformations=[node_parser]
)
tracker.checkpoint("After index building")
# Save index to disk using the correct pattern
index.storage_context.persist(persist_dir="./storage_faiss")
tracker.checkpoint("After index saving")
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
_ = query_engine.query(query)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
if __name__ == "__main__":
main()

659
benchmarks/micro_tpt.py Normal file
View File

@@ -0,0 +1,659 @@
# python embedd_micro.py --use_int8 Fastest
import argparse
import time
from contextlib import contextmanager
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from transformers import AutoModel, BitsAndBytesConfig
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: list[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False # Add this parameter
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
class GraphContainer:
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: dict[int, GraphWrapper] = {}
def get_or_create(self, batch_size: int) -> "GraphWrapper":
if batch_size not in self.graphs:
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
return self.graphs[batch_size]
class GraphWrapper:
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.device = self._get_device()
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Only use CUDA graphs on NVIDIA GPUs
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
self.use_cuda_graph = True
else:
# For MPS or CPU, just store the model
self.use_cuda_graph = False
self.static_output = None
def _get_device(self) -> str:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
else:
# For MPS/CPU, just run normally
return self.model(input_ids=input_ids, attention_mask=attention_mask)
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
if torch.cuda.is_available():
model = model.cuda()
device = "cuda"
elif torch.backends.mps.is_available():
model = model.to("mps")
device = "mps"
else:
model = model.cpu()
device = "cpu"
print(f"- Model moved to {device}")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA (only on CUDA)
if (
torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention (only on CUDA)
if config.use_flash_attention and torch.cuda.is_available():
try:
from flash_attn.flash_attention import FlashAttention # noqa: F401
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using GPU events or CPU timing."""
def __init__(self):
if torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.use_gpu_timing = True
elif torch.backends.mps.is_available():
# MPS doesn't have events, use CPU timing
self.use_gpu_timing = False
else:
# CPU timing
self.use_gpu_timing = False
@contextmanager
def timing(self):
if self.use_gpu_timing:
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
else:
# Use CPU timing for MPS/CPU
start_time = time.time()
yield
self.cpu_elapsed = time.time() - start_time
def elapsed_time(self) -> float:
if self.use_gpu_timing:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
else:
return self.cpu_elapsed
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
# Only use CUDA graphs on NVIDIA GPUs
if config.use_cuda_graphs and torch.cuda.is_available():
self.graphs = GraphContainer(self.model, config.seq_length)
else:
self.graphs = None
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {e!s}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# Check if using custom 8bit quantization
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# Load original model (without quantization config)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
model = AutoModel.from_pretrained(
self.config.model_path,
torch_dtype=compute_dtype,
)
# Define replacement function
def replace_linear_with_linear8bitlt(model):
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# Get original linear layer parameters
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# Create 8bit linear layer
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False,
)
# Copy weights and bias
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# Replace module
setattr(model, name, new_module)
else:
# Process child modules recursively
replace_linear_with_linear8bitlt(module)
return model
# Replace all linear layers
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# Move model to GPU (quantization happens here)
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# Use original Int4 quantization method
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto", # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if (
torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try:
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration
elif self.config.use_int8:
print("- Using INT8 quantization")
# For now, just use standard loading with INT8 config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto",
)
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
model.eval()
print("- Model set to eval mode")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {e!s}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
return torch.randint(
0,
1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long,
)
def _run_inference(
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
) -> tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if graph_wrapper is not None:
output = graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> dict[int, dict[str, float]]:
results = {}
# Reset peak memory stats
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
elif torch.backends.mps.is_available():
# MPS doesn't have reset_peak_memory_stats, skip it
pass
else:
print("- No GPU memory stats available")
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create graph for this batch size
graph_wrapper = (
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
elif torch.backends.mps.is_available():
# MPS doesn't have max_memory_allocated, use 0
peak_memory_gb = 0.0
else:
peak_memory_gb = 0.0
print("- No GPU memory usage available")
if peak_memory_gb > 0:
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
else:
print("\n- GPU memory usage not available")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,16,32",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--use_fp16",
action="store_true",
help="Enable FP16 inference",
)
parser.add_argument(
"--use_int4",
action="store_true",
help="Enable INT4 quantization using bitsandbytes",
)
parser.add_argument(
"--use_int8",
action="store_true",
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
)
parser.add_argument(
"--use_linear8bitlt",
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=args.use_fp16,
use_int4=args.use_int4,
use_int8=args.use_int8, # Add this line
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = (
"int4"
if config.use_int4
else "int8"
if config.use_int8
else "fp16"
if config.use_fp16
else "fp32"
)
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
},
"results": {str(k): v for k, v in results.items()},
},
f,
indent=2,
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,359 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
from leann.api import LeannBuilder, LeannSearcher
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
try:
from huggingface_hub import snapshot_download
if download_embeddings:
# Download everything including embeddings (large files)
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
return golden_texts
def load_queries(file_path: Path) -> list[str]:
queries = []
with open(file_path, encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument(
"index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'benchmarks/'
# and evaluation data is in 'benchmarks/data/'.
script_dir = Path(__file__).resolve().parent
data_root = script_dir / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print("No indices found. The data download should have included pre-built indices.")
print(
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file) as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,311 @@
import time
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from transformers import AutoModel
# Add MLX imports
try:
import mlx.core as mx
from mlx_lm.utils import load
MLX_AVAILABLE = True
except ImportError:
print("MLX not available. Install with: uv pip install mlx mlx-lm")
MLX_AVAILABLE = False
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever"
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
use_mlx: bool = False # New flag for MLX testing
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
class MLXBenchmark:
"""MLX-specific benchmark for embedding models"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model, self.tokenizer = self._load_model()
def _load_model(self):
"""Load MLX model and tokenizer following the API pattern"""
print(f"Loading MLX model from {self.config.model_path}...")
try:
model, tokenizer = load(self.config.model_path)
print("MLX model loaded successfully")
return model, tokenizer
except Exception as e:
print(f"Error loading MLX model: {e}")
raise
def _create_random_batch(self, batch_size: int):
"""Create random input batches for MLX testing - same as PyTorch"""
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
def _run_inference(self, input_ids: torch.Tensor) -> float:
"""Run MLX inference with same input as PyTorch"""
start_time = time.time()
try:
# Convert PyTorch tensor to MLX array
input_ids_mlx = mx.array(input_ids.numpy())
# Get embeddings
embeddings = self.model(input_ids_mlx)
# Mean pooling (following the API pattern)
pooled = embeddings.mean(axis=1)
# Convert to numpy (following the API pattern)
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
# Force computation
_ = pooled_numpy.shape
except Exception as e:
print(f"MLX inference error: {e}")
return float("inf")
end_time = time.time()
return end_time - start_time
def run(self) -> dict[int, dict[str, float]]:
"""Run the MLX benchmark across all batch sizes"""
results = {}
print(f"Starting MLX benchmark with model: {self.config.model_path}")
print(f"Testing batch sizes: {self.config.batch_sizes}")
for batch_size in self.config.batch_sizes:
print(f"\n=== Testing MLX batch size: {batch_size} ===")
times = []
# Create input batch (same as PyTorch)
input_ids = self._create_random_batch(batch_size)
# Warm up
print("Warming up...")
for _ in range(3):
try:
self._run_inference(input_ids[:2]) # Warm up with smaller batch
except Exception as e:
print(f"Warmup error: {e}")
break
# Run benchmark
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
if elapsed_time != float("inf"):
times.append(elapsed_time)
except Exception as e:
print(f"Error during MLX inference: {e}")
break
if not times:
print(f"Skipping batch size {batch_size} due to errors")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
"min_time": np.min(times),
"max_time": np.max(times),
}
print(f"MLX Results for batch size {batch_size}:")
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f" Min Time: {np.min(times):.4f}s")
print(f" Max Time: {np.max(times):.4f}s")
print(f" Throughput: {throughput:.2f} sequences/second")
return results
class Benchmark:
def __init__(self, config: BenchmarkConfig):
self.config = config
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = self._load_model()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
if self.config.use_fp16:
model = model.half()
model = torch.compile(model)
model = model.to(self.device)
model.eval()
return model
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0,
1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long,
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
start_time = time.time()
with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask)
end_time = time.time()
return end_time - start_time
def run(self) -> dict[int, dict[str, float]]:
results = {}
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
input_ids = self._create_random_batch(batch_size)
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
continue
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
else:
peak_memory_gb = 0.0
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def run_benchmark():
"""Main function to run the benchmark with optimized parameters."""
config = BenchmarkConfig()
try:
benchmark = Benchmark(config)
results = benchmark.run()
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results,
}
except Exception as e:
print(f"Benchmark failed: {e}")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
def run_mlx_benchmark():
"""Run MLX-specific benchmark"""
if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available",
}
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
try:
benchmark = MLXBenchmark(config)
results = benchmark.run()
if not results:
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results",
}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results,
}
except Exception as e:
print(f"MLX benchmark failed: {e}")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
if __name__ == "__main__":
print("=== PyTorch Benchmark ===")
pytorch_result = run_benchmark()
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
print("\n=== MLX Benchmark ===")
mlx_result = run_mlx_benchmark()
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
# Compare results
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
print("\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")