clean dict
This commit is contained in:
BIN
tests/.DS_Store
vendored
BIN
tests/.DS_Store
vendored
Binary file not shown.
@@ -1,120 +0,0 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
```
|
||||
|
||||
### `test_l2_verification.py`
|
||||
Specifically verifies that L2 distance is correctly implemented by:
|
||||
- Building indices with L2 vs Cosine metrics
|
||||
- Comparing search results and score ranges
|
||||
- Validating that different metrics produce expected score patterns
|
||||
|
||||
```bash
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
```bash
|
||||
uv run python tests/sanity_checks/test_sanity_check.py
|
||||
```
|
||||
|
||||
## 🎯 What These Tests Verify
|
||||
|
||||
### ✅ Distance Function Support
|
||||
- All three distance metrics (MIPS, L2, Cosine) work correctly
|
||||
- Score ranges are appropriate for each metric type
|
||||
- Different metrics can produce different rankings (as expected)
|
||||
|
||||
### ✅ Backend Integration
|
||||
- DiskANN backend properly initializes and builds indices
|
||||
- Graph construction completes without errors
|
||||
- Search operations return valid results
|
||||
|
||||
### ✅ Embedding Pipeline
|
||||
- Real-time embedding computation works
|
||||
- Multiple embedding models are supported
|
||||
- ZMQ server communication functions correctly
|
||||
|
||||
### ✅ End-to-End Functionality
|
||||
- Index building → searching → result retrieval pipeline
|
||||
- Metadata preservation through the entire flow
|
||||
- Error handling and graceful degradation
|
||||
|
||||
## 🔍 Expected Output
|
||||
|
||||
When all tests pass, you should see:
|
||||
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Import Errors**: Ensure you're running from the project root:
|
||||
```bash
|
||||
cd /path/to/leann
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
```
|
||||
|
||||
**Memory Issues**: Reduce graph complexity for resource-constrained systems:
|
||||
```python
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
graph_degree=8, # Reduced from 16
|
||||
complexity=16 # Reduced from 32
|
||||
)
|
||||
```
|
||||
|
||||
**ZMQ Port Conflicts**: The tests use different ports to avoid conflicts, but you may need to kill existing processes:
|
||||
```bash
|
||||
pkill -f "embedding_server"
|
||||
```
|
||||
|
||||
## 📊 Performance Expectations
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
- **Index Storage**: ~1-2 MB per distance function
|
||||
- **Runtime Memory**: ~500MB (including model loading)
|
||||
|
||||
## 🔗 Integration with CI/CD
|
||||
|
||||
These tests are designed to be run in automated environments:
|
||||
|
||||
```yaml
|
||||
# GitHub Actions example
|
||||
- name: Run Sanity Checks
|
||||
run: |
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
@@ -1,128 +0,0 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import mlx.core as mx
|
||||
from mlx_lm import load
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
# --- Benchmark Functions ---b
|
||||
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
def benchmark_mlx(model, tokenizer, sentences):
|
||||
start_time = time.time()
|
||||
|
||||
# Tokenize sentences using MLX tokenizer
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
token_ids = tokenizer.encode(sentence)
|
||||
tokens.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_len = max(len(t) for t in tokens)
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
|
||||
for token_seq in tokens:
|
||||
# Pad sequence
|
||||
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||
input_ids.append(padded)
|
||||
# Create attention mask (1 for real tokens, 0 for padding)
|
||||
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||
attention_mask.append(mask)
|
||||
|
||||
# Convert to MLX arrays
|
||||
input_ids = mx.array(input_ids)
|
||||
attention_mask = mx.array(attention_mask)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
mx.eval() # Ensure computation is finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
# Load PyTorch model
|
||||
print(f"Loading PyTorch model: {MODEL_NAME_TORCH}")
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
model_torch = SentenceTransformer(MODEL_NAME_TORCH, device=device)
|
||||
print(f"PyTorch model loaded on: {device}")
|
||||
|
||||
# Load MLX model
|
||||
print(f"Loading MLX model: {MODEL_NAME_MLX}")
|
||||
model_mlx, tokenizer_mlx = load(MODEL_NAME_MLX)
|
||||
print("MLX model loaded.")
|
||||
|
||||
# --- Warm-up ---
|
||||
print("\n--- Performing Warm-up Runs ---")
|
||||
for _ in range(WARMUP_RUNS):
|
||||
benchmark_torch(model_torch, DUMMY_SENTENCES[:1])
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, DUMMY_SENTENCES[:1])
|
||||
print("Warm-up complete.")
|
||||
|
||||
# --- Benchmarking ---
|
||||
print("\n--- Starting Benchmark ---")
|
||||
results_torch = []
|
||||
results_mlx = []
|
||||
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
print(f"Batch Sizes: {BATCH_SIZES}")
|
||||
print(f"PyTorch (mps): {[f'{t:.2f}' for t in results_torch]}")
|
||||
print(f"MLX: {[f'{t:.2f}' for t in results_mlx]}")
|
||||
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
||||
|
||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
# Save the plot
|
||||
output_filename = "embedding_benchmark.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
|
||||
"""
|
||||
|
||||
import zmq
|
||||
import time
|
||||
import threading
|
||||
import sys
|
||||
sys.path.append('packages/leann-backend-diskann')
|
||||
from leann_backend_diskann import embedding_pb2
|
||||
|
||||
def test_zmq_with_same_model():
|
||||
print("=== Testing ZMQ with same model as main_cli_example.py ===")
|
||||
|
||||
# Test the exact same model that main_cli_example.py uses
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
# Start server with the same model
|
||||
import subprocess
|
||||
server_cmd = [
|
||||
sys.executable, "-m",
|
||||
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||
"--zmq-port", "5556", # Use different port to avoid conflicts
|
||||
"--model-name", model_name
|
||||
]
|
||||
|
||||
print(f"Starting server with command: {' '.join(server_cmd)}")
|
||||
server_process = subprocess.Popen(
|
||||
server_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
print("Waiting for server to start...")
|
||||
time.sleep(10)
|
||||
|
||||
# Check if server is running
|
||||
if server_process.poll() is not None:
|
||||
stdout, stderr = server_process.communicate()
|
||||
print(f"Server failed to start. stdout: {stdout}")
|
||||
print(f"Server failed to start. stderr: {stderr}")
|
||||
return False
|
||||
|
||||
print(f"Server started with PID: {server_process.pid}")
|
||||
|
||||
try:
|
||||
# Test client
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect("tcp://127.0.0.1:5556")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
# Create request with same format as C++
|
||||
request = embedding_pb2.NodeEmbeddingRequest()
|
||||
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
|
||||
|
||||
print(f"Sending request with {len(request.node_ids)} node IDs...")
|
||||
start_time = time.time()
|
||||
|
||||
# Send request
|
||||
socket.send(request.SerializeToString())
|
||||
|
||||
# Receive response
|
||||
response_data = socket.recv()
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Received response in {end_time - start_time:.3f} seconds")
|
||||
print(f"Response size: {len(response_data)} bytes")
|
||||
|
||||
# Parse response
|
||||
response = embedding_pb2.NodeEmbeddingResponse()
|
||||
response.ParseFromString(response_data)
|
||||
|
||||
print(f"Response dimensions: {list(response.dimensions)}")
|
||||
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
|
||||
print(f"Missing IDs: {list(response.missing_ids)}")
|
||||
|
||||
# Calculate expected size
|
||||
if len(response.dimensions) == 2:
|
||||
batch_size = response.dimensions[0]
|
||||
embedding_dim = response.dimensions[1]
|
||||
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
|
||||
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
|
||||
|
||||
if len(response.embeddings_data) == expected_bytes:
|
||||
print("✅ Response format is correct!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Response format mismatch!")
|
||||
return False
|
||||
else:
|
||||
print("❌ Invalid response dimensions!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during ZMQ test: {e}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
server_process.terminate()
|
||||
server_process.wait()
|
||||
print("Server terminated")
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_zmq_with_same_model()
|
||||
if success:
|
||||
print("\n✅ ZMQ communication test passed!")
|
||||
else:
|
||||
print("\n❌ ZMQ communication test failed!")
|
||||
@@ -1,107 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN 距离函数测试
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
|
||||
# 导入后端包以触发插件注册
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# 从 leann-core 导入上层 API
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def load_sample_documents():
|
||||
"""创建用于演示的样本文档"""
|
||||
docs = [
|
||||
{"title": "Intro to Python", "content": "Python is a programming language for machine learning"},
|
||||
{"title": "ML Basics", "content": "Machine learning algorithms build intelligent systems"},
|
||||
{"title": "Data Structures", "content": "Data structures like arrays and graphs organize information"},
|
||||
]
|
||||
return docs
|
||||
|
||||
|
||||
def test_distance_function(distance_func, test_name):
|
||||
"""测试特定距离函数"""
|
||||
print(f"\n=== 测试 {test_name} ({distance_func}) ===")
|
||||
|
||||
INDEX_DIR = Path(f"./test_indices_{distance_func}")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# 构建索引
|
||||
print(f"构建索引 (距离函数: {distance_func})...")
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
distance_metric=distance_func,
|
||||
graph_degree=16,
|
||||
complexity=32
|
||||
)
|
||||
|
||||
documents = load_sample_documents()
|
||||
for doc in documents:
|
||||
builder.add_text(doc["content"], metadata=doc)
|
||||
|
||||
try:
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✅ 索引构建成功")
|
||||
|
||||
# 测试搜索
|
||||
searcher = LeannSearcher(INDEX_PATH, distance_metric=distance_func)
|
||||
results = searcher.search("machine learning programming", top_k=2)
|
||||
|
||||
print(f"搜索结果:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i+1}. Score: {result['score']:.4f}")
|
||||
print(f" Text: {result['text'][:50]}...")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
print("🔍 DiskANN 距离函数测试")
|
||||
print("=" * 50)
|
||||
|
||||
# 测试不同距离函数
|
||||
distance_tests = [
|
||||
("mips", "Maximum Inner Product Search"),
|
||||
("l2", "L2 Euclidean Distance"),
|
||||
("cosine", "Cosine Similarity")
|
||||
]
|
||||
|
||||
results = {}
|
||||
for distance_func, test_name in distance_tests:
|
||||
try:
|
||||
success = test_distance_function(distance_func, test_name)
|
||||
results[distance_func] = success
|
||||
except Exception as e:
|
||||
print(f"❌ {distance_func} 测试异常: {e}")
|
||||
results[distance_func] = False
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 50)
|
||||
print("📊 测试结果总结:")
|
||||
for distance_func, success in results.items():
|
||||
status = "✅ 通过" if success else "❌ 失败"
|
||||
print(f" {distance_func:10s}: {status}")
|
||||
|
||||
print("\n🎉 测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,127 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
验证DiskANN L2距离是否真正工作
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
# 导入后端包以触发插件注册
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
def test_l2_verification():
|
||||
"""验证L2距离是否真正被使用"""
|
||||
print("=== 验证DiskANN L2距离实现 ===")
|
||||
|
||||
INDEX_DIR = Path("./test_l2_verification")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# 创建特殊的测试文档,使L2和cosine产生不同结果
|
||||
documents = [
|
||||
"machine learning artificial intelligence", # 文档0
|
||||
"computer programming software development", # 文档1
|
||||
"data science analytics statistics" # 文档2
|
||||
]
|
||||
|
||||
print("构建索引...")
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
distance_metric="l2", # 明确指定L2
|
||||
graph_degree=16,
|
||||
complexity=32
|
||||
)
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
builder.add_text(doc, metadata={"id": i, "text": doc})
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("✅ 索引构建完成")
|
||||
|
||||
# 测试搜索
|
||||
searcher = LeannSearcher(INDEX_PATH, distance_metric="l2")
|
||||
|
||||
# 用一个与文档0非常相似的查询
|
||||
query = "machine learning AI technology"
|
||||
results = searcher.search(query, top_k=3)
|
||||
|
||||
print(f"\n查询: '{query}'")
|
||||
print("L2距离搜索结果:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.6f}")
|
||||
print(f" Text: {result['text']}")
|
||||
|
||||
# 现在用cosine重新测试同样的数据
|
||||
print(f"\n--- 用Cosine距离对比测试 ---")
|
||||
|
||||
INDEX_DIR_COS = Path("./test_cosine_verification")
|
||||
INDEX_PATH_COS = str(INDEX_DIR_COS / "documents.diskann")
|
||||
|
||||
if INDEX_DIR_COS.exists():
|
||||
shutil.rmtree(INDEX_DIR_COS)
|
||||
|
||||
builder_cos = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
distance_metric="cosine", # 使用cosine
|
||||
graph_degree=16,
|
||||
complexity=32
|
||||
)
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
builder_cos.add_text(doc, metadata={"id": i, "text": doc})
|
||||
|
||||
builder_cos.build_index(INDEX_PATH_COS)
|
||||
|
||||
searcher_cos = LeannSearcher(INDEX_PATH_COS, distance_metric="cosine")
|
||||
results_cos = searcher_cos.search(query, top_k=3)
|
||||
|
||||
print("Cosine距离搜索结果:")
|
||||
for i, result in enumerate(results_cos):
|
||||
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.6f}")
|
||||
print(f" Text: {result['text']}")
|
||||
|
||||
# 对比分析
|
||||
print(f"\n--- 结果对比分析 ---")
|
||||
print("L2距离的分数是欧几里得距离平方,越小越相似")
|
||||
print("Cosine距离的分数是余弦相似度的负值,越小越相似")
|
||||
|
||||
l2_top = results[0]
|
||||
cos_top = results_cos[0]
|
||||
|
||||
print(f"L2最佳匹配: ID{l2_top['id']}, Score={l2_top['score']:.6f}")
|
||||
print(f"Cosine最佳匹配: ID{cos_top['id']}, Score={cos_top['score']:.6f}")
|
||||
|
||||
if l2_top['id'] == cos_top['id']:
|
||||
print("✅ 两种距离函数返回相同的最佳匹配")
|
||||
else:
|
||||
print("⚠️ 两种距离函数返回不同的最佳匹配 - 这表明它们确实使用了不同的距离计算")
|
||||
|
||||
# 验证分数范围的合理性
|
||||
l2_scores = [r['score'] for r in results]
|
||||
cos_scores = [r['score'] for r in results_cos]
|
||||
|
||||
print(f"L2分数范围: {min(l2_scores):.6f} 到 {max(l2_scores):.6f}")
|
||||
print(f"Cosine分数范围: {min(cos_scores):.6f} 到 {max(cos_scores):.6f}")
|
||||
|
||||
# L2分数应该是正数,cosine分数应该在-1到0之间(因为是负的相似度)
|
||||
if all(score >= 0 for score in l2_scores):
|
||||
print("✅ L2分数都是正数,符合预期")
|
||||
else:
|
||||
print("❌ L2分数有负数,可能有问题")
|
||||
|
||||
if all(-1 <= score <= 0 for score in cos_scores):
|
||||
print("✅ Cosine分数在合理范围内")
|
||||
else:
|
||||
print(f"⚠️ Cosine分数超出预期范围: {cos_scores}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_l2_verification()
|
||||
@@ -1,190 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Sanity check script for Leann DiskANN backend
|
||||
Tests different distance functions and embedding models
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
|
||||
# 导入后端包以触发插件注册
|
||||
import sys
|
||||
sys.path.append('packages/leann-core/src')
|
||||
sys.path.append('packages/leann-backend-diskann')
|
||||
sys.path.append('packages/leann-backend-hnsw')
|
||||
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# 从 leann-core 导入上层 API
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
def test_distance_functions():
|
||||
"""测试不同的距离函数"""
|
||||
print("\n=== 测试不同距离函数 ===")
|
||||
|
||||
# 测试数据
|
||||
documents = [
|
||||
"Machine learning is a powerful technology",
|
||||
"Deep learning uses neural networks",
|
||||
"Artificial intelligence transforms industries"
|
||||
]
|
||||
|
||||
distance_functions = ["mips", "l2", "cosine"]
|
||||
|
||||
for distance_func in distance_functions:
|
||||
print(f"\n[测试 {distance_func} 距离函数]")
|
||||
try:
|
||||
index_path = f"test_indices/test_{distance_func}.diskann"
|
||||
if Path(index_path).parent.exists():
|
||||
shutil.rmtree(Path(index_path).parent)
|
||||
|
||||
# 构建索引
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
distance_metric=distance_func,
|
||||
graph_degree=16,
|
||||
complexity=32
|
||||
)
|
||||
|
||||
for doc in documents:
|
||||
builder.add_text(doc)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# 测试搜索
|
||||
searcher = LeannSearcher(index_path, distance_metric=distance_func)
|
||||
results = searcher.search("neural network technology", top_k=2)
|
||||
|
||||
print(f"✅ {distance_func} 距离函数工作正常")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i+1}. Score: {result['score']:.4f}, Text: {result['text'][:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {distance_func} 距离函数失败: {e}")
|
||||
|
||||
def test_embedding_models():
|
||||
"""测试不同的embedding模型"""
|
||||
print("\n=== 测试不同Embedding模型 ===")
|
||||
|
||||
documents = ["AI is transforming the world", "Technology advances rapidly"]
|
||||
|
||||
# 测试不同的embedding模型
|
||||
models_to_test = [
|
||||
"sentence-transformers/all-mpnet-base-v2",
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
# "sentence-transformers/distilbert-base-nli-mean-tokens", # 可能不存在
|
||||
]
|
||||
|
||||
for model_name in models_to_test:
|
||||
print(f"\n[测试 {model_name}]")
|
||||
try:
|
||||
index_path = f"test_indices/test_model.diskann"
|
||||
if Path(index_path).parent.exists():
|
||||
shutil.rmtree(Path(index_path).parent)
|
||||
|
||||
# 构建索引
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model=model_name,
|
||||
distance_metric="cosine"
|
||||
)
|
||||
|
||||
for doc in documents:
|
||||
builder.add_text(doc)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# 测试搜索
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("artificial intelligence", top_k=1)
|
||||
|
||||
print(f"✅ {model_name} 模型工作正常")
|
||||
print(f" 结果: {results[0]['text'][:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {model_name} 模型失败: {e}")
|
||||
|
||||
def test_search_correctness():
|
||||
"""验证搜索结果的正确性"""
|
||||
print("\n=== 验证搜索结果正确性 ===")
|
||||
|
||||
# 创建有明确相关性的测试文档
|
||||
documents = [
|
||||
"Python is a programming language used for machine learning", # 与编程相关
|
||||
"Dogs are loyal pets that love to play fetch", # 与动物相关
|
||||
"Machine learning algorithms can predict future trends", # 与ML相关
|
||||
"Cats are independent animals that sleep a lot", # 与动物相关
|
||||
"Deep learning neural networks process complex data" # 与ML相关
|
||||
]
|
||||
|
||||
try:
|
||||
index_path = "test_indices/correctness_test.diskann"
|
||||
if Path(index_path).parent.exists():
|
||||
shutil.rmtree(Path(index_path).parent)
|
||||
|
||||
# 构建索引
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
distance_metric="cosine"
|
||||
)
|
||||
|
||||
for doc in documents:
|
||||
builder.add_text(doc)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# 测试相关性查询
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
test_queries = [
|
||||
("machine learning programming", [0, 2, 4]), # 应该返回ML相关文档
|
||||
("pet animals behavior", [1, 3]), # 应该返回动物相关文档
|
||||
]
|
||||
|
||||
for query, expected_topics in test_queries:
|
||||
print(f"\n查询: '{query}'")
|
||||
results = searcher.search(query, top_k=3)
|
||||
|
||||
print("搜索结果:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.4f}")
|
||||
print(f" Text: {result['text'][:60]}...")
|
||||
|
||||
# 简单验证:检查前两个结果是否在预期范围内
|
||||
top_ids = [result['id'] for result in results[:2]]
|
||||
relevant_found = any(id in expected_topics for id in top_ids)
|
||||
|
||||
if relevant_found:
|
||||
print("✅ 搜索结果相关性正确")
|
||||
else:
|
||||
print("⚠️ 搜索结果相关性可能有问题")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 正确性测试失败: {e}")
|
||||
|
||||
def main():
|
||||
print("🔍 Leann DiskANN Sanity Check")
|
||||
print("=" * 50)
|
||||
|
||||
# 清理旧的测试数据
|
||||
if Path("test_indices").exists():
|
||||
shutil.rmtree("test_indices")
|
||||
|
||||
# 运行测试
|
||||
test_distance_functions()
|
||||
test_embedding_models()
|
||||
test_search_correctness()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🎉 Sanity check 完成!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user