refactor: Unify examples interface with BaseRAGExample (#12)

* refactor: Unify examples interface with BaseRAGExample

- Create BaseRAGExample base class for all RAG examples
- Refactor 4 examples to use unified interface:
  - document_rag.py (replaces main_cli_example.py)
  - email_rag.py (replaces mail_reader_leann.py)
  - browser_rag.py (replaces google_history_reader_leann.py)
  - wechat_rag.py (replaces wechat_history_reader_leann.py)
- Maintain 100% parameter compatibility with original files
- Add interactive mode support for all examples
- Unify parameter names (--max-items replaces --max-emails/--max-entries)
- Update README.md with new examples usage
- Add PARAMETER_CONSISTENCY.md documenting all parameter mappings
- Keep main_cli_example.py for backward compatibility with migration notice

All default values, LeannBuilder parameters, and chunking settings
remain identical to ensure full compatibility with existing indexes.

* fix: Update CI tests for new unified examples interface

- Rename test_main_cli.py to test_document_rag.py
- Update all references from main_cli_example.py to document_rag.py
- Update tests/README.md documentation

The tests now properly test the new unified interface while maintaining
the same test coverage and functionality.

* fix: Fix pre-commit issues and update tests

- Fix import sorting and unused imports
- Update type annotations to use built-in types (list, dict) instead of typing.List/Dict
- Fix trailing whitespace and end-of-file issues
- Fix Chinese fullwidth comma to regular comma
- Update test_main_cli.py to test_document_rag.py
- Add backward compatibility test for main_cli_example.py
- Pass all pre-commit hooks (ruff, ruff-format, etc.)

* refactor: Remove old example scripts and migration references

- Delete old example scripts (mail_reader_leann.py, google_history_reader_leann.py, etc.)
- Remove migration hints and backward compatibility
- Update tests to use new unified examples directly
- Clean up all references to old script names
- Users now only see the new unified interface

* fix: Restore embedding-mode parameter to all examples

- All examples now have --embedding-mode parameter (unified interface benefit)
- Default is 'sentence-transformers' (consistent with original behavior)
- Users can now use OpenAI or MLX embeddings with any data source
- Maintains functional equivalence with original scripts

* docs: Improve parameter categorization in README

- Clearly separate core (shared) vs specific parameters
- Move LLM and embedding examples to 'Example Commands' section
- Add descriptive comments for all specific parameters
- Keep only truly data-source-specific parameters in specific sections

* docs: Make example commands more representative

- Add default values to parameter descriptions
- Replace generic examples with real-world use cases
- Focus on data-source-specific features in examples
- Remove redundant demonstrations of common parameters

* docs: Reorganize parameter documentation structure

- Move common parameters to a dedicated section before all examples
- Rename sections to 'X-Specific Arguments' for clarity
- Remove duplicate common parameters from individual examples
- Better information architecture for users

* docs: polish applications

* docs: Add CLI installation instructions

- Add two installation options: venv and global uv tool
- Clearly explain when to use each option
- Make CLI more accessible for daily use

* docs: Clarify CLI global installation process

- Explain the transition from venv to global installation
- Add upgrade command for global installation
- Make it clear that global install allows usage without venv activation

* docs: Add collapsible section for CLI installation

- Wrap CLI installation instructions in details/summary tags
- Keep consistent with other collapsible sections in README
- Improve document readability and navigation

* style: format

* docs: Fix collapsible sections

- Make Common Parameters collapsible (as it's lengthy reference material)
- Keep CLI Installation visible (important for users to see immediately)
- Better information hierarchy

* docs: Add introduction for Common Parameters section

- Add 'Flexible Configuration' heading with descriptive sentence
- Create parallel structure with 'Generation Model Setup' section
- Improve document flow and readability

* docs: nit

* fix: Fix issues in unified examples

- Add smart path detection for data directory
- Fix add_texts -> add_text method call
- Handle both running from project root and examples directory

* fix: Fix async/await and add_text issues in unified examples

- Remove incorrect await from chat.ask() calls (not async)
- Fix add_texts -> add_text method calls
- Verify search-complexity correctly maps to efSearch parameter
- All examples now run successfully

* feat: Address review comments

- Add complexity parameter to LeannChat initialization (default: search_complexity)
- Fix chunk-size default in README documentation (256, not 2048)
- Add more index building parameters as CLI arguments:
  - --backend-name (hnsw/diskann)
  - --graph-degree (default: 32)
  - --build-complexity (default: 64)
  - --no-compact (disable compact storage)
  - --no-recompute (disable embedding recomputation)
- Update README to document all new parameters

* feat: Add chunk-size parameters and improve file type filtering

- Add --chunk-size and --chunk-overlap parameters to all RAG examples
- Preserve original default values for each data source:
  - Document: 256/128 (optimized for general documents)
  - Email: 256/25 (smaller overlap for email threads)
  - Browser: 256/128 (standard for web content)
  - WeChat: 192/64 (smaller chunks for chat messages)
- Make --file-types optional filter instead of restriction in document_rag
- Update README to clarify interactive mode and parameter usage
- Fix LLM default model documentation (gpt-4o, not gpt-4o-mini)

* feat: Update documentation based on review feedback

- Add MLX embedding example to README
- Clarify examples/data content description (two papers, Pride and Prejudice, Chinese README)
- Move chunk parameters to common parameters section
- Remove duplicate chunk parameters from document-specific section

* docs: Emphasize diverse data sources in examples/data description

* fix: update default embedding models for better performance

- Change WeChat, Browser, and Email RAG examples to use all-MiniLM-L6-v2
- Previous Qwen/Qwen3-Embedding-0.6B was too slow for these use cases
- all-MiniLM-L6-v2 is a fast 384-dim model, ideal for large-scale personal data

* add response highlight

* change rebuild logic

* fix some example

* feat: check if k is larger than #docs

* fix: WeChat history reader bugs and refactor wechat_rag to use unified architecture

* fix email wrong -1 to process all file

* refactor: reorgnize all examples/ and test/

* refactor: reorganize examples and add link checker

* fix: add init.py

* fix: handle certificate errors in link checker

* fix wechat

* merge

* docs: update README to use proper module imports for apps

- Change from 'python apps/xxx.py' to 'python -m apps.xxx'
- More professional and pythonic module calling
- Ensures proper module resolution and imports
- Better separation between apps/ (production tools) and examples/ (demos)

---------

Co-authored-by: yichuan520030910320 <yichuan_wang@berkeley.edu>
This commit is contained in:
Andy Lee
2025-08-03 23:06:24 -07:00
committed by GitHub
parent 54df6310c5
commit 8899734952
50 changed files with 1293 additions and 3193 deletions

120
benchmarks/README.md Normal file
View File

@@ -0,0 +1,120 @@
# 🧪 Leann Sanity Checks
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
## 📁 Test Files
### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend:
-**MIPS** (Maximum Inner Product Search)
-**L2** (Euclidean Distance)
-**Cosine** (Cosine Similarity)
```bash
uv run python tests/sanity_checks/test_distance_functions.py
```
### `test_l2_verification.py`
Specifically verifies that L2 distance is correctly implemented by:
- Building indices with L2 vs Cosine metrics
- Comparing search results and score ranges
- Validating that different metrics produce expected score patterns
```bash
uv run python tests/sanity_checks/test_l2_verification.py
```
### `test_sanity_check.py`
Comprehensive end-to-end verification including:
- Distance function testing
- Embedding model compatibility
- Search result correctness validation
- Backend integration testing
```bash
uv run python tests/sanity_checks/test_sanity_check.py
```
## 🎯 What These Tests Verify
### ✅ Distance Function Support
- All three distance metrics (MIPS, L2, Cosine) work correctly
- Score ranges are appropriate for each metric type
- Different metrics can produce different rankings (as expected)
### ✅ Backend Integration
- DiskANN backend properly initializes and builds indices
- Graph construction completes without errors
- Search operations return valid results
### ✅ Embedding Pipeline
- Real-time embedding computation works
- Multiple embedding models are supported
- ZMQ server communication functions correctly
### ✅ End-to-End Functionality
- Index building → searching → result retrieval pipeline
- Metadata preservation through the entire flow
- Error handling and graceful degradation
## 🔍 Expected Output
When all tests pass, you should see:
```
📊 测试结果总结:
mips : ✅ 通过
l2 : ✅ 通过
cosine : ✅ 通过
🎉 测试完成!
```
## 🐛 Troubleshooting
### Common Issues
**Import Errors**: Ensure you're running from the project root:
```bash
cd /path/to/leann
uv run python tests/sanity_checks/test_distance_functions.py
```
**Memory Issues**: Reduce graph complexity for resource-constrained systems:
```python
builder = LeannBuilder(
backend_name="diskann",
graph_degree=8, # Reduced from 16
complexity=16 # Reduced from 32
)
```
**ZMQ Port Conflicts**: The tests use different ports to avoid conflicts, but you may need to kill existing processes:
```bash
pkill -f "embedding_server"
```
## 📊 Performance Expectations
### Typical Timing (3 documents, consumer hardware):
- **Index Building**: 2-5 seconds per distance function
- **Search Query**: 50-200ms
- **Recompute Mode**: 5-15 seconds (higher accuracy)
### Memory Usage:
- **Index Storage**: ~1-2 MB per distance function
- **Runtime Memory**: ~500MB (including model loading)
## 🔗 Integration with CI/CD
These tests are designed to be run in automated environments:
```yaml
# GitHub Actions example
- name: Run Sanity Checks
run: |
uv run python tests/sanity_checks/test_distance_functions.py
uv run python tests/sanity_checks/test_l2_verification.py
```
The tests are deterministic and should produce consistent results across different platforms.

View File

@@ -0,0 +1,141 @@
import time
import matplotlib.pyplot as plt
import mlx.core as mx
import numpy as np
import torch
from mlx_lm import load
from sentence_transformers import SentenceTransformer
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
input_ids.append(padded)
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Load MLX model
print(f"Loading MLX model: {MODEL_NAME_MLX}")
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
print("MLX model loaded.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(
BATCH_SIZES,
results_torch,
marker="o",
linestyle="-",
label=f"PyTorch ({device})",
)
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,326 @@
#!/usr/bin/env python3
"""
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import gc
import logging
import os
import subprocess
import sys
import time
from pathlib import Path
import psutil
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def print_memory_stats(stage: str, start_mem: float):
"""Print memory statistics"""
current_mem = get_memory_usage()
diff = current_mem - start_mem
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
return current_mem
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
print(f"\n=== {self.name} Memory Summary ===")
for stage, mem in self.stages:
print(f"{stage}: {mem:.1f} MB")
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
return peak_mem
def test_faiss_hnsw():
"""Test Faiss HNSW Vector Store in subprocess"""
print("\n" + "=" * 50)
print("TESTING FAISS HNSW VECTOR STORE")
print("=" * 50)
try:
result = subprocess.run(
[sys.executable, "benchmarks/faiss_only.py"],
capture_output=True,
text=True,
timeout=300,
)
print(result.stdout)
if result.stderr:
print("Stderr:", result.stderr)
if result.returncode != 0:
return {
"peak_memory": float("inf"),
"error": f"Process failed with code {result.returncode}",
}
# Parse peak memory from output
lines = result.stdout.split("\n")
peak_memory = 0.0
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
return {"peak_memory": peak_memory}
except Exception as e:
return {
"peak_memory": float("inf"),
"error": str(e),
}
def test_leann_hnsw():
"""Test LEANN HNSW Search Memory (load existing index)"""
print("\n" + "=" * 50)
print("TESTING LEANN HNSW SEARCH MEMORY")
print("=" * 50)
tracker = MemoryTracker("LEANN HNSW Search")
# Import and setup
tracker.checkpoint("Initial")
from leann.api import LeannSearcher
tracker.checkpoint("After imports")
from leann.api import LeannBuilder
from llama_index.core import SimpleDirectoryReader
# Load and parse documents
documents = SimpleDirectoryReader(
"data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking")
# Build LEANN index
INDEX_DIR = Path("./test_leann_comparison")
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
# Check if index already exists
if os.path.exists(INDEX_PATH + ".meta.json"):
print("Loading existing LEANN HNSW index...")
tracker.checkpoint("After loading existing index")
else:
print("Building new LEANN HNSW index...")
# Clean up previous index
import shutil
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1,
)
tracker.checkpoint("After builder setup")
print("Building LEANN HNSW index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
del builder
gc.collect()
tracker.checkpoint("After index building")
# Find existing LEANN index
index_paths = [
"./test_leann_comparison/comparison.leann",
]
index_path = None
for path in index_paths:
if os.path.exists(path + ".meta.json"):
index_path = path
break
if not index_path:
print("❌ LEANN index not found. Please build it first")
return {"peak_memory": float("inf"), "error": "Index not found"}
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
_ = searcher.search(query, top_k=20, ef=120)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
# Get storage size before cleanup
storage_size = 0
INDEX_DIR = Path(index_path).parent
if INDEX_DIR.exists():
total_size = 0
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
for filename in filenames:
# Only count actual index files, skip text data and backups
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
continue
# Count .index, .idx, .map files (actual index structures)
if filename.endswith((".index", ".idx", ".map")):
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
storage_size = total_size / (1024 * 1024) # Convert to MB
# Clean up
del searcher
gc.collect()
return {
"peak_memory": peak_memory,
"storage_size": storage_size,
}
def main():
"""Run comparison tests"""
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
print("=" * 60)
# Test Faiss HNSW
faiss_results = test_faiss_hnsw()
# Force garbage collection
gc.collect()
time.sleep(2)
# Test LEANN HNSW
leann_results = test_leann_hnsw()
# Final comparison
print("\n" + "=" * 60)
print("STORAGE + SEARCH MEMORY COMPARISON")
print("=" * 60)
# Get storage sizes
faiss_storage_size = 0
leann_storage_size = leann_results.get("storage_size", 0)
# Get Faiss storage size using Python
if os.path.exists("./storage_faiss"):
total_size = 0
for dirpath, _, filenames in os.walk("./storage_faiss"):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
print("Faiss HNSW:")
if "error" in faiss_results:
print(f" ❌ Failed: {faiss_results['error']}")
else:
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f" Storage Size: {faiss_storage_size:.1f} MB")
print("\nLEANN HNSW:")
if "error" in leann_results:
print(f" ❌ Failed: {leann_results['error']}")
else:
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f" Storage Size: {leann_storage_size:.1f} MB")
# Calculate improvements only if both tests succeeded
if "error" not in faiss_results and "error" not in leann_results:
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
else:
print(" Storage Size: similar")
else:
if "error" not in leann_results:
print("\n✅ LEANN HNSW completed successfully!")
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
if "error" not in faiss_results:
print("\n✅ Faiss HNSW completed successfully!")
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
if __name__ == "__main__":
main()

151
benchmarks/faiss_only.py Normal file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import os
import sys
import time
import psutil
def get_memory_usage():
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = get_memory_usage()
diff = current_mem - self.start_mem
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
return peak_mem
def main():
try:
import faiss
except ImportError:
print("Faiss is not installed.")
print(
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
)
sys.exit(1)
from llama_index.core import (
Settings,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
Settings.embed_model = embed_model
tracker.checkpoint("After embedding model setup")
d = 768
faiss_index = faiss.IndexHNSWFlat(d, 32)
faiss_index.hnsw.efConstruction = 64
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks using the same splitter as LEANN
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
tracker.checkpoint("After text splitter setup")
# Check if index already exists and try to load it
index_loaded = False
if os.path.exists("./storage_faiss"):
print("Loading existing Faiss HNSW index...")
try:
# Use the correct Faiss loading pattern from the example
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print("Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
print(f"Failed to load existing index: {e}")
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context, transformations=[node_parser]
)
tracker.checkpoint("After index building")
# Save index to disk using the correct pattern
index.storage_context.persist(persist_dir="./storage_faiss")
tracker.checkpoint("After index saving")
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
_ = query_engine.query(query)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
if __name__ == "__main__":
main()

659
benchmarks/micro_tpt.py Normal file
View File

@@ -0,0 +1,659 @@
# python embedd_micro.py --use_int8 Fastest
import argparse
import time
from contextlib import contextmanager
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from transformers import AutoModel, BitsAndBytesConfig
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: list[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False # Add this parameter
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
class GraphContainer:
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: dict[int, GraphWrapper] = {}
def get_or_create(self, batch_size: int) -> "GraphWrapper":
if batch_size not in self.graphs:
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
return self.graphs[batch_size]
class GraphWrapper:
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.device = self._get_device()
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Only use CUDA graphs on NVIDIA GPUs
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
self.use_cuda_graph = True
else:
# For MPS or CPU, just store the model
self.use_cuda_graph = False
self.static_output = None
def _get_device(self) -> str:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask,
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.use_cuda_graph:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
else:
# For MPS/CPU, just run normally
return self.model(input_ids=input_ids, attention_mask=attention_mask)
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
if torch.cuda.is_available():
model = model.cuda()
device = "cuda"
elif torch.backends.mps.is_available():
model = model.to("mps")
device = "mps"
else:
model = model.cpu()
device = "cpu"
print(f"- Model moved to {device}")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA (only on CUDA)
if (
torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention (only on CUDA)
if config.use_flash_attention and torch.cuda.is_available():
try:
from flash_attn.flash_attention import FlashAttention # noqa: F401
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention (only on CUDA)
if torch.cuda.is_available():
try:
from xformers.ops import memory_efficient_attention # noqa: F401
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using GPU events or CPU timing."""
def __init__(self):
if torch.cuda.is_available():
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.use_gpu_timing = True
elif torch.backends.mps.is_available():
# MPS doesn't have events, use CPU timing
self.use_gpu_timing = False
else:
# CPU timing
self.use_gpu_timing = False
@contextmanager
def timing(self):
if self.use_gpu_timing:
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
else:
# Use CPU timing for MPS/CPU
start_time = time.time()
yield
self.cpu_elapsed = time.time() - start_time
def elapsed_time(self) -> float:
if self.use_gpu_timing:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
else:
return self.cpu_elapsed
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
# Only use CUDA graphs on NVIDIA GPUs
if config.use_cuda_graphs and torch.cuda.is_available():
self.graphs = GraphContainer(self.model, config.seq_length)
else:
self.graphs = None
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {e!s}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# Check if using custom 8bit quantization
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# Load original model (without quantization config)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
model = AutoModel.from_pretrained(
self.config.model_path,
torch_dtype=compute_dtype,
)
# Define replacement function
def replace_linear_with_linear8bitlt(model):
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# Get original linear layer parameters
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# Create 8bit linear layer
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False,
)
# Copy weights and bias
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# Replace module
setattr(model, name, new_module)
else:
# Process child modules recursively
replace_linear_with_linear8bitlt(module)
return model
# Replace all linear layers
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# Move model to GPU (quantization happens here)
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# Use original Int4 quantization method
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto", # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if (
torch.cuda.is_available()
and torch.version.cuda
and float(torch.version.cuda[:3]) >= 11.6
):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available (only on CUDA)
if torch.cuda.is_available():
try:
if hasattr(model, "enable_xformers_memory_efficient_attention"):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration
elif self.config.use_int8:
print("- Using INT8 quantization")
# For now, just use standard loading with INT8 config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto",
)
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
model.eval()
print("- Model set to eval mode")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {e!s}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
return torch.randint(
0,
1000,
(batch_size, self.config.seq_length),
device=device,
dtype=torch.long,
)
def _run_inference(
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
) -> tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if graph_wrapper is not None:
output = graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> dict[int, dict[str, float]]:
results = {}
# Reset peak memory stats
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
elif torch.backends.mps.is_available():
# MPS doesn't have reset_peak_memory_stats, skip it
pass
else:
print("- No GPU memory stats available")
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create graph for this batch size
graph_wrapper = (
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
elif torch.backends.mps.is_available():
# MPS doesn't have max_memory_allocated, use 0
peak_memory_gb = 0.0
else:
peak_memory_gb = 0.0
print("- No GPU memory usage available")
if peak_memory_gb > 0:
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
else:
print("\n- GPU memory usage not available")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,16,32",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--use_fp16",
action="store_true",
help="Enable FP16 inference",
)
parser.add_argument(
"--use_int4",
action="store_true",
help="Enable INT4 quantization using bitsandbytes",
)
parser.add_argument(
"--use_int8",
action="store_true",
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
)
parser.add_argument(
"--use_linear8bitlt",
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=args.use_fp16,
use_int4=args.use_int4,
use_int8=args.use_int8, # Add this line
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = (
"int4"
if config.use_int4
else "int8"
if config.use_int8
else "fp16"
if config.use_fp16
else "fp32"
)
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
},
"results": {str(k): v for k, v in results.items()},
},
f,
indent=2,
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,359 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import argparse
import json
import sys
import time
from pathlib import Path
import numpy as np
from leann.api import LeannBuilder, LeannSearcher
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
try:
from huggingface_hub import snapshot_download
if download_embeddings:
# Download everything including embeddings (large files)
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
return golden_texts
def load_queries(file_path: Path) -> list[str]:
queries = []
with open(file_path, encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument(
"index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'benchmarks/'
# and evaluation data is in 'benchmarks/data/'.
script_dir = Path(__file__).resolve().parent
data_root = script_dir / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print("No indices found. The data download should have included pre-built indices.")
print(
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file) as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,311 @@
import time
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from transformers import AutoModel
# Add MLX imports
try:
import mlx.core as mx
from mlx_lm.utils import load
MLX_AVAILABLE = True
except ImportError:
print("MLX not available. Install with: uv pip install mlx mlx-lm")
MLX_AVAILABLE = False
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever"
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
use_mlx: bool = False # New flag for MLX testing
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
class MLXBenchmark:
"""MLX-specific benchmark for embedding models"""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model, self.tokenizer = self._load_model()
def _load_model(self):
"""Load MLX model and tokenizer following the API pattern"""
print(f"Loading MLX model from {self.config.model_path}...")
try:
model, tokenizer = load(self.config.model_path)
print("MLX model loaded successfully")
return model, tokenizer
except Exception as e:
print(f"Error loading MLX model: {e}")
raise
def _create_random_batch(self, batch_size: int):
"""Create random input batches for MLX testing - same as PyTorch"""
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
def _run_inference(self, input_ids: torch.Tensor) -> float:
"""Run MLX inference with same input as PyTorch"""
start_time = time.time()
try:
# Convert PyTorch tensor to MLX array
input_ids_mlx = mx.array(input_ids.numpy())
# Get embeddings
embeddings = self.model(input_ids_mlx)
# Mean pooling (following the API pattern)
pooled = embeddings.mean(axis=1)
# Convert to numpy (following the API pattern)
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
# Force computation
_ = pooled_numpy.shape
except Exception as e:
print(f"MLX inference error: {e}")
return float("inf")
end_time = time.time()
return end_time - start_time
def run(self) -> dict[int, dict[str, float]]:
"""Run the MLX benchmark across all batch sizes"""
results = {}
print(f"Starting MLX benchmark with model: {self.config.model_path}")
print(f"Testing batch sizes: {self.config.batch_sizes}")
for batch_size in self.config.batch_sizes:
print(f"\n=== Testing MLX batch size: {batch_size} ===")
times = []
# Create input batch (same as PyTorch)
input_ids = self._create_random_batch(batch_size)
# Warm up
print("Warming up...")
for _ in range(3):
try:
self._run_inference(input_ids[:2]) # Warm up with smaller batch
except Exception as e:
print(f"Warmup error: {e}")
break
# Run benchmark
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
if elapsed_time != float("inf"):
times.append(elapsed_time)
except Exception as e:
print(f"Error during MLX inference: {e}")
break
if not times:
print(f"Skipping batch size {batch_size} due to errors")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
"min_time": np.min(times),
"max_time": np.max(times),
}
print(f"MLX Results for batch size {batch_size}:")
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f" Min Time: {np.min(times):.4f}s")
print(f" Max Time: {np.max(times):.4f}s")
print(f" Throughput: {throughput:.2f} sequences/second")
return results
class Benchmark:
def __init__(self, config: BenchmarkConfig):
self.config = config
self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.model = self._load_model()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
if self.config.use_fp16:
model = model.half()
model = torch.compile(model)
model = model.to(self.device)
model.eval()
return model
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0,
1000,
(batch_size, self.config.seq_length),
device=self.device,
dtype=torch.long,
)
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
start_time = time.time()
with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask)
end_time = time.time()
return end_time - start_time
def run(self) -> dict[int, dict[str, float]]:
results = {}
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
input_ids = self._create_random_batch(batch_size)
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time = self._run_inference(input_ids)
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
continue
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
if torch.cuda.is_available():
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
else:
peak_memory_gb = 0.0
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def run_benchmark():
"""Main function to run the benchmark with optimized parameters."""
config = BenchmarkConfig()
try:
benchmark = Benchmark(config)
results = benchmark.run()
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results,
}
except Exception as e:
print(f"Benchmark failed: {e}")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
def run_mlx_benchmark():
"""Run MLX-specific benchmark"""
if not MLX_AVAILABLE:
print("MLX not available, skipping MLX benchmark")
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "MLX not available",
}
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
try:
benchmark = MLXBenchmark(config)
results = benchmark.run()
if not results:
return {
"max_throughput": 0.0,
"avg_throughput": 0.0,
"error": "No valid results",
}
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
return {
"max_throughput": max_throughput,
"avg_throughput": avg_throughput,
"results": results,
}
except Exception as e:
print(f"MLX benchmark failed: {e}")
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
if __name__ == "__main__":
print("=== PyTorch Benchmark ===")
pytorch_result = run_benchmark()
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
print("\n=== MLX Benchmark ===")
mlx_result = run_mlx_benchmark()
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
# Compare results
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
print("\n=== Comparison ===")
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")