107 lines
3.1 KiB
Python
107 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DiskANN 距离函数测试
|
|
"""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
import shutil
|
|
import time
|
|
|
|
# 导入后端包以触发插件注册
|
|
try:
|
|
import leann_backend_diskann
|
|
import leann_backend_hnsw
|
|
print("INFO: Backend packages imported successfully.")
|
|
except ImportError as e:
|
|
print(f"WARNING: Could not import backend packages. Error: {e}")
|
|
|
|
# 从 leann-core 导入上层 API
|
|
from leann.api import LeannBuilder, LeannSearcher
|
|
|
|
|
|
def load_sample_documents():
|
|
"""创建用于演示的样本文档"""
|
|
docs = [
|
|
{"title": "Intro to Python", "content": "Python is a programming language for machine learning"},
|
|
{"title": "ML Basics", "content": "Machine learning algorithms build intelligent systems"},
|
|
{"title": "Data Structures", "content": "Data structures like arrays and graphs organize information"},
|
|
]
|
|
return docs
|
|
|
|
|
|
def test_distance_function(distance_func, test_name):
|
|
"""测试特定距离函数"""
|
|
print(f"\n=== 测试 {test_name} ({distance_func}) ===")
|
|
|
|
INDEX_DIR = Path(f"./test_indices_{distance_func}")
|
|
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
|
|
|
if INDEX_DIR.exists():
|
|
shutil.rmtree(INDEX_DIR)
|
|
|
|
# 构建索引
|
|
print(f"构建索引 (距离函数: {distance_func})...")
|
|
builder = LeannBuilder(
|
|
backend_name="diskann",
|
|
distance_metric=distance_func,
|
|
graph_degree=16,
|
|
complexity=32
|
|
)
|
|
|
|
documents = load_sample_documents()
|
|
for doc in documents:
|
|
builder.add_text(doc["content"], metadata=doc)
|
|
|
|
try:
|
|
builder.build_index(INDEX_PATH)
|
|
print(f"✅ 索引构建成功")
|
|
|
|
# 测试搜索
|
|
searcher = LeannSearcher(INDEX_PATH, distance_metric=distance_func)
|
|
results = searcher.search("machine learning programming", top_k=2)
|
|
|
|
print(f"搜索结果:")
|
|
for i, result in enumerate(results):
|
|
print(f" {i+1}. Score: {result['score']:.4f}")
|
|
print(f" Text: {result['text'][:50]}...")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ 测试失败: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
print("🔍 DiskANN 距离函数测试")
|
|
print("=" * 50)
|
|
|
|
# 测试不同距离函数
|
|
distance_tests = [
|
|
("mips", "Maximum Inner Product Search"),
|
|
("l2", "L2 Euclidean Distance"),
|
|
("cosine", "Cosine Similarity")
|
|
]
|
|
|
|
results = {}
|
|
for distance_func, test_name in distance_tests:
|
|
try:
|
|
success = test_distance_function(distance_func, test_name)
|
|
results[distance_func] = success
|
|
except Exception as e:
|
|
print(f"❌ {distance_func} 测试异常: {e}")
|
|
results[distance_func] = False
|
|
|
|
# 总结
|
|
print("\n" + "=" * 50)
|
|
print("📊 测试结果总结:")
|
|
for distance_func, success in results.items():
|
|
status = "✅ 通过" if success else "❌ 失败"
|
|
print(f" {distance_func:10s}: {status}")
|
|
|
|
print("\n🎉 测试完成!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |