Files
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

190 lines
6.3 KiB
Python

#!/usr/bin/env python3
"""
Sanity check script for Leann DiskANN backend
Tests different distance functions and embedding models
"""
import os
import numpy as np
from pathlib import Path
import shutil
import time
# 导入后端包以触发插件注册
import sys
sys.path.append('packages/leann-core/src')
sys.path.append('packages/leann-backend-diskann')
sys.path.append('packages/leann-backend-hnsw')
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# 从 leann-core 导入上层 API
from leann.api import LeannBuilder, LeannSearcher
def test_distance_functions():
"""测试不同的距离函数"""
print("\n=== 测试不同距离函数 ===")
# 测试数据
documents = [
"Machine learning is a powerful technology",
"Deep learning uses neural networks",
"Artificial intelligence transforms industries"
]
distance_functions = ["mips", "l2", "cosine"]
for distance_func in distance_functions:
print(f"\n[测试 {distance_func} 距离函数]")
try:
index_path = f"test_indices/test_{distance_func}.diskann"
if Path(index_path).parent.exists():
shutil.rmtree(Path(index_path).parent)
# 构建索引
builder = LeannBuilder(
backend_name="diskann",
distance_metric=distance_func,
graph_degree=16,
complexity=32
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
# 测试搜索
searcher = LeannSearcher(index_path, distance_metric=distance_func)
results = searcher.search("neural network technology", top_k=2)
print(f"{distance_func} 距离函数工作正常")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result['score']:.4f}, Text: {result['text'][:50]}...")
except Exception as e:
print(f"{distance_func} 距离函数失败: {e}")
def test_embedding_models():
"""测试不同的embedding模型"""
print("\n=== 测试不同Embedding模型 ===")
documents = ["AI is transforming the world", "Technology advances rapidly"]
# 测试不同的embedding模型
models_to_test = [
"sentence-transformers/all-mpnet-base-v2",
"sentence-transformers/all-MiniLM-L6-v2",
# "sentence-transformers/distilbert-base-nli-mean-tokens", # 可能不存在
]
for model_name in models_to_test:
print(f"\n[测试 {model_name}]")
try:
index_path = f"test_indices/test_model.diskann"
if Path(index_path).parent.exists():
shutil.rmtree(Path(index_path).parent)
# 构建索引
builder = LeannBuilder(
backend_name="diskann",
embedding_model=model_name,
distance_metric="cosine"
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
# 测试搜索
searcher = LeannSearcher(index_path)
results = searcher.search("artificial intelligence", top_k=1)
print(f"{model_name} 模型工作正常")
print(f" 结果: {results[0]['text'][:50]}...")
except Exception as e:
print(f"{model_name} 模型失败: {e}")
def test_search_correctness():
"""验证搜索结果的正确性"""
print("\n=== 验证搜索结果正确性 ===")
# 创建有明确相关性的测试文档
documents = [
"Python is a programming language used for machine learning", # 与编程相关
"Dogs are loyal pets that love to play fetch", # 与动物相关
"Machine learning algorithms can predict future trends", # 与ML相关
"Cats are independent animals that sleep a lot", # 与动物相关
"Deep learning neural networks process complex data" # 与ML相关
]
try:
index_path = "test_indices/correctness_test.diskann"
if Path(index_path).parent.exists():
shutil.rmtree(Path(index_path).parent)
# 构建索引
builder = LeannBuilder(
backend_name="diskann",
distance_metric="cosine"
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
# 测试相关性查询
searcher = LeannSearcher(index_path)
test_queries = [
("machine learning programming", [0, 2, 4]), # 应该返回ML相关文档
("pet animals behavior", [1, 3]), # 应该返回动物相关文档
]
for query, expected_topics in test_queries:
print(f"\n查询: '{query}'")
results = searcher.search(query, top_k=3)
print("搜索结果:")
for i, result in enumerate(results):
print(f" {i+1}. ID:{result['id']}, Score:{result['score']:.4f}")
print(f" Text: {result['text'][:60]}...")
# 简单验证:检查前两个结果是否在预期范围内
top_ids = [result['id'] for result in results[:2]]
relevant_found = any(id in expected_topics for id in top_ids)
if relevant_found:
print("✅ 搜索结果相关性正确")
else:
print("⚠️ 搜索结果相关性可能有问题")
except Exception as e:
print(f"❌ 正确性测试失败: {e}")
def main():
print("🔍 Leann DiskANN Sanity Check")
print("=" * 50)
# 清理旧的测试数据
if Path("test_indices").exists():
shutil.rmtree("test_indices")
# 运行测试
test_distance_functions()
test_embedding_models()
test_search_correctness()
print("\n" + "=" * 50)
print("🎉 Sanity check 完成!")
if __name__ == "__main__":
main()