refactor: reorgnize all examples/ and test/
This commit is contained in:
120
benchmarks/README.md
Normal file
120
benchmarks/README.md
Normal file
@@ -0,0 +1,120 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
```
|
||||
|
||||
### `test_l2_verification.py`
|
||||
Specifically verifies that L2 distance is correctly implemented by:
|
||||
- Building indices with L2 vs Cosine metrics
|
||||
- Comparing search results and score ranges
|
||||
- Validating that different metrics produce expected score patterns
|
||||
|
||||
```bash
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
```bash
|
||||
uv run python tests/sanity_checks/test_sanity_check.py
|
||||
```
|
||||
|
||||
## 🎯 What These Tests Verify
|
||||
|
||||
### ✅ Distance Function Support
|
||||
- All three distance metrics (MIPS, L2, Cosine) work correctly
|
||||
- Score ranges are appropriate for each metric type
|
||||
- Different metrics can produce different rankings (as expected)
|
||||
|
||||
### ✅ Backend Integration
|
||||
- DiskANN backend properly initializes and builds indices
|
||||
- Graph construction completes without errors
|
||||
- Search operations return valid results
|
||||
|
||||
### ✅ Embedding Pipeline
|
||||
- Real-time embedding computation works
|
||||
- Multiple embedding models are supported
|
||||
- ZMQ server communication functions correctly
|
||||
|
||||
### ✅ End-to-End Functionality
|
||||
- Index building → searching → result retrieval pipeline
|
||||
- Metadata preservation through the entire flow
|
||||
- Error handling and graceful degradation
|
||||
|
||||
## 🔍 Expected Output
|
||||
|
||||
When all tests pass, you should see:
|
||||
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Import Errors**: Ensure you're running from the project root:
|
||||
```bash
|
||||
cd /path/to/leann
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
```
|
||||
|
||||
**Memory Issues**: Reduce graph complexity for resource-constrained systems:
|
||||
```python
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
graph_degree=8, # Reduced from 16
|
||||
complexity=16 # Reduced from 32
|
||||
)
|
||||
```
|
||||
|
||||
**ZMQ Port Conflicts**: The tests use different ports to avoid conflicts, but you may need to kill existing processes:
|
||||
```bash
|
||||
pkill -f "embedding_server"
|
||||
```
|
||||
|
||||
## 📊 Performance Expectations
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
- **Index Storage**: ~1-2 MB per distance function
|
||||
- **Runtime Memory**: ~500MB (including model loading)
|
||||
|
||||
## 🔗 Integration with CI/CD
|
||||
|
||||
These tests are designed to be run in automated environments:
|
||||
|
||||
```yaml
|
||||
# GitHub Actions example
|
||||
- name: Run Sanity Checks
|
||||
run: |
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
141
benchmarks/benchmark_embeddings.py
Normal file
141
benchmarks/benchmark_embeddings.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx_lm import load
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
# --- Benchmark Functions ---b
|
||||
|
||||
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
def benchmark_mlx(model, tokenizer, sentences):
|
||||
start_time = time.time()
|
||||
|
||||
# Tokenize sentences using MLX tokenizer
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
token_ids = tokenizer.encode(sentence)
|
||||
tokens.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_len = max(len(t) for t in tokens)
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
|
||||
for token_seq in tokens:
|
||||
# Pad sequence
|
||||
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||
input_ids.append(padded)
|
||||
# Create attention mask (1 for real tokens, 0 for padding)
|
||||
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||
attention_mask.append(mask)
|
||||
|
||||
# Convert to MLX arrays
|
||||
input_ids = mx.array(input_ids)
|
||||
attention_mask = mx.array(attention_mask)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
mx.eval() # Ensure computation is finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
# Load PyTorch model
|
||||
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
||||
print(f"PyTorch model loaded on: {device}")
|
||||
|
||||
# Load MLX model
|
||||
print(f"Loading MLX model: {MODEL_NAME_MLX}")
|
||||
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
|
||||
print("MLX model loaded.")
|
||||
|
||||
# --- Warm-up ---
|
||||
print("\n--- Performing Warm-up Runs ---")
|
||||
for _ in range(WARMUP_RUNS):
|
||||
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
|
||||
print("Warm-up complete.")
|
||||
|
||||
# --- Benchmarking ---
|
||||
print("\n--- Starting Benchmark ---")
|
||||
results_torch = []
|
||||
results_mlx = []
|
||||
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||
]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
print(f"Batch Sizes: {BATCH_SIZES}")
|
||||
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
||||
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
|
||||
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(
|
||||
BATCH_SIZES,
|
||||
results_torch,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
label=f"PyTorch ({device})",
|
||||
)
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||
|
||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
# Save the plot
|
||||
output_filename = "embedding_benchmark.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
326
benchmarks/compare_faiss_vs_leann.py
Normal file
326
benchmarks/compare_faiss_vs_leann.py
Normal file
@@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
"""Get current memory usage in MB"""
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
def print_memory_stats(stage: str, start_mem: float):
|
||||
"""Print memory statistics"""
|
||||
current_mem = get_memory_usage()
|
||||
diff = current_mem - start_mem
|
||||
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||
return current_mem
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.start_mem = get_memory_usage()
|
||||
self.stages = []
|
||||
|
||||
def checkpoint(self, stage: str):
|
||||
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
||||
self.stages.append((stage, current_mem))
|
||||
return current_mem
|
||||
|
||||
def summary(self):
|
||||
print(f"\n=== {self.name} Memory Summary ===")
|
||||
for stage, mem in self.stages:
|
||||
print(f"{stage}: {mem:.1f} MB")
|
||||
peak_mem = max(mem for _, mem in self.stages)
|
||||
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
||||
return peak_mem
|
||||
|
||||
|
||||
def test_faiss_hnsw():
|
||||
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING FAISS HNSW VECTOR STORE")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Stderr:", result.stderr)
|
||||
|
||||
if result.returncode != 0:
|
||||
return {
|
||||
"peak_memory": float("inf"),
|
||||
"error": f"Process failed with code {result.returncode}",
|
||||
}
|
||||
|
||||
# Parse peak memory from output
|
||||
lines = result.stdout.split("\n")
|
||||
peak_memory = 0.0
|
||||
|
||||
for line in lines:
|
||||
if "Peak Memory:" in line:
|
||||
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||
|
||||
return {"peak_memory": peak_memory}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"peak_memory": float("inf"),
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
def test_leann_hnsw():
|
||||
"""Test LEANN HNSW Search Memory (load existing index)"""
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING LEANN HNSW SEARCH MEMORY")
|
||||
print("=" * 50)
|
||||
|
||||
tracker = MemoryTracker("LEANN HNSW Search")
|
||||
|
||||
# Import and setup
|
||||
tracker.checkpoint("Initial")
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
tracker.checkpoint("After imports")
|
||||
|
||||
from leann.api import LeannBuilder
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data()
|
||||
|
||||
tracker.checkpoint("After document loading")
|
||||
|
||||
# Parse into chunks
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
print(f"Total number of chunks: {len(all_texts)}")
|
||||
|
||||
tracker.checkpoint("After text chunking")
|
||||
|
||||
# Build LEANN index
|
||||
INDEX_DIR = Path("./test_leann_comparison")
|
||||
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||
print("Loading existing LEANN HNSW index...")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
else:
|
||||
print("Building new LEANN HNSW index...")
|
||||
# Clean up previous index
|
||||
import shutil
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
tracker.checkpoint("After builder setup")
|
||||
|
||||
print("Building LEANN HNSW index...")
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
del builder
|
||||
gc.collect()
|
||||
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
# Find existing LEANN index
|
||||
index_paths = [
|
||||
"./test_leann_comparison/comparison.leann",
|
||||
]
|
||||
index_path = None
|
||||
for path in index_paths:
|
||||
if os.path.exists(path + ".meta.json"):
|
||||
index_path = path
|
||||
break
|
||||
|
||||
if not index_path:
|
||||
print("❌ LEANN index not found. Please build it first")
|
||||
return {"peak_memory": float("inf"), "error": "Index not found"}
|
||||
|
||||
# Measure runtime memory overhead
|
||||
print("\nMeasuring runtime memory overhead...")
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
# Load searcher
|
||||
searcher = LeannSearcher(index_path)
|
||||
tracker.checkpoint("After searcher loading")
|
||||
|
||||
print("Running search queries...")
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
start_time = time.time()
|
||||
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
||||
_ = searcher.search(query, top_k=20, ef=120)
|
||||
query_time = time.time() - start_time
|
||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||
tracker.checkpoint(f"After query {i + 1}")
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
|
||||
# Get storage size before cleanup
|
||||
storage_size = 0
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
if INDEX_DIR.exists():
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
||||
for filename in filenames:
|
||||
# Only count actual index files, skip text data and backups
|
||||
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
||||
continue
|
||||
# Count .index, .idx, .map files (actual index structures)
|
||||
if filename.endswith((".index", ".idx", ".map")):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
total_size += os.path.getsize(filepath)
|
||||
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||
|
||||
# Clean up
|
||||
del searcher
|
||||
gc.collect()
|
||||
|
||||
return {
|
||||
"peak_memory": peak_memory,
|
||||
"storage_size": storage_size,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Run comparison tests"""
|
||||
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
||||
print("=" * 60)
|
||||
|
||||
# Test Faiss HNSW
|
||||
faiss_results = test_faiss_hnsw()
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
time.sleep(2)
|
||||
|
||||
# Test LEANN HNSW
|
||||
leann_results = test_leann_hnsw()
|
||||
|
||||
# Final comparison
|
||||
print("\n" + "=" * 60)
|
||||
print("STORAGE + SEARCH MEMORY COMPARISON")
|
||||
print("=" * 60)
|
||||
|
||||
# Get storage sizes
|
||||
faiss_storage_size = 0
|
||||
leann_storage_size = leann_results.get("storage_size", 0)
|
||||
|
||||
# Get Faiss storage size using Python
|
||||
if os.path.exists("./storage_faiss"):
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
||||
for filename in filenames:
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
total_size += os.path.getsize(filepath)
|
||||
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||
|
||||
print("Faiss HNSW:")
|
||||
if "error" in faiss_results:
|
||||
print(f" ❌ Failed: {faiss_results['error']}")
|
||||
else:
|
||||
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||
|
||||
print("\nLEANN HNSW:")
|
||||
if "error" in leann_results:
|
||||
print(f" ❌ Failed: {leann_results['error']}")
|
||||
else:
|
||||
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||
|
||||
# Calculate improvements only if both tests succeeded
|
||||
if "error" not in faiss_results and "error" not in leann_results:
|
||||
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||
|
||||
print("\nLEANN vs Faiss Performance:")
|
||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||
|
||||
# Storage comparison
|
||||
if leann_storage_size > faiss_storage_size:
|
||||
storage_ratio = leann_storage_size / faiss_storage_size
|
||||
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||
elif faiss_storage_size > leann_storage_size:
|
||||
storage_ratio = faiss_storage_size / leann_storage_size
|
||||
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||
else:
|
||||
print(" Storage Size: similar")
|
||||
else:
|
||||
if "error" not in leann_results:
|
||||
print("\n✅ LEANN HNSW completed successfully!")
|
||||
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
||||
if "error" not in faiss_results:
|
||||
print("\n✅ Faiss HNSW completed successfully!")
|
||||
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
151
benchmarks/faiss_only.py
Normal file
151
benchmarks/faiss_only.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test only Faiss HNSW"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
process = psutil.Process()
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.start_mem = get_memory_usage()
|
||||
self.stages = []
|
||||
|
||||
def checkpoint(self, stage: str):
|
||||
current_mem = get_memory_usage()
|
||||
diff = current_mem - self.start_mem
|
||||
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||
self.stages.append((stage, current_mem))
|
||||
return current_mem
|
||||
|
||||
def summary(self):
|
||||
peak_mem = max(mem for _, mem in self.stages)
|
||||
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||
return peak_mem
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print(
|
||||
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
Settings,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
tracker = MemoryTracker("Faiss HNSW")
|
||||
tracker.checkpoint("Initial")
|
||||
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
Settings.embed_model = embed_model
|
||||
tracker.checkpoint("After embedding model setup")
|
||||
|
||||
d = 768
|
||||
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
||||
faiss_index.hnsw.efConstruction = 64
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data()
|
||||
tracker.checkpoint("After document loading")
|
||||
|
||||
# Parse into chunks using the same splitter as LEANN
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
tracker.checkpoint("After text splitter setup")
|
||||
|
||||
# Check if index already exists and try to load it
|
||||
index_loaded = False
|
||||
if os.path.exists("./storage_faiss"):
|
||||
print("Loading existing Faiss HNSW index...")
|
||||
try:
|
||||
# Use the correct Faiss loading pattern from the example
|
||||
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
||||
storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||
)
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
print("Index loaded from ./storage_faiss")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
index_loaded = True
|
||||
except Exception as e:
|
||||
print(f"Failed to load existing index: {e}")
|
||||
print("Cleaning up corrupted index and building new one...")
|
||||
# Clean up corrupted index
|
||||
import shutil
|
||||
|
||||
if os.path.exists("./storage_faiss"):
|
||||
shutil.rmtree("./storage_faiss")
|
||||
|
||||
if not index_loaded:
|
||||
print("Building new Faiss HNSW index...")
|
||||
|
||||
# Use the correct Faiss building pattern from the example
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, storage_context=storage_context, transformations=[node_parser]
|
||||
)
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
# Save index to disk using the correct pattern
|
||||
index.storage_context.persist(persist_dir="./storage_faiss")
|
||||
tracker.checkpoint("After index saving")
|
||||
|
||||
# Measure runtime memory overhead
|
||||
print("\nMeasuring runtime memory overhead...")
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
start_time = time.time()
|
||||
_ = query_engine.query(query)
|
||||
query_time = time.time() - start_time
|
||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||
tracker.checkpoint(f"After query {i + 1}")
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
659
benchmarks/micro_tpt.py
Normal file
659
benchmarks/micro_tpt.py
Normal file
@@ -0,0 +1,659 @@
|
||||
# python embedd_micro.py --use_int8 Fastest
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: list[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
use_int4: bool = False
|
||||
use_int8: bool = False # Add this parameter
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
|
||||
|
||||
class GraphContainer:
|
||||
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: dict[int, GraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class GraphWrapper:
|
||||
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.device = self._get_device()
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
self.use_cuda_graph = True
|
||||
else:
|
||||
# For MPS or CPU, just store the model
|
||||
self.use_cuda_graph = False
|
||||
self.static_output = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||
)
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_cuda_graph:
|
||||
self.static_input.copy_(input_ids)
|
||||
self.static_attention_mask.copy_(attention_mask)
|
||||
self.graph.replay()
|
||||
return self.static_output
|
||||
else:
|
||||
# For MPS/CPU, just run normally
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
device = "mps"
|
||||
else:
|
||||
model = model.cpu()
|
||||
device = "cpu"
|
||||
print(f"- Model moved to {device}")
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
# Check if using SDPA (only on CUDA)
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Flash Attention (only on CUDA)
|
||||
if config.use_flash_attention and torch.cuda.is_available():
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
# Memory efficient attention (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using GPU events or CPU timing."""
|
||||
|
||||
def __init__(self):
|
||||
if torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
self.use_gpu_timing = True
|
||||
elif torch.backends.mps.is_available():
|
||||
# MPS doesn't have events, use CPU timing
|
||||
self.use_gpu_timing = False
|
||||
else:
|
||||
# CPU timing
|
||||
self.use_gpu_timing = False
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
if self.use_gpu_timing:
|
||||
self.start_event.record()
|
||||
yield
|
||||
self.end_event.record()
|
||||
self.end_event.synchronize()
|
||||
else:
|
||||
# Use CPU timing for MPS/CPU
|
||||
start_time = time.time()
|
||||
yield
|
||||
self.cpu_elapsed = time.time() - start_time
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
if self.use_gpu_timing:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
else:
|
||||
return self.cpu_elapsed
|
||||
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if config.use_cuda_graphs and torch.cuda.is_available():
|
||||
self.graphs = GraphContainer(self.model, config.seq_length)
|
||||
else:
|
||||
self.graphs = None
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {e!s}")
|
||||
raise
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# Check if using custom 8bit quantization
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# Load original model (without quantization config)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# Define replacement function
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Get original linear layer parameters
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# Create 8bit linear layer
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False,
|
||||
)
|
||||
|
||||
# Copy weights and bias
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# Replace module
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# Process child modules recursively
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
return model
|
||||
|
||||
# Replace all linear layers
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# Move model to GPU (quantization happens here)
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
model = model.to(device)
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
else:
|
||||
# Use original Int4 quantization method
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
# Create quantization config
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto", # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
# Try xformers if available (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
# Int8 quantization using HuggingFace integration
|
||||
elif self.config.use_int8:
|
||||
print("- Using INT8 quantization")
|
||||
# For now, just use standard loading with INT8 config
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {e!s}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
return torch.randint(
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||
) -> tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if graph_wrapper is not None:
|
||||
output = graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
# Reset peak memory stats
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
elif torch.backends.mps.is_available():
|
||||
# MPS doesn't have reset_peak_memory_stats, skip it
|
||||
pass
|
||||
else:
|
||||
print("- No GPU memory stats available")
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
# Get or create graph for this batch size
|
||||
graph_wrapper = (
|
||||
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||
)
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time, output = self._run_inference(input_ids, graph_wrapper)
|
||||
if i == 0: # Only print on first run
|
||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
# Log memory usage
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
elif torch.backends.mps.is_available():
|
||||
# MPS doesn't have max_memory_allocated, use 0
|
||||
peak_memory_gb = 0.0
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
print("- No GPU memory usage available")
|
||||
|
||||
if peak_memory_gb > 0:
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
else:
|
||||
print("\n- GPU memory usage not available")
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Path to the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=str,
|
||||
default="1,2,4,8,16,32",
|
||||
help="Comma-separated list of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Sequence length for input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_runs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of runs for each batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
action="store_true",
|
||||
help="Enable FP16 inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int4",
|
||||
action="store_true",
|
||||
help="Enable INT4 quantization using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_int8",
|
||||
action="store_true",
|
||||
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda_graphs",
|
||||
action="store_true",
|
||||
help="Enable CUDA Graphs optimization (only on NVIDIA GPUs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attention",
|
||||
action="store_true",
|
||||
help="Enable Flash Attention 2 if available (only on NVIDIA GPUs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_linear8bitlt",
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
seq_length=args.seq_length,
|
||||
num_runs=args.num_runs,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int4=args.use_int4,
|
||||
use_int8=args.use_int8, # Add this line
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = (
|
||||
"int4"
|
||||
if config.use_int4
|
||||
else "int8"
|
||||
if config.use_int8
|
||||
else "fp16"
|
||||
if config.use_fp16
|
||||
else "fp32"
|
||||
)
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {
|
||||
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||
},
|
||||
"results": {str(k): v for k, v in results.items()},
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
359
benchmarks/run_evaluation.py
Normal file
359
benchmarks/run_evaluation.py
Normal file
@@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script runs a recall evaluation on a given LEANN index.
|
||||
It correctly compares results by fetching the text content for both the new search
|
||||
results and the golden standard results, making the comparison robust to ID changes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if download_embeddings:
|
||||
# Download everything including embeddings (large files)
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("Data download complete (including embeddings)!")
|
||||
else:
|
||||
# Download only specific folders, excluding embeddings
|
||||
allow_patterns = [
|
||||
"ground_truth/**",
|
||||
"indices/**",
|
||||
"queries/**",
|
||||
"*.md",
|
||||
"*.txt",
|
||||
]
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
print("Data download complete (excluding embeddings)!")
|
||||
except ImportError:
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv pip install -e '.[dev]'")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
if dataset_type:
|
||||
# Check if specific dataset embeddings exist
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
print(f"Embeddings for {dataset_type} already exist")
|
||||
return str(target_file)
|
||||
|
||||
print("Downloading embeddings from HuggingFace Hub...")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download only embeddings folder
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir=data_root,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=["embeddings/**/*.pkl"],
|
||||
)
|
||||
print("Embeddings download complete!")
|
||||
|
||||
if dataset_type:
|
||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||
if target_file.exists():
|
||||
return str(target_file)
|
||||
|
||||
return str(embeddings_dir)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error downloading embeddings: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
"""
|
||||
golden_texts = set()
|
||||
for gid in golden_ids:
|
||||
try:
|
||||
# PassageManager uses string IDs
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data["text"])
|
||||
except KeyError:
|
||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||
return golden_texts
|
||||
|
||||
|
||||
def load_queries(file_path: Path) -> list[str]:
|
||||
queries = []
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
Args:
|
||||
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||
output_path: Path where to save the index
|
||||
backend: Backend to use ("hnsw" or "diskann")
|
||||
"""
|
||||
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||
|
||||
# Create builder with appropriate parameters
|
||||
if backend == "hnsw":
|
||||
builder_kwargs = {
|
||||
"M": 32, # Graph degree
|
||||
"efConstruction": 256, # Construction complexity
|
||||
"is_compact": True, # Use compact storage
|
||||
"is_recompute": True, # Enable pruning for better recall
|
||||
}
|
||||
elif backend == "diskann":
|
||||
builder_kwargs = {
|
||||
"complexity": 64,
|
||||
"graph_degree": 32,
|
||||
"search_memory_maximum": 8.0, # GB
|
||||
"build_memory_maximum": 16.0, # GB
|
||||
}
|
||||
else:
|
||||
builder_kwargs = {}
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||
dimensions=768, # Will be auto-detected from embeddings
|
||||
**builder_kwargs,
|
||||
)
|
||||
|
||||
# Build index from precomputed embeddings
|
||||
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||
print(f"Index saved to: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||
parser.add_argument(
|
||||
"index_path",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="Path to the LEANN index to evaluate or build (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["evaluate", "build"],
|
||||
default="evaluate",
|
||||
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embeddings-file",
|
||||
type=str,
|
||||
help="Path to embeddings pickle file (optional for build mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["hnsw", "diskann"],
|
||||
default="hnsw",
|
||||
help="Backend to use for building index (default: hnsw)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
)
|
||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
embeddings_file = args.embeddings_file
|
||||
# Try to detect dataset type from embeddings file path
|
||||
if "rpj_wiki" in str(embeddings_file):
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in str(embeddings_file):
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default
|
||||
else:
|
||||
# Auto-detect from index path if provided, otherwise default to DPR
|
||||
if args.index_path:
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
else:
|
||||
dataset_type = "dpr" # Default to DPR
|
||||
|
||||
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||
|
||||
# Auto-generate index path if not provided
|
||||
if not args.index_path:
|
||||
indices_dir = data_root / "indices" / dataset_type
|
||||
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||
print(f"Auto-generated index path: {args.index_path}")
|
||||
|
||||
print(f"Building index from embeddings: {embeddings_file}")
|
||||
built_index_path = build_index_from_embeddings(
|
||||
embeddings_file, args.index_path, args.backend
|
||||
)
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
else:
|
||||
# For evaluation mode, don't need embeddings
|
||||
download_data_if_needed(data_root, download_embeddings=False)
|
||||
|
||||
# Auto-detect index path if not provided
|
||||
if not args.index_path:
|
||||
# Default to using downloaded indices
|
||||
indices_dir = data_root / "indices"
|
||||
|
||||
# Try common datasets in order of preference
|
||||
for dataset in ["dpr", "rpj_wiki"]:
|
||||
dataset_dir = indices_dir / dataset
|
||||
if dataset_dir.exists():
|
||||
# Look for index files
|
||||
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||
dataset_dir.glob("*_disk.index")
|
||||
)
|
||||
if index_files:
|
||||
args.index_path = str(
|
||||
index_files[0].with_suffix("")
|
||||
) # Remove .index extension
|
||||
print(f"Using index: {args.index_path}")
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print(
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
index_path_str = str(args.index_path)
|
||||
if "rpj_wiki" in index_path_str:
|
||||
dataset_type = "rpj_wiki"
|
||||
elif "dpr" in index_path_str:
|
||||
dataset_type = "dpr"
|
||||
else:
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file) as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
queries = queries[:num_eval_queries]
|
||||
|
||||
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||
recall_scores = []
|
||||
search_times = []
|
||||
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
# Get golden texts directly from the searcher's passage manager
|
||||
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||
|
||||
overlap = len(new_texts & golden_texts)
|
||||
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||
recall_scores.append(recall)
|
||||
|
||||
print("\n--- EVALUATION RESULTS ---")
|
||||
print(f"Query: {queries[i]}")
|
||||
print(f"New Results: {new_texts}")
|
||||
print(f"Golden Results: {golden_texts}")
|
||||
print(f"Overlap: {overlap}")
|
||||
print(f"Recall: {recall}")
|
||||
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||
print("--------------------------------")
|
||||
|
||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||
avg_time = np.mean(search_times) if search_times else 0
|
||||
|
||||
print("\n🎉 --- Evaluation Complete ---")
|
||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
311
benchmarks/simple_mac_tpt_test.py
Normal file
311
benchmarks/simple_mac_tpt_test.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModel
|
||||
|
||||
# Add MLX imports
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError:
|
||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||
MLX_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str = "facebook/contriever"
|
||||
batch_sizes: list[int] = None
|
||||
seq_length: int = 256
|
||||
num_runs: int = 5
|
||||
use_fp16: bool = True
|
||||
use_int4: bool = False
|
||||
use_int8: bool = False
|
||||
use_cuda_graphs: bool = False
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
use_mlx: bool = False # New flag for MLX testing
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_sizes is None:
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
|
||||
|
||||
class MLXBenchmark:
|
||||
"""MLX-specific benchmark for embedding models"""
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model, self.tokenizer = self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load MLX model and tokenizer following the API pattern"""
|
||||
print(f"Loading MLX model from {self.config.model_path}...")
|
||||
try:
|
||||
model, tokenizer = load(self.config.model_path)
|
||||
print("MLX model loaded successfully")
|
||||
return model, tokenizer
|
||||
except Exception as e:
|
||||
print(f"Error loading MLX model: {e}")
|
||||
raise
|
||||
|
||||
def _create_random_batch(self, batch_size: int):
|
||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
"""Run MLX inference with same input as PyTorch"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Convert PyTorch tensor to MLX array
|
||||
input_ids_mlx = mx.array(input_ids.numpy())
|
||||
|
||||
# Get embeddings
|
||||
embeddings = self.model(input_ids_mlx)
|
||||
|
||||
# Mean pooling (following the API pattern)
|
||||
pooled = embeddings.mean(axis=1)
|
||||
|
||||
# Convert to numpy (following the API pattern)
|
||||
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||
|
||||
# Force computation
|
||||
_ = pooled_numpy.shape
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX inference error: {e}")
|
||||
return float("inf")
|
||||
end_time = time.time()
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
"""Run the MLX benchmark across all batch sizes"""
|
||||
results = {}
|
||||
|
||||
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||
times = []
|
||||
|
||||
# Create input batch (same as PyTorch)
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
# Warm up
|
||||
print("Warming up...")
|
||||
for _ in range(3):
|
||||
try:
|
||||
self._run_inference(input_ids[:2]) # Warm up with smaller batch
|
||||
except Exception as e:
|
||||
print(f"Warmup error: {e}")
|
||||
break
|
||||
|
||||
# Run benchmark
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
if elapsed_time != float("inf"):
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during MLX inference: {e}")
|
||||
break
|
||||
|
||||
if not times:
|
||||
print(f"Skipping batch size {batch_size} due to errors")
|
||||
continue
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
"min_time": np.min(times),
|
||||
"max_time": np.max(times),
|
||||
}
|
||||
|
||||
print(f"MLX Results for batch size {batch_size}:")
|
||||
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f" Min Time: {np.min(times):.4f}s")
|
||||
print(f" Max Time: {np.max(times):.4f}s")
|
||||
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
self.model = self._load_model()
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
if self.config.use_fp16:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
model = model.to(self.device)
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
end_time = time.time()
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
if not times:
|
||||
continue
|
||||
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_benchmark():
|
||||
"""Main function to run the benchmark with optimized parameters."""
|
||||
config = BenchmarkConfig()
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
def run_mlx_benchmark():
|
||||
"""Run MLX-specific benchmark"""
|
||||
if not MLX_AVAILABLE:
|
||||
print("MLX not available, skipping MLX benchmark")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "MLX not available",
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||
|
||||
try:
|
||||
benchmark = MLXBenchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "No valid results",
|
||||
}
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX benchmark failed: {e}")
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== PyTorch Benchmark ===")
|
||||
pytorch_result = run_benchmark()
|
||||
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
print("\n=== MLX Benchmark ===")
|
||||
mlx_result = run_mlx_benchmark()
|
||||
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
# Compare results
|
||||
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||
print("\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
Reference in New Issue
Block a user