190 lines
6.3 KiB
Python
190 lines
6.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Sanity check script for Leann DiskANN backend
|
|
Tests different distance functions and embedding models
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
from pathlib import Path
|
|
import shutil
|
|
import time
|
|
|
|
# 导入后端包以触发插件注册
|
|
import sys
|
|
sys.path.append('packages/leann-core/src')
|
|
sys.path.append('packages/leann-backend-diskann')
|
|
sys.path.append('packages/leann-backend-hnsw')
|
|
|
|
try:
|
|
import leann_backend_diskann
|
|
import leann_backend_hnsw
|
|
print("INFO: Backend packages imported successfully.")
|
|
except ImportError as e:
|
|
print(f"WARNING: Could not import backend packages. Error: {e}")
|
|
|
|
# 从 leann-core 导入上层 API
|
|
from leann.api import LeannBuilder, LeannSearcher
|
|
|
|
def test_distance_functions():
|
|
"""测试不同的距离函数"""
|
|
print("\n=== 测试不同距离函数 ===")
|
|
|
|
# 测试数据
|
|
documents = [
|
|
"Machine learning is a powerful technology",
|
|
"Deep learning uses neural networks",
|
|
"Artificial intelligence transforms industries"
|
|
]
|
|
|
|
distance_functions = ["mips", "l2", "cosine"]
|
|
|
|
for distance_func in distance_functions:
|
|
print(f"\n[测试 {distance_func} 距离函数]")
|
|
try:
|
|
index_path = f"test_indices/test_{distance_func}.diskann"
|
|
if Path(index_path).parent.exists():
|
|
shutil.rmtree(Path(index_path).parent)
|
|
|
|
# 构建索引
|
|
builder = LeannBuilder(
|
|
backend_name="diskann",
|
|
distance_metric=distance_func,
|
|
graph_degree=16,
|
|
complexity=32
|
|
)
|
|
|
|
for doc in documents:
|
|
builder.add_text(doc)
|
|
|
|
builder.build_index(index_path)
|
|
|
|
# 测试搜索
|
|
searcher = LeannSearcher(index_path, distance_metric=distance_func)
|
|
results = searcher.search("neural network technology", top_k=2)
|
|
|
|
print(f"✅ {distance_func} 距离函数工作正常")
|
|
for i, result in enumerate(results):
|
|
print(f" {i+1}. Score: {result['score']:.4f}, Text: {result['text'][:50]}...")
|
|
|
|
except Exception as e:
|
|
print(f"❌ {distance_func} 距离函数失败: {e}")
|
|
|
|
def test_embedding_models():
|
|
"""测试不同的embedding模型"""
|
|
print("\n=== 测试不同Embedding模型 ===")
|
|
|
|
documents = ["AI is transforming the world", "Technology advances rapidly"]
|
|
|
|
# 测试不同的embedding模型
|
|
models_to_test = [
|
|
"sentence-transformers/all-mpnet-base-v2",
|
|
"sentence-transformers/all-MiniLM-L6-v2",
|
|
# "sentence-transformers/distilbert-base-nli-mean-tokens", # 可能不存在
|
|
]
|
|
|
|
for model_name in models_to_test:
|
|
print(f"\n[测试 {model_name}]")
|
|
try:
|
|
index_path = f"test_indices/test_model.diskann"
|
|
if Path(index_path).parent.exists():
|
|
shutil.rmtree(Path(index_path).parent)
|
|
|
|
# 构建索引
|
|
builder = LeannBuilder(
|
|
backend_name="diskann",
|
|
embedding_model=model_name,
|
|
distance_metric="cosine"
|
|
)
|
|
|
|
for doc in documents:
|
|
builder.add_text(doc)
|
|
|
|
builder.build_index(index_path)
|
|
|
|
# 测试搜索
|
|
searcher = LeannSearcher(index_path)
|
|
results = searcher.search("artificial intelligence", top_k=1)
|
|
|
|
print(f"✅ {model_name} 模型工作正常")
|
|
print(f" 结果: {results[0]['text'][:50]}...")
|
|
|
|
except Exception as e:
|
|
print(f"❌ {model_name} 模型失败: {e}")
|
|
|
|
def test_search_correctness():
|
|
"""验证搜索结果的正确性"""
|
|
print("\n=== 验证搜索结果正确性 ===")
|
|
|
|
# 创建有明确相关性的测试文档
|
|
documents = [
|
|
"Python is a programming language used for machine learning", # 与编程相关
|
|
"Dogs are loyal pets that love to play fetch", # 与动物相关
|
|
"Machine learning algorithms can predict future trends", # 与ML相关
|
|
"Cats are independent animals that sleep a lot", # 与动物相关
|
|
"Deep learning neural networks process complex data" # 与ML相关
|
|
]
|
|
|
|
try:
|
|
index_path = "test_indices/correctness_test.diskann"
|
|
if Path(index_path).parent.exists():
|
|
shutil.rmtree(Path(index_path).parent)
|
|
|
|
# 构建索引
|
|
builder = LeannBuilder(
|
|
backend_name="diskann",
|
|
distance_metric="cosine"
|
|
)
|
|
|
|
for doc in documents:
|
|
builder.add_text(doc)
|
|
|
|
builder.build_index(index_path)
|
|
|
|
# 测试相关性查询
|
|
searcher = LeannSearcher(index_path)
|
|
|
|
test_queries = [
|
|
("machine learning programming", [0, 2, 4]), # 应该返回ML相关文档
|
|
("pet animals behavior", [1, 3]), # 应该返回动物相关文档
|
|
]
|
|
|
|
for query, expected_topics in test_queries:
|
|
print(f"\n查询: '{query}'")
|
|
results = searcher.search(query, top_k=3)
|
|
|
|
print("搜索结果:")
|
|
for i, result in enumerate(results):
|
|
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.4f}")
|
|
print(f" Text: {result['text'][:60]}...")
|
|
|
|
# 简单验证:检查前两个结果是否在预期范围内
|
|
top_ids = [result['id'] for result in results[:2]]
|
|
relevant_found = any(id in expected_topics for id in top_ids)
|
|
|
|
if relevant_found:
|
|
print("✅ 搜索结果相关性正确")
|
|
else:
|
|
print("⚠️ 搜索结果相关性可能有问题")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 正确性测试失败: {e}")
|
|
|
|
def main():
|
|
print("🔍 Leann DiskANN Sanity Check")
|
|
print("=" * 50)
|
|
|
|
# 清理旧的测试数据
|
|
if Path("test_indices").exists():
|
|
shutil.rmtree("test_indices")
|
|
|
|
# 运行测试
|
|
test_distance_functions()
|
|
test_embedding_models()
|
|
test_search_correctness()
|
|
|
|
print("\n" + "=" * 50)
|
|
print("🎉 Sanity check 完成!")
|
|
|
|
if __name__ == "__main__":
|
|
main() |