clean dict

This commit is contained in:
yichuan520030910320
2025-07-15 22:30:52 -07:00
parent b1c93fe178
commit 4a2cb914d7
8 changed files with 1 additions and 425 deletions

View File

@@ -0,0 +1,120 @@
# 🧪 Leann Sanity Checks
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
## 📁 Test Files
### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend:
-**MIPS** (Maximum Inner Product Search)
-**L2** (Euclidean Distance)
-**Cosine** (Cosine Similarity)
```bash
uv run python tests/sanity_checks/test_distance_functions.py
```
### `test_l2_verification.py`
Specifically verifies that L2 distance is correctly implemented by:
- Building indices with L2 vs Cosine metrics
- Comparing search results and score ranges
- Validating that different metrics produce expected score patterns
```bash
uv run python tests/sanity_checks/test_l2_verification.py
```
### `test_sanity_check.py`
Comprehensive end-to-end verification including:
- Distance function testing
- Embedding model compatibility
- Search result correctness validation
- Backend integration testing
```bash
uv run python tests/sanity_checks/test_sanity_check.py
```
## 🎯 What These Tests Verify
### ✅ Distance Function Support
- All three distance metrics (MIPS, L2, Cosine) work correctly
- Score ranges are appropriate for each metric type
- Different metrics can produce different rankings (as expected)
### ✅ Backend Integration
- DiskANN backend properly initializes and builds indices
- Graph construction completes without errors
- Search operations return valid results
### ✅ Embedding Pipeline
- Real-time embedding computation works
- Multiple embedding models are supported
- ZMQ server communication functions correctly
### ✅ End-to-End Functionality
- Index building → searching → result retrieval pipeline
- Metadata preservation through the entire flow
- Error handling and graceful degradation
## 🔍 Expected Output
When all tests pass, you should see:
```
📊 测试结果总结:
mips : ✅ 通过
l2 : ✅ 通过
cosine : ✅ 通过
🎉 测试完成!
```
## 🐛 Troubleshooting
### Common Issues
**Import Errors**: Ensure you're running from the project root:
```bash
cd /path/to/leann
uv run python tests/sanity_checks/test_distance_functions.py
```
**Memory Issues**: Reduce graph complexity for resource-constrained systems:
```python
builder = LeannBuilder(
backend_name="diskann",
graph_degree=8, # Reduced from 16
complexity=16 # Reduced from 32
)
```
**ZMQ Port Conflicts**: The tests use different ports to avoid conflicts, but you may need to kill existing processes:
```bash
pkill -f "embedding_server"
```
## 📊 Performance Expectations
### Typical Timing (3 documents, consumer hardware):
- **Index Building**: 2-5 seconds per distance function
- **Search Query**: 50-200ms
- **Recompute Mode**: 5-15 seconds (higher accuracy)
### Memory Usage:
- **Index Storage**: ~1-2 MB per distance function
- **Runtime Memory**: ~500MB (including model loading)
## 🔗 Integration with CI/CD
These tests are designed to be run in automated environments:
```yaml
# GitHub Actions example
- name: Run Sanity Checks
run: |
uv run python tests/sanity_checks/test_distance_functions.py
uv run python tests/sanity_checks/test_l2_verification.py
```
The tests are deterministic and should produce consistent results across different platforms.

View File

@@ -0,0 +1,128 @@
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from sentence_transformers import SentenceTransformer
import mlx.core as mx
from mlx_lm import load
# --- Configuration ---
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
NUM_RUNS = 10 # Number of runs to average for each batch size
WARMUP_RUNS = 2 # Number of warm-up runs
# --- Generate Dummy Data ---
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
# --- Benchmark Functions ---b
def benchmark_torch(model, sentences):
start_time = time.time()
model.encode(sentences, convert_to_numpy=True)
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
def benchmark_mlx(model, tokenizer, sentences):
start_time = time.time()
# Tokenize sentences using MLX tokenizer
tokens = []
for sentence in sentences:
token_ids = tokenizer.encode(sentence)
tokens.append(token_ids)
# Pad sequences to the same length
max_len = max(len(t) for t in tokens)
input_ids = []
attention_mask = []
for token_seq in tokens:
# Pad sequence
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
input_ids.append(padded)
# Create attention mask (1 for real tokens, 0 for padding)
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
attention_mask.append(mask)
# Convert to MLX arrays
input_ids = mx.array(input_ids)
attention_mask = mx.array(attention_mask)
# Get embeddings
embeddings = model(input_ids)
# Mean pooling
mask = mx.expand_dims(attention_mask, -1)
sum_embeddings = (embeddings * mask).sum(axis=1)
sum_mask = mask.sum(axis=1)
_ = sum_embeddings / sum_mask
mx.eval() # Ensure computation is finished
end_time = time.time()
return (end_time - start_time) * 1000 # Return time in ms
# --- Main Execution ---
def main():
print("--- Initializing Models ---")
# Load PyTorch model
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
device = "mps" if torch.backends.mps.is_available() else "cpu"
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
print(f"PyTorch model loaded on: {device}")
# Load MLX model
print(f"Loading MLX model: {MODEL_NAME_MLX}")
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
print("MLX model loaded.")
# --- Warm-up ---
print("\n--- Performing Warm-up Runs ---")
for _ in range(WARMUP_RUNS):
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
print("Warm-up complete.")
# --- Benchmarking ---
print("\n--- Starting Benchmark ---")
results_torch = []
results_mlx = []
for batch_size in BATCH_SIZES:
print(f"Benchmarking batch size: {batch_size}")
sentences_batch = DUMMY_SENTENCES[:batch_size]
# Benchmark PyTorch
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
results_torch.append(np.mean(torch_times))
# Benchmark MLX
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
results_mlx.append(np.mean(mlx_times))
print("\n--- Benchmark Results (Average time per batch in ms) ---")
print(f"Batch Sizes: {BATCH_SIZES}")
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
# --- Plotting ---
print("\n--- Generating Plot ---")
plt.figure(figsize=(10, 6))
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Batch (ms)")
plt.xticks(BATCH_SIZES)
plt.grid(True)
plt.legend()
# Save the plot
output_filename = "embedding_benchmark.png"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python3
"""
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
"""
import zmq
import time
import threading
import sys
sys.path.append('packages/leann-backend-diskann')
from leann_backend_diskann import embedding_pb2
def test_zmq_with_same_model():
print("=== Testing ZMQ with same model as main_cli_example.py ===")
# Test the exact same model that main_cli_example.py uses
model_name = "sentence-transformers/all-mpnet-base-v2"
# Start server with the same model
import subprocess
server_cmd = [
sys.executable, "-m",
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", "5556", # Use different port to avoid conflicts
"--model-name", model_name
]
print(f"Starting server with command: {' '.join(server_cmd)}")
server_process = subprocess.Popen(
server_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# Wait for server to start
print("Waiting for server to start...")
time.sleep(10)
# Check if server is running
if server_process.poll() is not None:
stdout, stderr = server_process.communicate()
print(f"Server failed to start. stdout: {stdout}")
print(f"Server failed to start. stderr: {stderr}")
return False
print(f"Server started with PID: {server_process.pid}")
try:
# Test client
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://127.0.0.1:5556")
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
socket.setsockopt(zmq.SNDTIMEO, 30000)
# Create request with same format as C++
request = embedding_pb2.NodeEmbeddingRequest()
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
print(f"Sending request with {len(request.node_ids)} node IDs...")
start_time = time.time()
# Send request
socket.send(request.SerializeToString())
# Receive response
response_data = socket.recv()
end_time = time.time()
print(f"Received response in {end_time - start_time:.3f} seconds")
print(f"Response size: {len(response_data)} bytes")
# Parse response
response = embedding_pb2.NodeEmbeddingResponse()
response.ParseFromString(response_data)
print(f"Response dimensions: {list(response.dimensions)}")
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
print(f"Missing IDs: {list(response.missing_ids)}")
# Calculate expected size
if len(response.dimensions) == 2:
batch_size = response.dimensions[0]
embedding_dim = response.dimensions[1]
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
if len(response.embeddings_data) == expected_bytes:
print("✅ Response format is correct!")
return True
else:
print("❌ Response format mismatch!")
return False
else:
print("❌ Invalid response dimensions!")
return False
except Exception as e:
print(f"❌ Error during ZMQ test: {e}")
return False
finally:
# Clean up
server_process.terminate()
server_process.wait()
print("Server terminated")
if __name__ == "__main__":
success = test_zmq_with_same_model()
if success:
print("\n✅ ZMQ communication test passed!")
else:
print("\n❌ ZMQ communication test failed!")