Compare commits
99 Commits
fix/drop-p
...
debug/clea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b241c17f5e | ||
|
|
8cfd5d6a8a | ||
|
|
10bfe9c980 | ||
|
|
a4346ef701 | ||
|
|
6db0a7747d | ||
|
|
b6efe3a726 | ||
|
|
0f110dc7b9 | ||
|
|
dfe60a152f | ||
|
|
6af8101977 | ||
|
|
17e0d7470f | ||
|
|
d6a923f52e | ||
|
|
d79d0af7b1 | ||
|
|
eb71969d2c | ||
|
|
183e523be9 | ||
|
|
f096e62bfa | ||
|
|
27215dfcce | ||
|
|
b8cf7198dd | ||
|
|
317d9e9ed7 | ||
|
|
751b5f8735 | ||
|
|
a7ad0bc3d6 | ||
|
|
f496621034 | ||
|
|
91d4b4fd94 | ||
|
|
4b714f3b44 | ||
|
|
b381278c3e | ||
|
|
f30166f9d5 | ||
|
|
24676970eb | ||
|
|
e26d6d9d14 | ||
|
|
2530939c0f | ||
|
|
8496828a90 | ||
|
|
7244518901 | ||
|
|
3c1207c35c | ||
|
|
364a546863 | ||
|
|
2001edf22b | ||
|
|
c1d39eead8 | ||
|
|
8d06aa99f4 | ||
|
|
2d8a1ac328 | ||
|
|
ffbf0282c3 | ||
|
|
aa2002dc3a | ||
|
|
19faa020c7 | ||
|
|
360a3ec732 | ||
|
|
341141cf8b | ||
|
|
fdf47852f0 | ||
|
|
491979c057 | ||
|
|
8e43066e10 | ||
|
|
0cc29f5edc | ||
|
|
ce9ae5f7f9 | ||
|
|
101a45a04f | ||
|
|
fbf619f087 | ||
|
|
aa8ed87bda | ||
|
|
33616c493b | ||
|
|
b0c27f3a12 | ||
|
|
7b28f81194 | ||
|
|
eb6c9e0a32 | ||
|
|
51bbf3efbf | ||
|
|
3806f2a3ba | ||
|
|
8f3cda2100 | ||
|
|
d88d0c0295 | ||
|
|
042da1fe09 | ||
|
|
2d9c183ebb | ||
|
|
a8421c0475 | ||
|
|
0ec00e1a60 | ||
|
|
777b5fed01 | ||
|
|
440ad6e816 | ||
|
|
8714472cd8 | ||
|
|
c799d61a5a | ||
|
|
e409933149 | ||
|
|
bc31876a9f | ||
|
|
e421c44b8b | ||
|
|
af69aa0508 | ||
|
|
575b354976 | ||
|
|
65bbff1d93 | ||
|
|
df798d350d | ||
|
|
3fa6b2aa17 | ||
|
|
ba95554fe7 | ||
|
|
677eb0bae3 | ||
|
|
9cdfcec331 | ||
|
|
f30d1a2530 | ||
|
|
df69a49123 | ||
|
|
65b54ff905 | ||
|
|
4db3e94f35 | ||
|
|
a2568f3ddc | ||
|
|
45bdad4fa7 | ||
|
|
8b538d1ef9 | ||
|
|
ada8bcbc70 | ||
|
|
6061e8f2de | ||
|
|
9842ad8330 | ||
|
|
7d920f9071 | ||
|
|
f28f15000c | ||
|
|
1d657fd9f6 | ||
|
|
d217adbe40 | ||
|
|
f790ec634f | ||
|
|
b8da9d7b12 | ||
|
|
0cb0463929 | ||
|
|
b982241249 | ||
|
|
c66f197e1d | ||
|
|
4a1353761a | ||
|
|
a72090d2ab | ||
|
|
669e622430 | ||
|
|
77d7b60a61 |
1
.github/workflows/build-and-publish.yml
vendored
1
.github/workflows/build-and-publish.yml
vendored
@@ -5,6 +5,7 @@ on:
|
|||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|||||||
13
.github/workflows/build-reusable.yml
vendored
13
.github/workflows/build-reusable.yml
vendored
@@ -277,19 +277,16 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests with pytest
|
- name: Run tests with pytest
|
||||||
env:
|
env:
|
||||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
CI: true
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
HF_HUB_DISABLE_SYMLINKS: 1
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
TOKENIZERS_PARALLELISM: false
|
TOKENIZERS_PARALLELISM: false
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||||
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
OMP_NUM_THREADS: 1
|
||||||
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
MKL_NUM_THREADS: 1
|
||||||
run: |
|
run: |
|
||||||
# Activate virtual environment
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
# Run tests
|
|
||||||
pytest -v tests/
|
|
||||||
|
|
||||||
- name: Run sanity checks (optional)
|
- name: Run sanity checks (optional)
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.5.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -10,7 +10,7 @@ repos:
|
|||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.2.1
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -455,7 +455,7 @@ leann --help
|
|||||||
**To make it globally available:**
|
**To make it globally available:**
|
||||||
```bash
|
```bash
|
||||||
# Install the LEANN CLI globally using uv tool
|
# Install the LEANN CLI globally using uv tool
|
||||||
uv tool install leann
|
uv tool install leann-core
|
||||||
|
|
||||||
# Now you can use leann from anywhere without activating venv
|
# Now you can use leann from anywhere without activating venv
|
||||||
leann --help
|
leann --help
|
||||||
@@ -543,12 +543,16 @@ Options:
|
|||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
|
**Backends:**
|
||||||
|
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||||
|
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
|
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||||
|
|
||||||
|
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||||
|
|
||||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
|
||||||
### 📊 Storage Comparison
|
### 📊 Storage Comparison
|
||||||
|
|
||||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
|
|||||||
@@ -178,6 +178,9 @@ class BaseRAGExample(ABC):
|
|||||||
config["host"] = args.llm_host
|
config["host"] = args.llm_host
|
||||||
elif args.llm == "hf":
|
elif args.llm == "hf":
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
elif args.llm == "simulated":
|
||||||
|
# Simulated LLM doesn't need additional configuration
|
||||||
|
pass
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,24 @@
|
|||||||
# 🧪 Leann Sanity Checks
|
# 🧪 LEANN Benchmarks & Testing
|
||||||
|
|
||||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
|
### `diskann_vs_hnsw_speed_comparison.py`
|
||||||
|
Performance comparison between DiskANN and HNSW backends:
|
||||||
|
- ✅ **Search latency** comparison with both backends using recompute
|
||||||
|
- ✅ **Index size** and **build time** measurements
|
||||||
|
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||||
|
- ✅ **Configurable dataset sizes** for different scales
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick comparison with 500 docs, 10 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||||
|
|
||||||
|
# Large-scale comparison with 2000 docs, 20 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||||
|
```
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
|
|||||||
268
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
268
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DiskANN vs HNSW Search Performance Comparison
|
||||||
|
|
||||||
|
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||||
|
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||||
|
- HNSW: With recompute enabled (is_recompute=True)
|
||||||
|
- Tests performance across different dataset sizes
|
||||||
|
- Measures search latency, recall, and index size
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_texts(n_docs: int) -> list[str]:
|
||||||
|
"""Create synthetic test documents for benchmarking."""
|
||||||
|
np.random.seed(42)
|
||||||
|
topics = [
|
||||||
|
"machine learning and artificial intelligence",
|
||||||
|
"natural language processing and text analysis",
|
||||||
|
"computer vision and image recognition",
|
||||||
|
"data science and statistical analysis",
|
||||||
|
"deep learning and neural networks",
|
||||||
|
"information retrieval and search engines",
|
||||||
|
"database systems and data management",
|
||||||
|
"software engineering and programming",
|
||||||
|
"cybersecurity and network protection",
|
||||||
|
"cloud computing and distributed systems",
|
||||||
|
]
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
for i in range(n_docs):
|
||||||
|
topic = topics[i % len(topics)]
|
||||||
|
variation = np.random.randint(1, 100)
|
||||||
|
text = (
|
||||||
|
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||||
|
f"Additional information about {topic} with details and examples. "
|
||||||
|
f"Technical discussion of {topic} including implementation aspects."
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_backend(
|
||||||
|
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark a specific backend with the given configuration."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure index size
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||||
|
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||||
|
size_mb = total_size / (1024 * 1024)
|
||||||
|
|
||||||
|
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||||
|
|
||||||
|
# Search benchmark
|
||||||
|
print("🔍 Running search benchmark...")
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||||
|
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||||
|
|
||||||
|
# Check for valid scores (detect -inf issues)
|
||||||
|
all_scores = [
|
||||||
|
result.score
|
||||||
|
for results in all_results
|
||||||
|
for result in results
|
||||||
|
if result.score is not None
|
||||||
|
]
|
||||||
|
valid_scores = [
|
||||||
|
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||||
|
]
|
||||||
|
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
try:
|
||||||
|
if hasattr(searcher, "__del__"):
|
||||||
|
searcher.__del__()
|
||||||
|
del searcher
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"build_time": build_time,
|
||||||
|
"avg_search_time_ms": avg_search_time,
|
||||||
|
"index_size_mb": size_mb,
|
||||||
|
"score_validity_rate": score_validity_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||||
|
"""Run performance comparison between DiskANN and HNSW."""
|
||||||
|
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||||
|
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
texts = create_test_texts(n_docs)
|
||||||
|
test_queries = [
|
||||||
|
"machine learning algorithms",
|
||||||
|
"natural language processing",
|
||||||
|
"computer vision techniques",
|
||||||
|
"data analysis methods",
|
||||||
|
"neural network architectures",
|
||||||
|
"database query optimization",
|
||||||
|
"software development practices",
|
||||||
|
"security vulnerabilities",
|
||||||
|
"cloud infrastructure",
|
||||||
|
"distributed computing",
|
||||||
|
][:n_queries]
|
||||||
|
|
||||||
|
# HNSW benchmark
|
||||||
|
hnsw_results = benchmark_backend(
|
||||||
|
backend_name="hnsw",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable recompute for fair comparison
|
||||||
|
"M": 16,
|
||||||
|
"efConstruction": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiskANN benchmark
|
||||||
|
diskann_results = benchmark_backend(
|
||||||
|
backend_name="diskann",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable graph partitioning
|
||||||
|
"num_neighbors": 32,
|
||||||
|
"search_list_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
print("\n📈 Performance Comparison Results")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||||
|
print(f"{'-' * 60}")
|
||||||
|
|
||||||
|
# Build time comparison
|
||||||
|
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||||
|
print(
|
||||||
|
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search time comparison
|
||||||
|
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||||
|
print(
|
||||||
|
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index size comparison
|
||||||
|
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||||
|
print(
|
||||||
|
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score validity
|
||||||
|
print(
|
||||||
|
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print("\n🎯 Summary:")
|
||||||
|
if search_speedup > 1:
|
||||||
|
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||||
|
else:
|
||||||
|
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||||
|
|
||||||
|
if size_ratio > 1:
|
||||||
|
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||||
|
else:
|
||||||
|
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle help request
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||||
|
print()
|
||||||
|
print("Arguments:")
|
||||||
|
print(" n_docs Number of documents to index (default: 500)")
|
||||||
|
print(" n_queries Number of test queries to run (default: 10)")
|
||||||
|
print()
|
||||||
|
print("Examples:")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||||
|
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||||
|
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||||
|
print()
|
||||||
|
|
||||||
|
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Benchmark interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Benchmark failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
# Ensure clean exit
|
||||||
|
try:
|
||||||
|
gc.collect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
sys.exit(0)
|
||||||
@@ -97,16 +97,30 @@ ollama pull nomic-embed-text
|
|||||||
```
|
```
|
||||||
|
|
||||||
### DiskANN
|
### DiskANN
|
||||||
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
|
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
|
||||||
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
|
|
||||||
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
|
**How it works:**
|
||||||
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
|
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
|
||||||
|
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
|
||||||
|
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
|
||||||
|
|
||||||
|
**Trade-offs compared to HNSW:**
|
||||||
|
- ✅ **Faster search latency** (typically 2-8x speedup)
|
||||||
|
- ✅ **Better scaling** for large datasets
|
||||||
|
- ✅ **Smart storage management** with automatic partitioning
|
||||||
|
- ✅ **Better graph locality** with `--ldg-times` parameter for SSD optimization
|
||||||
|
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# For billion-scale deployments
|
# Recommended for most use cases
|
||||||
|
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||||
|
|
||||||
|
# For large-scale deployments
|
||||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
--backend-name diskann --graph-degree 64 --build-complexity 128
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||||
|
|
||||||
## LLM Selection: Engine and Model Comparison
|
## LLM Selection: Engine and Model Comparison
|
||||||
|
|
||||||
### LLM Engines
|
### LLM Engines
|
||||||
@@ -283,3 +297,4 @@ LEANN's recomputation feature provides exact distance calculations but can be di
|
|||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
|
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||||
|
|||||||
@@ -1 +1,7 @@
|
|||||||
from . import diskann_backend as diskann_backend
|
from . import diskann_backend as diskann_backend
|
||||||
|
from . import graph_partition
|
||||||
|
|
||||||
|
# Export main classes and functions
|
||||||
|
from .graph_partition import GraphPartitioner, partition_graph
|
||||||
|
|
||||||
|
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||||
|
|||||||
@@ -22,6 +22,11 @@ logger = logging.getLogger(__name__)
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def suppress_cpp_output_if_needed():
|
def suppress_cpp_output_if_needed():
|
||||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
||||||
|
if os.getenv("CI") == "true":
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
@@ -137,6 +142,71 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
|
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||||
|
"""
|
||||||
|
Safely cleanup files after partition.
|
||||||
|
In partition mode, C++ doesn't read _disk.index content,
|
||||||
|
so we can delete it if all derived files exist.
|
||||||
|
"""
|
||||||
|
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||||
|
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||||
|
|
||||||
|
# Required files that C++ partition mode needs
|
||||||
|
# Note: C++ generates these with _disk.index suffix
|
||||||
|
disk_suffix = "_disk.index"
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||||
|
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||||
|
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||||
|
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if all required files exist
|
||||||
|
missing_files = []
|
||||||
|
for filename in required_files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if not file_path.exists():
|
||||||
|
missing_files.append(filename)
|
||||||
|
|
||||||
|
if missing_files:
|
||||||
|
logger.warning(
|
||||||
|
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||||
|
)
|
||||||
|
logger.info("Keeping all original files for safety")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate space savings
|
||||||
|
space_saved = 0
|
||||||
|
files_to_delete = []
|
||||||
|
|
||||||
|
if disk_index_file.exists():
|
||||||
|
space_saved += disk_index_file.stat().st_size
|
||||||
|
files_to_delete.append(disk_index_file)
|
||||||
|
|
||||||
|
if beam_search_file.exists():
|
||||||
|
space_saved += beam_search_file.stat().st_size
|
||||||
|
files_to_delete.append(beam_search_file)
|
||||||
|
|
||||||
|
# Safe to delete!
|
||||||
|
for file_to_delete in files_to_delete:
|
||||||
|
try:
|
||||||
|
os.remove(file_to_delete)
|
||||||
|
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||||
|
|
||||||
|
if space_saved > 0:
|
||||||
|
space_saved_mb = space_saved / (1024 * 1024)
|
||||||
|
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||||
|
|
||||||
|
# Show what files are kept
|
||||||
|
logger.info("📁 Kept essential files for partition mode:")
|
||||||
|
for filename in required_files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -151,6 +221,17 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
|
|
||||||
|
# Extract is_recompute from nested backend_kwargs if needed
|
||||||
|
is_recompute = build_kwargs.get("is_recompute", False)
|
||||||
|
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||||
|
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||||
|
|
||||||
|
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||||
|
if "backend_kwargs" in build_kwargs:
|
||||||
|
nested_params = build_kwargs.pop("backend_kwargs")
|
||||||
|
build_kwargs.update(nested_params)
|
||||||
|
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
@@ -185,6 +266,30 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
build_kwargs.get("pq_disk_bytes", 0),
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Auto-partition if is_recompute is enabled
|
||||||
|
if build_kwargs.get("is_recompute", False):
|
||||||
|
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||||
|
from .graph_partition import partition_graph
|
||||||
|
|
||||||
|
# Partition the index using absolute paths
|
||||||
|
# Convert to absolute paths to avoid issues with working directory changes
|
||||||
|
absolute_index_dir = Path(index_dir).resolve()
|
||||||
|
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||||
|
disk_graph_path, partition_bin_path = partition_graph(
|
||||||
|
index_prefix_path=absolute_index_prefix_path,
|
||||||
|
output_dir=str(absolute_index_dir),
|
||||||
|
partition_prefix=index_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||||
|
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||||
|
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||||
|
|
||||||
|
logger.info("✅ Graph partitioning completed successfully!")
|
||||||
|
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||||
|
logger.info(f" - Partition file: {partition_bin_path}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
@@ -213,7 +318,26 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
# Store the initialization parameters for later use
|
# Store the initialization parameters for later use
|
||||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||||
|
# C++ internally constructs: index_prefix + "_disk.index"
|
||||||
|
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||||
|
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||||
|
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||||
|
|
||||||
|
# Auto-detect partition files and set partition_prefix
|
||||||
|
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||||
|
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||||
|
|
||||||
|
partition_prefix = ""
|
||||||
|
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||||
|
# C++ expects full path prefix, not just filename
|
||||||
|
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||||
|
logger.info(
|
||||||
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No partition files detected, using standard index files")
|
||||||
|
|
||||||
self._init_params = {
|
self._init_params = {
|
||||||
"metric_enum": metric_enum,
|
"metric_enum": metric_enum,
|
||||||
"full_index_prefix": full_index_prefix,
|
"full_index_prefix": full_index_prefix,
|
||||||
@@ -221,8 +345,14 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
"cache_mechanism": 1,
|
"cache_mechanism": 1,
|
||||||
"pq_prefix": "",
|
"pq_prefix": "",
|
||||||
"partition_prefix": "",
|
"partition_prefix": partition_prefix,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log partition configuration for debugging
|
||||||
|
if partition_prefix:
|
||||||
|
logger.info(
|
||||||
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||||
|
)
|
||||||
self._diskannpy = diskannpy
|
self._diskannpy = diskannpy
|
||||||
self._current_zmq_port = None
|
self._current_zmq_port = None
|
||||||
self._index = None
|
self._index = None
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ def create_diskann_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||||
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
@@ -102,8 +103,9 @@ def create_diskann_embedding_server(
|
|||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -220,30 +222,217 @@ def create_diskann_embedding_server(
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
This creates its own REP socket, binds to zmq_port, and periodically
|
||||||
|
checks shutdown_event using recv timeouts to exit cleanly.
|
||||||
|
"""
|
||||||
|
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
rep_socket = context.socket(zmq.REP)
|
||||||
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
# Set receive timeout so we can check shutdown_event periodically
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
e2e_start = time.time()
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = rep_socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if not message:
|
||||||
|
logger.warning("Received empty message, sending empty response")
|
||||||
|
rep_socket.send(b"")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try protobuf first (same logic as original)
|
||||||
|
texts = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
# Look up texts by node IDs
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
|
||||||
|
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
||||||
|
except Exception:
|
||||||
|
# Fallback to msgpack for text requests
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(
|
||||||
|
f"ZMQ received msgpack text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception:
|
||||||
|
logger.error("Both protobuf and msgpack parsing failed!")
|
||||||
|
# Send error response
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process the request
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error("NaN or Inf detected in embeddings!")
|
||||||
|
# Send error response
|
||||||
|
if is_text_request:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb([])
|
||||||
|
else:
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For direct text requests, return msgpack
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For protobuf requests, return protobuf
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
try:
|
||||||
|
# Send error response for REP socket
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
||||||
|
|
||||||
|
# Add shutdown coordination
|
||||||
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
|
def shutdown_zmq_server():
|
||||||
|
"""Gracefully shutdown ZMQ server."""
|
||||||
|
logger.info("Initiating graceful shutdown...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.info("Waiting for ZMQ thread to finish...")
|
||||||
|
zmq_thread.join(timeout=5)
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.warning("ZMQ thread did not finish in time")
|
||||||
|
|
||||||
|
# Clean up ZMQ resources
|
||||||
|
try:
|
||||||
|
# Note: socket and context are cleaned up by thread exit
|
||||||
|
logger.info("ZMQ resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||||
|
|
||||||
|
# Clean up other resources
|
||||||
|
try:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Additional resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning additional resources: {e}")
|
||||||
|
|
||||||
|
logger.info("Graceful shutdown completed")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers within this function scope
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
shutdown_zmq_server()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Start ZMQ thread (NOT daemon!)
|
||||||
|
zmq_thread = threading.Thread(
|
||||||
|
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||||
|
daemon=False, # Not daemon - we want to wait for it
|
||||||
|
)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(0.1) # Check shutdown more frequently
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("DiskANN Server shutting down...")
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
shutdown_zmq_server()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we reach here, shutdown was triggered by signal
|
||||||
|
logger.info("Main loop exited, process should be shutting down")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# Signal handlers are now registered within create_diskann_embedding_server
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
|||||||
@@ -0,0 +1,299 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Graph Partition Module for LEANN DiskANN Backend
|
||||||
|
|
||||||
|
This module provides Python bindings for the graph partition functionality
|
||||||
|
of DiskANN, allowing users to partition disk-based indices for better
|
||||||
|
performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class GraphPartitioner:
|
||||||
|
"""
|
||||||
|
A Python interface for DiskANN's graph partition functionality.
|
||||||
|
|
||||||
|
This class provides methods to partition disk-based indices for improved
|
||||||
|
search performance and memory efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, build_type: str = "release"):
|
||||||
|
"""
|
||||||
|
Initialize the GraphPartitioner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
build_type: Build type for the executables ("debug" or "release")
|
||||||
|
"""
|
||||||
|
self.build_type = build_type
|
||||||
|
self._ensure_executables()
|
||||||
|
|
||||||
|
def _get_executable_path(self, name: str) -> str:
|
||||||
|
"""Get the path to a graph partition executable."""
|
||||||
|
# Get the directory where this Python module is located
|
||||||
|
module_dir = Path(__file__).parent
|
||||||
|
# Navigate to the graph_partition directory
|
||||||
|
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||||
|
|
||||||
|
if not executable_path.exists():
|
||||||
|
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||||
|
|
||||||
|
return str(executable_path)
|
||||||
|
|
||||||
|
def _ensure_executables(self):
|
||||||
|
"""Ensure that the required executables are built."""
|
||||||
|
try:
|
||||||
|
self._get_executable_path("partitioner")
|
||||||
|
self._get_executable_path("index_relayout")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# Try to build the executables automatically
|
||||||
|
print("Executables not found, attempting to build them...")
|
||||||
|
self._build_executables()
|
||||||
|
|
||||||
|
def _build_executables(self):
|
||||||
|
"""Build the required executables."""
|
||||||
|
graph_partition_dir = (
|
||||||
|
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.chdir(graph_partition_dir)
|
||||||
|
|
||||||
|
# Clean any existing build
|
||||||
|
if (graph_partition_dir / "build").exists():
|
||||||
|
shutil.rmtree(graph_partition_dir / "build")
|
||||||
|
|
||||||
|
# Run the build script
|
||||||
|
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||||
|
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||||
|
|
||||||
|
# Check if executables were created
|
||||||
|
partitioner_path = self._get_executable_path("partitioner")
|
||||||
|
relayout_path = self._get_executable_path("index_relayout")
|
||||||
|
|
||||||
|
print(f"✅ Built partitioner: {partitioner_path}")
|
||||||
|
print(f"✅ Built index_relayout: {relayout_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to build executables: {e}")
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
def partition_graph(
|
||||||
|
self,
|
||||||
|
index_prefix_path: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
partition_prefix: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Partition a disk-based index for improved performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||||
|
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||||
|
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||||
|
**kwargs: Additional parameters for graph partitioning:
|
||||||
|
- gp_times: Number of LDG partition iterations (default: 10)
|
||||||
|
- lock_nums: Number of lock nodes (default: 10)
|
||||||
|
- cut: Cut adjacency list degree (default: 100)
|
||||||
|
- scale_factor: Scale factor (default: 1)
|
||||||
|
- data_type: Data type (default: "float")
|
||||||
|
- thread_nums: Number of threads (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the partitioning process fails
|
||||||
|
"""
|
||||||
|
# Set default parameters
|
||||||
|
params = {
|
||||||
|
"gp_times": 10,
|
||||||
|
"lock_nums": 10,
|
||||||
|
"cut": 100,
|
||||||
|
"scale_factor": 1,
|
||||||
|
"data_type": "float",
|
||||||
|
"thread_nums": 10,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine output directory
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = str(Path(index_prefix_path).parent)
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine partition prefix
|
||||||
|
if partition_prefix is None:
|
||||||
|
partition_prefix = Path(index_prefix_path).name
|
||||||
|
|
||||||
|
# Get executable paths
|
||||||
|
partitioner_path = self._get_executable_path("partitioner")
|
||||||
|
relayout_path = self._get_executable_path("index_relayout")
|
||||||
|
|
||||||
|
# Create temporary directory for processing
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Change to the graph_partition directory for temporary files
|
||||||
|
graph_partition_dir = (
|
||||||
|
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.chdir(graph_partition_dir)
|
||||||
|
|
||||||
|
# Create temporary data directory
|
||||||
|
temp_data_dir = Path(temp_dir) / "data"
|
||||||
|
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Set up paths for temporary files
|
||||||
|
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||||
|
graph_gp_path = (
|
||||||
|
graph_path
|
||||||
|
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||||
|
)
|
||||||
|
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Find input index file
|
||||||
|
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||||
|
if not os.path.exists(old_index_file):
|
||||||
|
old_index_file = f"{index_prefix_path}_disk.index"
|
||||||
|
|
||||||
|
if not os.path.exists(old_index_file):
|
||||||
|
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||||
|
|
||||||
|
# Run partitioner
|
||||||
|
gp_file_path = graph_gp_path / "_part.bin"
|
||||||
|
partitioner_cmd = [
|
||||||
|
partitioner_path,
|
||||||
|
"--index_file",
|
||||||
|
old_index_file,
|
||||||
|
"--data_type",
|
||||||
|
params["data_type"],
|
||||||
|
"--gp_file",
|
||||||
|
str(gp_file_path),
|
||||||
|
"-T",
|
||||||
|
str(params["thread_nums"]),
|
||||||
|
"--ldg_times",
|
||||||
|
str(params["gp_times"]),
|
||||||
|
"--scale",
|
||||||
|
str(params["scale_factor"]),
|
||||||
|
"--mode",
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||||
|
result = subprocess.run(
|
||||||
|
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Partitioner failed with return code {result.returncode}.\n"
|
||||||
|
f"stdout: {result.stdout}\n"
|
||||||
|
f"stderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run relayout
|
||||||
|
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||||
|
relayout_cmd = [
|
||||||
|
relayout_path,
|
||||||
|
old_index_file,
|
||||||
|
str(gp_file_path),
|
||||||
|
params["data_type"],
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||||
|
result = subprocess.run(
|
||||||
|
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Relayout failed with return code {result.returncode}.\n"
|
||||||
|
f"stdout: {result.stdout}\n"
|
||||||
|
f"stderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy results to output directory
|
||||||
|
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||||
|
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||||
|
|
||||||
|
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||||
|
shutil.copy2(gp_file_path, partition_bin_path)
|
||||||
|
|
||||||
|
print(f"Results copied to: {output_dir}")
|
||||||
|
return str(disk_graph_path), str(partition_bin_path)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about a partition file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
partition_bin_path: Path to the partition binary file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing partition information
|
||||||
|
"""
|
||||||
|
if not os.path.exists(partition_bin_path):
|
||||||
|
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||||
|
|
||||||
|
# For now, return basic file information
|
||||||
|
# In the future, this could parse the binary file for detailed info
|
||||||
|
stat = os.stat(partition_bin_path)
|
||||||
|
return {
|
||||||
|
"file_size": stat.st_size,
|
||||||
|
"file_path": partition_bin_path,
|
||||||
|
"modified_time": stat.st_mtime,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def partition_graph(
|
||||||
|
index_prefix_path: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
partition_prefix: Optional[str] = None,
|
||||||
|
build_type: str = "release",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Convenience function to partition a graph index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_prefix_path: Path to the index prefix
|
||||||
|
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||||
|
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||||
|
build_type: Build type for executables ("debug" or "release")
|
||||||
|
**kwargs: Additional parameters for graph partitioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||||
|
"""
|
||||||
|
partitioner = GraphPartitioner(build_type=build_type)
|
||||||
|
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example: partition an index
|
||||||
|
try:
|
||||||
|
disk_graph_path, partition_bin_path = partition_graph(
|
||||||
|
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||||
|
)
|
||||||
|
print("Partitioning completed successfully!")
|
||||||
|
print(f"Disk graph index: {disk_graph_path}")
|
||||||
|
print(f"Partition binary: {partition_bin_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Partitioning failed: {e}")
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc # Import garbage collector interface
|
import gc # Import garbage collector interface
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
@@ -7,6 +8,12 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# Set up logging to avoid print buffer issues
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
@@ -243,6 +250,8 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
output_filename: Output CSR index file
|
output_filename: Output CSR index file
|
||||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||||
"""
|
"""
|
||||||
|
# Keep prints simple; rely on CI runner to flush output as needed
|
||||||
|
|
||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,7 +34,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Union[str, None] = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -82,199 +82,317 @@ def create_hnsw_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Convert relative paths to absolute paths based on metadata file location
|
# Let PassageManager handle path resolution uniformly. It supports fallback order:
|
||||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
|
||||||
passage_sources = []
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
for source in meta["passage_sources"]:
|
# Dimension from metadata for shaping responses
|
||||||
source_copy = source.copy()
|
try:
|
||||||
# Convert relative paths to absolute paths
|
embedding_dim: int = int(meta.get("dimensions", 0))
|
||||||
if not Path(source_copy["path"]).is_absolute():
|
except Exception:
|
||||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
embedding_dim = 0
|
||||||
if not Path(source_copy["index_path"]).is_absolute():
|
|
||||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
|
||||||
passage_sources.append(source_copy)
|
|
||||||
|
|
||||||
passages = PassageManager(passage_sources)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
def zmq_server_thread():
|
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||||
"""ZMQ server thread"""
|
|
||||||
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
Creates its own REP socket bound to zmq_port and polls with timeouts
|
||||||
|
to allow graceful shutdown.
|
||||||
|
"""
|
||||||
|
logger.info("ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
rep_socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
# Track last request type/length for shape-correct fallbacks
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||||
|
last_request_length = 0
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
try:
|
||||||
message_bytes = socket.recv()
|
while not shutdown_event.is_set():
|
||||||
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
try:
|
||||||
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
|
request_bytes = rep_socket.recv()
|
||||||
|
|
||||||
|
# Rest of the processing logic (same as original)
|
||||||
|
request = msgpack.unpackb(request_bytes)
|
||||||
|
|
||||||
|
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
|
response_bytes = msgpack.packb([model_name])
|
||||||
|
rep_socket.send(response_bytes)
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle direct text embedding request
|
# Handle direct text embedding request
|
||||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
if (
|
||||||
# Check if this is a direct text request (list of strings)
|
isinstance(request, list)
|
||||||
if all(isinstance(item, str) for item in request_payload):
|
and request
|
||||||
logger.info(
|
and all(isinstance(item, str) for item in request)
|
||||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
):
|
||||||
)
|
last_request_type = "text"
|
||||||
|
last_request_length = len(request)
|
||||||
# Use unified embedding computation (now with model caching)
|
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||||
embeddings = compute_embeddings(
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
request_payload, model_name, mode=embedding_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
response = embeddings.tolist()
|
|
||||||
socket.send(msgpack.packb(response))
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Handle distance calculation request: [[ids], [query_vector]]
|
||||||
if (
|
if (
|
||||||
isinstance(request_payload, list)
|
isinstance(request, list)
|
||||||
and len(request_payload) == 2
|
and len(request) == 2
|
||||||
and isinstance(request_payload[0], list)
|
and isinstance(request[0], list)
|
||||||
and isinstance(request_payload[1], list)
|
and isinstance(request[1], list)
|
||||||
):
|
):
|
||||||
node_ids = request_payload[0]
|
node_ids = request[0]
|
||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
# Handle nested [[ids]] shape defensively
|
||||||
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
|
node_ids = node_ids[0]
|
||||||
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
|
last_request_type = "distance"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
# Get embeddings for node IDs
|
# Gather texts for found ids
|
||||||
texts = []
|
texts: list[str] = []
|
||||||
for nid in node_ids:
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {nid}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
|
||||||
|
|
||||||
# Process embeddings
|
# Prepare full-length response with large sentinel values
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
large_distance = 1e9
|
||||||
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts, model_name, mode=embedding_mode
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate distances
|
|
||||||
if distance_metric == "l2":
|
if distance_metric == "l2":
|
||||||
distances = np.sum(
|
partial = np.sum(
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
)
|
)
|
||||||
else: # mips or cosine
|
else: # mips or cosine
|
||||||
distances = -np.dot(embeddings, query_vector)
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
response_payload = distances.flatten().tolist()
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
response_distances[pos] = float(dval)
|
||||||
logger.debug(f"Sending distance response with {len(distances)} distances")
|
except Exception as e:
|
||||||
|
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||||
|
|
||||||
socket.send(response_bytes)
|
# Send response in expected shape [[distances]]
|
||||||
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Standard embedding request (passage ID lookup)
|
# Fallback: treat as embedding-by-id request
|
||||||
if (
|
if (
|
||||||
not isinstance(request_payload, list)
|
isinstance(request, list)
|
||||||
or len(request_payload) != 1
|
and len(request) == 1
|
||||||
or not isinstance(request_payload[0], list)
|
and isinstance(request[0], list)
|
||||||
):
|
):
|
||||||
logger.error(
|
node_ids = request[0]
|
||||||
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
elif isinstance(request, list):
|
||||||
)
|
node_ids = request
|
||||||
socket.send(msgpack.packb([[], []]))
|
else:
|
||||||
continue
|
node_ids = []
|
||||||
|
last_request_type = "embedding"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
node_ids = request_payload[0]
|
# Preallocate zero-filled flat data for robustness
|
||||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
if embedding_dim <= 0:
|
||||||
|
dims = [0, 0]
|
||||||
|
flat_data: list[float] = []
|
||||||
|
else:
|
||||||
|
dims = [len(node_ids), embedding_dim]
|
||||||
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
# Look up texts by node IDs
|
# Collect texts for found ids
|
||||||
texts = []
|
texts: list[str] = []
|
||||||
for nid in node_ids:
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data.get("text", "")
|
||||||
if not txt:
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {nid}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
|
||||||
|
|
||||||
# Process embeddings
|
if texts:
|
||||||
|
try:
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Serialization and response
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
logger.error(
|
logger.error(
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
)
|
)
|
||||||
raise AssertionError()
|
dims = [0, embedding_dim]
|
||||||
|
flat_data = []
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
else:
|
||||||
response_payload = [
|
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
list(hidden_contiguous_f32.shape),
|
flat = emb_f32.flatten().tolist()
|
||||||
hidden_contiguous_f32.flatten().tolist(),
|
for j, pos in enumerate(found_indices):
|
||||||
|
start = pos * embedding_dim
|
||||||
|
end = start + embedding_dim
|
||||||
|
if end <= len(flat_data):
|
||||||
|
flat_data[start:end] = flat[
|
||||||
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
]
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||||
|
|
||||||
|
response_payload = [dims, flat_data]
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
|
||||||
socket.send(response_bytes)
|
rep_socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
# Timeout - check shutdown_event and continue
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
import traceback
|
# Shape-correct fallback
|
||||||
|
try:
|
||||||
|
if last_request_type == "distance":
|
||||||
|
large_distance = 1e9
|
||||||
|
fallback_len = max(0, int(last_request_length))
|
||||||
|
safe = [[large_distance] * fallback_len]
|
||||||
|
elif last_request_type == "embedding":
|
||||||
|
bsz = max(0, int(last_request_length))
|
||||||
|
dim = max(0, int(embedding_dim))
|
||||||
|
safe = (
|
||||||
|
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||||
|
)
|
||||||
|
elif last_request_type == "text":
|
||||||
|
safe = [] # direct text embeddings expectation is a flat list
|
||||||
|
else:
|
||||||
|
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
traceback.print_exc()
|
logger.info("ZMQ server thread exiting gracefully")
|
||||||
socket.send(msgpack.packb([[], []]))
|
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
# Add shutdown coordination
|
||||||
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
|
def shutdown_zmq_server():
|
||||||
|
"""Gracefully shutdown ZMQ server."""
|
||||||
|
logger.info("Initiating graceful shutdown...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.info("Waiting for ZMQ thread to finish...")
|
||||||
|
zmq_thread.join(timeout=5)
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.warning("ZMQ thread did not finish in time")
|
||||||
|
|
||||||
|
# Clean up ZMQ resources
|
||||||
|
try:
|
||||||
|
# Note: socket and context are cleaned up by thread exit
|
||||||
|
logger.info("ZMQ resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||||
|
|
||||||
|
# Clean up other resources
|
||||||
|
try:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Additional resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning additional resources: {e}")
|
||||||
|
|
||||||
|
logger.info("Graceful shutdown completed")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers within this function scope
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
shutdown_zmq_server()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Pass shutdown_event to ZMQ thread
|
||||||
|
zmq_thread = threading.Thread(
|
||||||
|
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||||
|
daemon=False, # Not daemon - we want to wait for it
|
||||||
|
)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(0.1) # Check shutdown more frequently
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
|
shutdown_zmq_server()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we reach here, shutdown was triggered by signal
|
||||||
|
logger.info("Main loop exited, process should be shutting down")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# Signal handlers are now registered within create_hnsw_embedding_server
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
|||||||
@@ -115,20 +115,62 @@ class SearchResult:
|
|||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
def __init__(
|
||||||
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||||
|
):
|
||||||
self.offset_maps = {}
|
self.offset_maps = {}
|
||||||
self.passage_files = {}
|
self.passage_files = {}
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
|
|
||||||
|
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
||||||
|
index_name_base = None
|
||||||
|
if metadata_file_path:
|
||||||
|
meta_name = Path(metadata_file_path).name
|
||||||
|
if meta_name.endswith(".meta.json"):
|
||||||
|
index_name_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source.get("path", "")
|
||||||
index_file = source["index_path"] # .idx file
|
index_file = source.get("index_path", "") # .idx file
|
||||||
|
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution - relative paths should be relative to metadata file directory
|
||||||
if not Path(index_file).is_absolute():
|
def _resolve_candidates(
|
||||||
# If relative path, try to resolve it properly
|
primary: str,
|
||||||
index_file = str(Path(index_file).resolve())
|
relative_key: str,
|
||||||
|
default_name: Optional[str],
|
||||||
|
source_dict: dict[str, Any],
|
||||||
|
) -> list[Path]:
|
||||||
|
candidates: list[Path] = []
|
||||||
|
# 1) Primary as-is (absolute or relative)
|
||||||
|
if primary:
|
||||||
|
p = Path(primary)
|
||||||
|
candidates.append(p if p.is_absolute() else (Path.cwd() / p))
|
||||||
|
# 2) metadata-relative explicit relative key
|
||||||
|
if metadata_file_path and source_dict.get(relative_key):
|
||||||
|
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
||||||
|
# 3) metadata-relative standard sibling filename
|
||||||
|
if metadata_file_path and default_name:
|
||||||
|
candidates.append(Path(metadata_file_path).parent / default_name)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
|
||||||
|
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
|
||||||
|
idx_candidates = _resolve_candidates(
|
||||||
|
index_file, "index_path_relative", idx_default, source
|
||||||
|
)
|
||||||
|
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
|
||||||
|
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
|
||||||
|
|
||||||
|
def _pick_existing(cands: list[Path]) -> str:
|
||||||
|
for c in cands:
|
||||||
|
if c.exists():
|
||||||
|
return str(c.resolve())
|
||||||
|
# Fallback to last candidate (best guess) even if not exists; will error below
|
||||||
|
return str(cands[-1].resolve()) if cands else ""
|
||||||
|
|
||||||
|
index_file = _pick_existing(idx_candidates)
|
||||||
|
passage_file = _pick_existing(pas_candidates)
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
@@ -314,8 +356,12 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": str(passages_file),
|
# Preserve existing relative file names (backward-compatible)
|
||||||
"index_path": str(offset_file),
|
"path": passages_file.name,
|
||||||
|
"index_path": offset_file.name,
|
||||||
|
# Add optional redundant relative keys for remote build portability (non-breaking)
|
||||||
|
"path_relative": passages_file.name,
|
||||||
|
"index_path_relative": offset_file.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -430,8 +476,12 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": str(passages_file),
|
# Preserve existing relative file names (backward-compatible)
|
||||||
"index_path": str(offset_file),
|
"path": passages_file.name,
|
||||||
|
"index_path": offset_file.name,
|
||||||
|
# Add optional redundant relative keys for remote build portability (non-breaking)
|
||||||
|
"path_relative": passages_file.name,
|
||||||
|
"index_path_relative": offset_file.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"built_from_precomputed_embeddings": True,
|
"built_from_precomputed_embeddings": True,
|
||||||
@@ -473,7 +523,9 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
self.passage_manager = PassageManager(
|
||||||
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
|
)
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
@@ -546,13 +598,13 @@ class LeannSearcher:
|
|||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
time.time() - start_time
|
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
|
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0])
|
||||||
):
|
):
|
||||||
@@ -580,13 +632,26 @@ class LeannSearcher:
|
|||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
|
RESET = "\033[0m"
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define color codes outside the loop for final message
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
RESET = "\033[0m"
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
|
This method should be called after you're done using the searcher,
|
||||||
|
especially in test environments or batch processing scenarios.
|
||||||
|
"""
|
||||||
|
if hasattr(self.backend_impl, "embedding_server_manager"):
|
||||||
|
self.backend_impl.embedding_server_manager.stop_server()
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -656,3 +721,12 @@ class LeannChat:
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
print("\nGoodbye!")
|
print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
|
This method should be called after you're done using the chat interface,
|
||||||
|
especially in test environments or batch processing scenarios.
|
||||||
|
"""
|
||||||
|
if hasattr(self.searcher, "cleanup"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import psutil
|
# Lightweight, self-contained server manager with no cross-process inspection
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -43,130 +43,7 @@ def _check_port(port: int) -> bool:
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
def _check_process_matches_config(
|
# Note: All cross-process scanning helpers removed for simplicity
|
||||||
port: int, expected_model: str, expected_passages_file: str
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the process using the port matches our expected model and passages file.
|
|
||||||
Returns True if matches, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
||||||
if not _is_process_listening_on_port(proc, port):
|
|
||||||
continue
|
|
||||||
|
|
||||||
cmdline = proc.info["cmdline"]
|
|
||||||
if not cmdline:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return _check_cmdline_matches_config(
|
|
||||||
cmdline, port, expected_model, expected_passages_file
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"No process found listening on port {port}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not check process on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
|
||||||
"""Check if a process is listening on the given port."""
|
|
||||||
try:
|
|
||||||
connections = proc.net_connections()
|
|
||||||
for conn in connections:
|
|
||||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _check_cmdline_matches_config(
|
|
||||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
|
||||||
) -> bool:
|
|
||||||
"""Check if command line matches our expected configuration."""
|
|
||||||
cmdline_str = " ".join(cmdline)
|
|
||||||
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
|
||||||
|
|
||||||
# Check if it's our embedding server
|
|
||||||
is_embedding_server = any(
|
|
||||||
server_type in cmdline_str
|
|
||||||
for server_type in [
|
|
||||||
"embedding_server",
|
|
||||||
"leann_backend_diskann.embedding_server",
|
|
||||||
"leann_backend_hnsw.hnsw_embedding_server",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_embedding_server:
|
|
||||||
logger.debug(f"Process on port {port} is not our embedding server")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check model name
|
|
||||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
|
||||||
|
|
||||||
# Check passages file if provided
|
|
||||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
|
||||||
|
|
||||||
result = model_matches and passages_matches
|
|
||||||
logger.debug(
|
|
||||||
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
|
||||||
"""Check if the command line contains the expected model."""
|
|
||||||
if "--model-name" not in cmdline:
|
|
||||||
return False
|
|
||||||
|
|
||||||
model_idx = cmdline.index("--model-name")
|
|
||||||
if model_idx + 1 >= len(cmdline):
|
|
||||||
return False
|
|
||||||
|
|
||||||
actual_model = cmdline[model_idx + 1]
|
|
||||||
return actual_model == expected_model
|
|
||||||
|
|
||||||
|
|
||||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
|
||||||
"""Check if the command line contains the expected passages file."""
|
|
||||||
if "--passages-file" not in cmdline:
|
|
||||||
return False # Expected but not found
|
|
||||||
|
|
||||||
passages_idx = cmdline.index("--passages-file")
|
|
||||||
if passages_idx + 1 >= len(cmdline):
|
|
||||||
return False
|
|
||||||
|
|
||||||
actual_passages = cmdline[passages_idx + 1]
|
|
||||||
expected_path = Path(expected_passages_file).resolve()
|
|
||||||
actual_path = Path(actual_passages).resolve()
|
|
||||||
return actual_path == expected_path
|
|
||||||
|
|
||||||
|
|
||||||
def _find_compatible_port_or_next_available(
|
|
||||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
|
||||||
) -> tuple[int, bool]:
|
|
||||||
"""
|
|
||||||
Find a port that either has a compatible server or is available.
|
|
||||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
|
||||||
"""
|
|
||||||
for port in range(start_port, start_port + max_attempts):
|
|
||||||
if not _check_port(port):
|
|
||||||
# Port is available
|
|
||||||
return port, False
|
|
||||||
|
|
||||||
# Port is in use, check if it's compatible
|
|
||||||
if _check_process_matches_config(port, model_name, passages_file):
|
|
||||||
logger.info(f"Found compatible server on port {port}")
|
|
||||||
return port, True
|
|
||||||
else:
|
|
||||||
logger.info(f"Port {port} has incompatible server, trying next port...")
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
@@ -185,7 +62,16 @@ class EmbeddingServerManager:
|
|||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: Optional[int] = None
|
||||||
|
# Track last-started config for in-process reuse only
|
||||||
|
self._server_config: Optional[dict] = None
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
|
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
|
||||||
|
try:
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
self._finalizer = weakref.finalize(self, self._finalize_process)
|
||||||
|
except Exception:
|
||||||
|
self._finalizer = None
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
self,
|
self,
|
||||||
@@ -195,26 +81,24 @@ class EmbeddingServerManager:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
passages_file = kwargs.get("passages_file")
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
|
|
||||||
# Check if we have a compatible server already running
|
# If this manager already has a live server, just reuse it
|
||||||
if self._has_compatible_running_server(model_name, passages_file):
|
if self.server_process and self.server_process.poll() is None and self.server_port:
|
||||||
logger.info("Found compatible running server!")
|
logger.info("Reusing in-process server")
|
||||||
return True, port
|
return True, self.server_port
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
# Find a compatible port or next available
|
# Always pick a fresh available port
|
||||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
try:
|
||||||
port, model_name, passages_file
|
actual_port = _get_available_port(port)
|
||||||
)
|
except RuntimeError:
|
||||||
|
logger.error("No available ports found")
|
||||||
if is_compatible:
|
return False, port
|
||||||
logger.info(f"Found compatible server on port {actual_port}")
|
|
||||||
return True, actual_port
|
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
@@ -247,17 +131,7 @@ class EmbeddingServerManager:
|
|||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
|
|
||||||
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||||
"""Check if we have a compatible running server."""
|
|
||||||
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
|
||||||
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info("Existing server process is incompatible. Should start a new server.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
@@ -304,22 +178,61 @@ class EmbeddingServerManager:
|
|||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
# Let server output go directly to console
|
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
|
||||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
# Embedding servers use many print statements that can fill stdout buffers
|
||||||
|
is_ci = os.environ.get("CI") == "true"
|
||||||
|
if is_ci:
|
||||||
|
stdout_target = subprocess.DEVNULL
|
||||||
|
stderr_target = None # Keep stderr for error debugging in CI
|
||||||
|
logger.info(
|
||||||
|
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stdout_target = None # Direct to console for visible logs
|
||||||
|
stderr_target = None # Direct to console for visible logs
|
||||||
|
|
||||||
|
# Start embedding server subprocess
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=None, # Direct to console
|
stdout=stdout_target,
|
||||||
stderr=None, # Direct to console
|
stderr=stderr_target,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
|
# Record config for in-process reuse
|
||||||
|
try:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
|
if "--model-name" in command
|
||||||
|
else "",
|
||||||
|
"passages_file": command[command.index("--passages-file") + 1]
|
||||||
|
if "--passages-file" in command
|
||||||
|
else "",
|
||||||
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
|
if "--embedding-mode" in command
|
||||||
|
else "sentence-transformers",
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": "",
|
||||||
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
# Use a lambda to avoid issues with bound methods
|
# Always attempt best-effort finalize at interpreter exit
|
||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
# Touch finalizer so it knows there is a live process
|
||||||
|
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
|
||||||
|
try:
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
self._finalizer = weakref.finalize(self, self._finalize_process)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
@@ -344,22 +257,26 @@ class EmbeddingServerManager:
|
|||||||
if not self.server_process:
|
if not self.server_process:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.server_process.poll() is not None:
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
# Process already terminated
|
# Process already terminated
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use simple termination - our improved server shutdown should handle this properly
|
||||||
self.server_process.terminate()
|
self.server_process.terminate()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=3)
|
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
|
||||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
|
||||||
)
|
)
|
||||||
self.server_process.kill()
|
self.server_process.kill()
|
||||||
try:
|
try:
|
||||||
@@ -369,15 +286,33 @@ class EmbeddingServerManager:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
||||||
)
|
)
|
||||||
# Don't hang indefinitely
|
|
||||||
|
|
||||||
# Clean up process resources to prevent resource tracker warnings
|
# Clean up process resources with timeout to avoid CI hang
|
||||||
try:
|
try:
|
||||||
self.server_process.wait() # Ensure process is fully cleaned up
|
# Use shorter timeout in CI environments
|
||||||
|
is_ci = os.environ.get("CI") == "true"
|
||||||
|
timeout = 3 if is_ci else 10
|
||||||
|
self.server_process.wait(timeout=timeout)
|
||||||
|
logger.info(f"Server process {self.server_process.pid} cleanup completed")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during process cleanup: {e}")
|
||||||
|
finally:
|
||||||
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
|
|
||||||
|
def _finalize_process(self) -> None:
|
||||||
|
"""Best-effort cleanup used by weakref.finalize/atexit."""
|
||||||
|
try:
|
||||||
|
self.stop_server()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.server_process = None
|
def _adopt_existing_server(self, *args, **kwargs) -> None:
|
||||||
|
# Removed: cross-process adoption no longer supported
|
||||||
|
return
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
@@ -393,10 +328,16 @@ class EmbeddingServerManager:
|
|||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback
|
# Register atexit callback (unified)
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
# Record config for in-process reuse is best-effort in Colab mode
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": "",
|
||||||
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(
|
||||||
self, passages_source_file: str, port: Union[int, None], **kwargs
|
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Ensure server is running"""
|
"""Ensure server is running"""
|
||||||
pass
|
pass
|
||||||
@@ -50,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Union[int, None] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -76,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Union[int, None] = None,
|
zmq_port: Optional[int] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ def handle_request(request):
|
|||||||
f"--top-k={args.get('top_k', 5)}",
|
f"--top-k={args.get('top_k', 5)}",
|
||||||
f"--complexity={args.get('complexity', 32)}",
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
elif tool_name == "leann_status":
|
elif tool_name == "leann_status":
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ dependencies = [
|
|||||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
|
"pybind11>=3.0.0",
|
||||||
"pathspec>=0.12.1",
|
"pathspec>=0.12.1",
|
||||||
"nbconvert>=7.16.6",
|
"nbconvert>=7.16.6",
|
||||||
"gitignore-parser>=0.1.12",
|
"gitignore-parser>=0.1.12",
|
||||||
@@ -54,7 +55,7 @@ dev = [
|
|||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
"pytest-xdist>=3.0", # For parallel test execution
|
"pytest-xdist>=3.0", # For parallel test execution
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff>=0.1.0",
|
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
"pre-commit>=3.5.0",
|
"pre-commit>=3.5.0",
|
||||||
@@ -154,7 +155,7 @@ markers = [
|
|||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
]
|
]
|
||||||
timeout = 600
|
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||||
addopts = [
|
addopts = [
|
||||||
"-v",
|
"-v",
|
||||||
"--tb=short",
|
"--tb=short",
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ This directory contains automated tests for the LEANN project using pytest.
|
|||||||
|
|
||||||
### `test_readme_examples.py`
|
### `test_readme_examples.py`
|
||||||
Tests the examples shown in README.md:
|
Tests the examples shown in README.md:
|
||||||
- The basic example code that users see first
|
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
||||||
- Import statements work correctly
|
- Import statements work correctly
|
||||||
- Different backend options (HNSW, DiskANN)
|
- Different backend options (HNSW, DiskANN)
|
||||||
- Different LLM configuration options
|
- Different LLM configuration options (parametrized for both backends)
|
||||||
|
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
||||||
|
|
||||||
### `test_basic.py`
|
### `test_basic.py`
|
||||||
Basic functionality tests that verify:
|
Basic functionality tests that verify:
|
||||||
@@ -25,6 +26,16 @@ Tests the document RAG example functionality:
|
|||||||
- Tests error handling with invalid parameters
|
- Tests error handling with invalid parameters
|
||||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||||
|
|
||||||
|
### `test_diskann_partition.py`
|
||||||
|
Tests DiskANN graph partitioning functionality:
|
||||||
|
- Tests DiskANN index building without partitioning (baseline)
|
||||||
|
- Tests automatic graph partitioning with `is_recompute=True`
|
||||||
|
- Verifies that partition files are created and large files are cleaned up for storage saving
|
||||||
|
- Tests search functionality with partitioned indices
|
||||||
|
- Validates medoid and max_base_norm file generation and usage
|
||||||
|
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||||
|
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -54,15 +65,23 @@ pytest tests/ -m "not openai"
|
|||||||
|
|
||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
|
# Run DiskANN partition tests (requires local machine, not CI)
|
||||||
|
pytest tests/test_diskann_partition.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run with specific backend:
|
### Run with specific backend:
|
||||||
```bash
|
```bash
|
||||||
# Test only HNSW backend
|
# Test only HNSW backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||||
|
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
||||||
|
|
||||||
# Test only DiskANN backend
|
# Test only DiskANN backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||||
|
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
||||||
|
|
||||||
|
# All DiskANN tests (parametrized + specialized partition tests)
|
||||||
|
pytest tests/ -k diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
@@ -64,6 +64,9 @@ def test_backend_basic(backend_name):
|
|||||||
assert isinstance(results[0], SearchResult)
|
assert isinstance(results[0], SearchResult)
|
||||||
assert "topic 2" in results[0].text or "document" in results[0].text
|
assert "topic 2" in results[0].text or "document" in results[0].text
|
||||||
|
|
||||||
|
# Ensure cleanup to avoid hanging background servers
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||||
@@ -90,3 +93,5 @@ def test_large_index():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
results = searcher.search(["word10 word20"], top_k=10)
|
results = searcher.search(["word10 word20"], top_k=10)
|
||||||
assert len(results[0]) == 10
|
assert len(results[0]) == 10
|
||||||
|
# Cleanup
|
||||||
|
searcher.cleanup()
|
||||||
|
|||||||
369
tests/test_diskann_partition.py
Normal file
369
tests/test_diskann_partition.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
"""
|
||||||
|
Test DiskANN graph partitioning functionality.
|
||||||
|
|
||||||
|
Tests the automatic graph partitioning feature that was implemented to save
|
||||||
|
storage space by partitioning large DiskANN indices and safely deleting
|
||||||
|
redundant files while maintaining search functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_without_partition():
|
||||||
|
"""Test DiskANN index building without partition (baseline)."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_no_partition.leann")
|
||||||
|
|
||||||
|
# Test data - enough to trigger index building
|
||||||
|
texts = [
|
||||||
|
f"Document {i} discusses topic {i % 10} with detailed analysis of subject {i // 10}."
|
||||||
|
for i in range(500)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build without partition (is_recompute=False)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
num_neighbors=32,
|
||||||
|
search_list_size=50,
|
||||||
|
is_recompute=False, # No partition
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
# Verify index was created
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
assert index_dir.exists()
|
||||||
|
|
||||||
|
# Check that traditional DiskANN files exist
|
||||||
|
index_prefix = Path(index_path).stem
|
||||||
|
# Core DiskANN files (beam search index may not be created for small datasets)
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}_disk.index",
|
||||||
|
f"{index_prefix}_pq_compressed.bin",
|
||||||
|
f"{index_prefix}_pq_pivots.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check all generated files first for debugging
|
||||||
|
generated_files = [f.name for f in index_dir.glob(f"{index_prefix}*")]
|
||||||
|
print(f"Generated files: {generated_files}")
|
||||||
|
|
||||||
|
for required_file in required_files:
|
||||||
|
file_path = index_dir / required_file
|
||||||
|
assert file_path.exists(), f"Required file {required_file} not found"
|
||||||
|
|
||||||
|
# Ensure no partition files exist in non-partition mode
|
||||||
|
partition_files = [f"{index_prefix}_disk_graph.index", f"{index_prefix}_partition.bin"]
|
||||||
|
|
||||||
|
for partition_file in partition_files:
|
||||||
|
file_path = index_dir / partition_file
|
||||||
|
assert not file_path.exists(), (
|
||||||
|
f"Partition file {partition_file} should not exist in non-partition mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test search functionality
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
results = searcher.search("topic 3 analysis", top_k=3)
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
assert all(result.score is not None and result.score != float("-inf") for result in results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_with_partition():
|
||||||
|
"""Test DiskANN index building with automatic graph partitioning."""
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_with_partition.leann")
|
||||||
|
|
||||||
|
# Test data - enough to trigger partitioning
|
||||||
|
texts = [
|
||||||
|
f"Document {i} explores subject {i % 15} with comprehensive coverage of area {i // 15}."
|
||||||
|
for i in range(500)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build with partition (is_recompute=True)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
num_neighbors=32,
|
||||||
|
search_list_size=50,
|
||||||
|
is_recompute=True, # Enable automatic partitioning
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
# Verify index was created
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
assert index_dir.exists()
|
||||||
|
|
||||||
|
# Check that partition files exist
|
||||||
|
index_prefix = Path(index_path).stem
|
||||||
|
partition_files = [
|
||||||
|
f"{index_prefix}_disk_graph.index", # Partitioned graph
|
||||||
|
f"{index_prefix}_partition.bin", # Partition metadata
|
||||||
|
f"{index_prefix}_pq_compressed.bin",
|
||||||
|
f"{index_prefix}_pq_pivots.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
for partition_file in partition_files:
|
||||||
|
file_path = index_dir / partition_file
|
||||||
|
assert file_path.exists(), f"Expected partition file {partition_file} not found"
|
||||||
|
|
||||||
|
# Check that large files were cleaned up (storage saving goal)
|
||||||
|
large_files = [f"{index_prefix}_disk.index", f"{index_prefix}_disk_beam_search.index"]
|
||||||
|
|
||||||
|
for large_file in large_files:
|
||||||
|
file_path = index_dir / large_file
|
||||||
|
assert not file_path.exists(), (
|
||||||
|
f"Large file {large_file} should have been deleted for storage saving"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify required auxiliary files for partition mode exist
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}_disk.index_medoids.bin",
|
||||||
|
f"{index_prefix}_disk.index_max_base_norm.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
for req_file in required_files:
|
||||||
|
file_path = index_dir / req_file
|
||||||
|
assert file_path.exists(), (
|
||||||
|
f"Required auxiliary file {req_file} missing for partition mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_partition_search_functionality():
|
||||||
|
"""Test that search works correctly with partitioned indices."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_partition_search.leann")
|
||||||
|
|
||||||
|
# Create diverse test data
|
||||||
|
texts = [
|
||||||
|
"LEANN is a storage-efficient approximate nearest neighbor search system.",
|
||||||
|
"Graph partitioning helps reduce memory usage in large scale vector search.",
|
||||||
|
"DiskANN provides high-performance disk-based approximate nearest neighbor search.",
|
||||||
|
"Vector embeddings enable semantic search over unstructured text data.",
|
||||||
|
"Approximate nearest neighbor algorithms trade accuracy for speed and storage.",
|
||||||
|
] * 100 # Repeat to get enough data
|
||||||
|
|
||||||
|
# Build with partitioning
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True, # Enable partitioning
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
# Test search with partitioned index
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Test various queries
|
||||||
|
test_queries = [
|
||||||
|
("vector search algorithms", 5),
|
||||||
|
("LEANN storage efficiency", 3),
|
||||||
|
("graph partitioning memory", 4),
|
||||||
|
("approximate nearest neighbor", 7),
|
||||||
|
]
|
||||||
|
|
||||||
|
for query, top_k in test_queries:
|
||||||
|
results = searcher.search(query, top_k=top_k)
|
||||||
|
|
||||||
|
# Verify search results
|
||||||
|
assert len(results) == top_k, f"Expected {top_k} results for query '{query}'"
|
||||||
|
assert all(result.score is not None for result in results), (
|
||||||
|
"All results should have scores"
|
||||||
|
)
|
||||||
|
assert all(result.score != float("-inf") for result in results), (
|
||||||
|
"No result should have -inf score"
|
||||||
|
)
|
||||||
|
assert all(result.text is not None for result in results), (
|
||||||
|
"All results should have text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scores should be in descending order (higher similarity first)
|
||||||
|
scores = [result.score for result in results]
|
||||||
|
assert scores == sorted(scores, reverse=True), (
|
||||||
|
"Results should be sorted by score descending"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_medoid_and_norm_files():
|
||||||
|
"""Test that medoid and max_base_norm files are correctly generated and used."""
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_medoid_norm.leann")
|
||||||
|
|
||||||
|
# Small but sufficient dataset
|
||||||
|
texts = [f"Test document {i} with content about subject {i % 10}." for i in range(200)]
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_prefix = Path(index_path).stem
|
||||||
|
|
||||||
|
# Test medoids file
|
||||||
|
medoids_file = index_dir / f"{index_prefix}_disk.index_medoids.bin"
|
||||||
|
assert medoids_file.exists(), "Medoids file should be generated"
|
||||||
|
|
||||||
|
# Read and validate medoids file format
|
||||||
|
with open(medoids_file, "rb") as f:
|
||||||
|
nshards = struct.unpack("<I", f.read(4))[0]
|
||||||
|
one_val = struct.unpack("<I", f.read(4))[0]
|
||||||
|
medoid_id = struct.unpack("<I", f.read(4))[0]
|
||||||
|
|
||||||
|
assert nshards == 1, "Single-shot build should have 1 shard"
|
||||||
|
assert one_val == 1, "Expected value should be 1"
|
||||||
|
assert medoid_id >= 0, "Medoid ID should be valid (not hardcoded 0)"
|
||||||
|
|
||||||
|
# Test max_base_norm file
|
||||||
|
norm_file = index_dir / f"{index_prefix}_disk.index_max_base_norm.bin"
|
||||||
|
assert norm_file.exists(), "Max base norm file should be generated"
|
||||||
|
|
||||||
|
# Read and validate norm file
|
||||||
|
with open(norm_file, "rb") as f:
|
||||||
|
npts = struct.unpack("<I", f.read(4))[0]
|
||||||
|
ndims = struct.unpack("<I", f.read(4))[0]
|
||||||
|
norm_val = struct.unpack("<f", f.read(4))[0]
|
||||||
|
|
||||||
|
assert npts == 1, "Should have 1 norm point"
|
||||||
|
assert ndims == 1, "Should have 1 dimension"
|
||||||
|
assert norm_val > 0, "Norm value should be positive"
|
||||||
|
assert norm_val != float("inf"), "Norm value should be finite"
|
||||||
|
|
||||||
|
# Test that search works with these files
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
results = searcher.search("test subject", top_k=3)
|
||||||
|
|
||||||
|
# Verify that scores are not -inf (which indicates norm file was loaded correctly)
|
||||||
|
assert len(results) > 0
|
||||||
|
assert all(result.score != float("-inf") for result in results), (
|
||||||
|
"Scores should not be -inf when norm file is correct"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip performance comparison in CI - requires significant compute time",
|
||||||
|
)
|
||||||
|
def test_diskann_vs_hnsw_performance():
|
||||||
|
"""Compare DiskANN (with partition) vs HNSW performance."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Test data
|
||||||
|
texts = [
|
||||||
|
f"Performance test document {i} covering topic {i % 20} in detail." for i in range(1000)
|
||||||
|
]
|
||||||
|
query = "performance topic test"
|
||||||
|
|
||||||
|
# Test DiskANN with partitioning
|
||||||
|
diskann_path = str(Path(temp_dir) / "perf_diskann.leann")
|
||||||
|
diskann_builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
diskann_builder.add_text(text)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
diskann_builder.build_index(diskann_path)
|
||||||
|
|
||||||
|
# Test HNSW
|
||||||
|
hnsw_path = str(Path(temp_dir) / "perf_hnsw.leann")
|
||||||
|
hnsw_builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
hnsw_builder.add_text(text)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
hnsw_builder.build_index(hnsw_path)
|
||||||
|
|
||||||
|
# Compare search performance
|
||||||
|
diskann_searcher = LeannSearcher(diskann_path)
|
||||||
|
hnsw_searcher = LeannSearcher(hnsw_path)
|
||||||
|
|
||||||
|
# Warm up searches
|
||||||
|
diskann_searcher.search(query, top_k=5)
|
||||||
|
hnsw_searcher.search(query, top_k=5)
|
||||||
|
|
||||||
|
# Timed searches
|
||||||
|
start_time = time.time()
|
||||||
|
diskann_results = diskann_searcher.search(query, top_k=10)
|
||||||
|
diskann_search_time = time.time() - start_time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
hnsw_results = hnsw_searcher.search(query, top_k=10)
|
||||||
|
hnsw_search_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Basic assertions
|
||||||
|
assert len(diskann_results) == 10
|
||||||
|
assert len(hnsw_results) == 10
|
||||||
|
assert all(r.score != float("-inf") for r in diskann_results)
|
||||||
|
assert all(r.score != float("-inf") for r in hnsw_results)
|
||||||
|
|
||||||
|
# Performance ratio (informational)
|
||||||
|
if hnsw_search_time > 0:
|
||||||
|
speed_ratio = hnsw_search_time / diskann_search_time
|
||||||
|
print(f"DiskANN search time: {diskann_search_time:.4f}s")
|
||||||
|
print(f"HNSW search time: {hnsw_search_time:.4f}s")
|
||||||
|
print(f"DiskANN is {speed_ratio:.2f}x faster than HNSW")
|
||||||
@@ -58,6 +58,9 @@ def test_document_rag_simulated(test_data_dir):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true", reason="Skip OpenAI tests in CI to avoid API costs"
|
||||||
|
)
|
||||||
def test_document_rag_openai(test_data_dir):
|
def test_document_rag_openai(test_data_dir):
|
||||||
"""Test document_rag with OpenAI embeddings."""
|
"""Test document_rag with OpenAI embeddings."""
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
|||||||
@@ -10,29 +10,33 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_readme_basic_example():
|
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||||
"""Test the basic example from README.md."""
|
def test_readme_basic_example(backend_name):
|
||||||
|
"""Test the basic example from README.md with both backends."""
|
||||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||||
|
# Skip DiskANN on CI (Linux runners) due to C++ extension memory/hardware constraints
|
||||||
|
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
||||||
|
pytest.skip("Skip DiskANN tests in CI due to resource constraints and instability")
|
||||||
|
|
||||||
# This is the exact code from README (with smaller model for CI)
|
# This is the exact code from README (with smaller model for CI)
|
||||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from leann.api import SearchResult
|
from leann.api import SearchResult
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
|
INDEX_PATH = str(Path(temp_dir) / f"demo_{backend_name}.leann")
|
||||||
|
|
||||||
# Build an index
|
# Build an index
|
||||||
# In CI, use a smaller model to avoid memory issues
|
# In CI, use a smaller model to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
if os.environ.get("CI") == "true":
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name=backend_name,
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
||||||
dimensions=384, # Smaller dimensions
|
dimensions=384, # Smaller dimensions
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
builder = LeannBuilder(backend_name=backend_name)
|
||||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
@@ -52,9 +56,15 @@ def test_readme_basic_example():
|
|||||||
# Verify search results
|
# Verify search results
|
||||||
assert len(results) > 0
|
assert len(results) > 0
|
||||||
assert isinstance(results[0], SearchResult)
|
assert isinstance(results[0], SearchResult)
|
||||||
|
assert results[0].score != float("-inf"), (
|
||||||
|
f"should return valid scores, got {results[0].score}"
|
||||||
|
)
|
||||||
# The second text about banana-crocodile should be more relevant
|
# The second text about banana-crocodile should be more relevant
|
||||||
assert "banana" in results[0].text or "crocodile" in results[0].text
|
assert "banana" in results[0].text or "crocodile" in results[0].text
|
||||||
|
|
||||||
|
# Ensure we cleanup background embedding server
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
# Chat with your data (using simulated LLM to avoid external dependencies)
|
# Chat with your data (using simulated LLM to avoid external dependencies)
|
||||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
|
||||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
@@ -62,6 +72,8 @@ def test_readme_basic_example():
|
|||||||
# Verify chat works
|
# Verify chat works
|
||||||
assert isinstance(response, str)
|
assert isinstance(response, str)
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
|
# Cleanup chat resources
|
||||||
|
chat.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def test_readme_imports():
|
def test_readme_imports():
|
||||||
@@ -110,26 +122,31 @@ def test_backend_options():
|
|||||||
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_llm_config_simulated():
|
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||||
"""Test simulated LLM configuration option."""
|
def test_llm_config_simulated(backend_name):
|
||||||
|
"""Test simulated LLM configuration option with both backends."""
|
||||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
# Skip DiskANN tests in CI due to hardware requirements
|
||||||
|
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
||||||
|
pytest.skip("Skip DiskANN tests in CI - requires specific hardware and large memory")
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannChat
|
from leann import LeannBuilder, LeannChat
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Build a simple index
|
# Build a simple index
|
||||||
index_path = str(Path(temp_dir) / "test.leann")
|
index_path = str(Path(temp_dir) / f"test_{backend_name}.leann")
|
||||||
# Use smaller model in CI to avoid memory issues
|
# Use smaller model in CI to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
if os.environ.get("CI") == "true":
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name=backend_name,
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
dimensions=384,
|
dimensions=384,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
builder = LeannBuilder(backend_name=backend_name)
|
||||||
builder.add_text("Test document for LLM testing")
|
builder.add_text("Test document for LLM testing")
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|||||||
77
uv.lock
generated
77
uv.lock
generated
@@ -2223,7 +2223,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.2.6"
|
version = "0.2.8"
|
||||||
source = { editable = "packages/leann-backend-diskann" }
|
source = { editable = "packages/leann-backend-diskann" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "leann-core" },
|
{ name = "leann-core" },
|
||||||
@@ -2235,14 +2235,14 @@ dependencies = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "leann-core", specifier = "==0.2.6" },
|
{ name = "leann-core", specifier = "==0.2.8" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "protobuf", specifier = ">=3.19.0" },
|
{ name = "protobuf", specifier = ">=3.19.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.2.6"
|
version = "0.2.8"
|
||||||
source = { editable = "packages/leann-backend-hnsw" }
|
source = { editable = "packages/leann-backend-hnsw" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "leann-core" },
|
{ name = "leann-core" },
|
||||||
@@ -2255,7 +2255,7 @@ dependencies = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "leann-core", specifier = "==0.2.6" },
|
{ name = "leann-core", specifier = "==0.2.8" },
|
||||||
{ name = "msgpack", specifier = ">=1.0.0" },
|
{ name = "msgpack", specifier = ">=1.0.0" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "pyzmq", specifier = ">=23.0.0" },
|
{ name = "pyzmq", specifier = ">=23.0.0" },
|
||||||
@@ -2263,7 +2263,7 @@ requires-dist = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.2.6"
|
version = "0.2.8"
|
||||||
source = { editable = "packages/leann-core" }
|
source = { editable = "packages/leann-core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "accelerate" },
|
{ name = "accelerate" },
|
||||||
@@ -2272,8 +2272,8 @@ dependencies = [
|
|||||||
{ name = "llama-index-core" },
|
{ name = "llama-index-core" },
|
||||||
{ name = "llama-index-embeddings-huggingface" },
|
{ name = "llama-index-embeddings-huggingface" },
|
||||||
{ name = "llama-index-readers-file" },
|
{ name = "llama-index-readers-file" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
|
||||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
|
||||||
{ name = "msgpack" },
|
{ name = "msgpack" },
|
||||||
{ name = "nbconvert" },
|
{ name = "nbconvert" },
|
||||||
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||||
@@ -2302,8 +2302,8 @@ requires-dist = [
|
|||||||
{ name = "llama-index-core", specifier = ">=0.12.0" },
|
{ name = "llama-index-core", specifier = ">=0.12.0" },
|
||||||
{ name = "llama-index-embeddings-huggingface", specifier = ">=0.5.5" },
|
{ name = "llama-index-embeddings-huggingface", specifier = ">=0.5.5" },
|
||||||
{ name = "llama-index-readers-file", specifier = ">=0.4.0" },
|
{ name = "llama-index-readers-file", specifier = ">=0.4.0" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.26.3" },
|
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.3" },
|
||||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin'", specifier = ">=0.26.0" },
|
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.0" },
|
||||||
{ name = "msgpack", specifier = ">=1.0.0" },
|
{ name = "msgpack", specifier = ">=1.0.0" },
|
||||||
{ name = "nbconvert", specifier = ">=7.0.0" },
|
{ name = "nbconvert", specifier = ">=7.0.0" },
|
||||||
{ name = "numpy", specifier = ">=1.20.0" },
|
{ name = "numpy", specifier = ">=1.20.0" },
|
||||||
@@ -2343,8 +2343,8 @@ dependencies = [
|
|||||||
{ name = "llama-index-embeddings-huggingface" },
|
{ name = "llama-index-embeddings-huggingface" },
|
||||||
{ name = "llama-index-readers-file" },
|
{ name = "llama-index-readers-file" },
|
||||||
{ name = "llama-index-vector-stores-faiss" },
|
{ name = "llama-index-vector-stores-faiss" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
|
||||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin'" },
|
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
|
||||||
{ name = "msgpack" },
|
{ name = "msgpack" },
|
||||||
{ name = "nbconvert" },
|
{ name = "nbconvert" },
|
||||||
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||||
@@ -2356,6 +2356,7 @@ dependencies = [
|
|||||||
{ name = "pdfplumber" },
|
{ name = "pdfplumber" },
|
||||||
{ name = "protobuf" },
|
{ name = "protobuf" },
|
||||||
{ name = "psutil" },
|
{ name = "psutil" },
|
||||||
|
{ name = "pybind11" },
|
||||||
{ name = "pymupdf" },
|
{ name = "pymupdf" },
|
||||||
{ name = "pypdf2" },
|
{ name = "pypdf2" },
|
||||||
{ name = "pypdfium2" },
|
{ name = "pypdfium2" },
|
||||||
@@ -2424,8 +2425,8 @@ requires-dist = [
|
|||||||
{ name = "llama-index-readers-file", marker = "extra == 'test'", specifier = ">=0.4.0" },
|
{ name = "llama-index-readers-file", marker = "extra == 'test'", specifier = ">=0.4.0" },
|
||||||
{ name = "llama-index-vector-stores-faiss", specifier = ">=0.4.0" },
|
{ name = "llama-index-vector-stores-faiss", specifier = ">=0.4.0" },
|
||||||
{ name = "matplotlib", marker = "extra == 'dev'" },
|
{ name = "matplotlib", marker = "extra == 'dev'" },
|
||||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.26.3" },
|
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.3" },
|
||||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin'", specifier = ">=0.26.0" },
|
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.0" },
|
||||||
{ name = "msgpack", specifier = ">=1.1.1" },
|
{ name = "msgpack", specifier = ">=1.1.1" },
|
||||||
{ name = "nbconvert", specifier = ">=7.16.6" },
|
{ name = "nbconvert", specifier = ">=7.16.6" },
|
||||||
{ name = "numpy", specifier = ">=1.26.0" },
|
{ name = "numpy", specifier = ">=1.26.0" },
|
||||||
@@ -2438,6 +2439,7 @@ requires-dist = [
|
|||||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
||||||
{ name = "protobuf", specifier = "==4.25.3" },
|
{ name = "protobuf", specifier = "==4.25.3" },
|
||||||
{ name = "psutil", specifier = ">=5.8.0" },
|
{ name = "psutil", specifier = ">=5.8.0" },
|
||||||
|
{ name = "pybind11", specifier = ">=3.0.0" },
|
||||||
{ name = "pymupdf", specifier = ">=1.26.0" },
|
{ name = "pymupdf", specifier = ">=1.26.0" },
|
||||||
{ name = "pypdf2", specifier = ">=3.0.0" },
|
{ name = "pypdf2", specifier = ">=3.0.0" },
|
||||||
{ name = "pypdfium2", specifier = ">=4.30.0" },
|
{ name = "pypdfium2", specifier = ">=4.30.0" },
|
||||||
@@ -2449,7 +2451,7 @@ requires-dist = [
|
|||||||
{ name = "python-docx", marker = "extra == 'documents'", specifier = ">=0.8.11" },
|
{ name = "python-docx", marker = "extra == 'documents'", specifier = ">=0.8.11" },
|
||||||
{ name = "python-dotenv", marker = "extra == 'test'", specifier = ">=1.0.0" },
|
{ name = "python-dotenv", marker = "extra == 'test'", specifier = ">=1.0.0" },
|
||||||
{ name = "requests", specifier = ">=2.25.0" },
|
{ name = "requests", specifier = ">=2.25.0" },
|
||||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
|
{ name = "ruff", marker = "extra == 'dev'", specifier = "==0.12.7" },
|
||||||
{ name = "sentence-transformers", specifier = ">=2.2.0" },
|
{ name = "sentence-transformers", specifier = ">=2.2.0" },
|
||||||
{ name = "sentence-transformers", marker = "extra == 'test'", specifier = ">=2.2.0" },
|
{ name = "sentence-transformers", marker = "extra == 'test'", specifier = ">=2.2.0" },
|
||||||
{ name = "sglang" },
|
{ name = "sglang" },
|
||||||
@@ -4358,6 +4360,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/10/15/6b30e77872012bbfe8265d42a01d5b3c17ef0ac0f2fae531ad91b6a6c02e/pyarrow-21.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdc4c17afda4dab2a9c0b79148a43a7f4e1094916b3e18d8975bfd6d6d52241f", size = 26227521 },
|
{ url = "https://files.pythonhosted.org/packages/10/15/6b30e77872012bbfe8265d42a01d5b3c17ef0ac0f2fae531ad91b6a6c02e/pyarrow-21.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdc4c17afda4dab2a9c0b79148a43a7f4e1094916b3e18d8975bfd6d6d52241f", size = 26227521 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pybind11"
|
||||||
|
version = "3.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ef/83/698d120e257a116f2472c710932023ad779409adf2734d2e940f34eea2c5/pybind11-3.0.0.tar.gz", hash = "sha256:c3f07bce3ada51c3e4b76badfa85df11688d12c46111f9d242bc5c9415af7862", size = 544819 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/41/9c/85f50a5476832c3efc67b6d7997808388236ae4754bf53e1749b3bc27577/pybind11-3.0.0-py3-none-any.whl", hash = "sha256:7c5cac504da5a701b5163f0e6a7ba736c713a096a5378383c5b4b064b753f607", size = 292118 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pycparser"
|
name = "pycparser"
|
||||||
version = "2.22"
|
version = "2.22"
|
||||||
@@ -5204,27 +5215,27 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.12.5"
|
version = "0.12.7"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/30/cd/01015eb5034605fd98d829c5839ec2c6b4582b479707f7c1c2af861e8258/ruff-0.12.5.tar.gz", hash = "sha256:b209db6102b66f13625940b7f8c7d0f18e20039bb7f6101fbdac935c9612057e", size = 5170722 }
|
sdist = { url = "https://files.pythonhosted.org/packages/a1/81/0bd3594fa0f690466e41bd033bdcdf86cba8288345ac77ad4afbe5ec743a/ruff-0.12.7.tar.gz", hash = "sha256:1fc3193f238bc2d7968772c82831a4ff69252f673be371fb49663f0068b7ec71", size = 5197814 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/d4/de/ad2f68f0798ff15dd8c0bcc2889558970d9a685b3249565a937cd820ad34/ruff-0.12.5-py3-none-linux_armv6l.whl", hash = "sha256:1de2c887e9dec6cb31fcb9948299de5b2db38144e66403b9660c9548a67abd92", size = 11819133 },
|
{ url = "https://files.pythonhosted.org/packages/e1/d2/6cb35e9c85e7a91e8d22ab32ae07ac39cc34a71f1009a6f9e4a2a019e602/ruff-0.12.7-py3-none-linux_armv6l.whl", hash = "sha256:76e4f31529899b8c434c3c1dede98c4483b89590e15fb49f2d46183801565303", size = 11852189 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f8/fc/c6b65cd0e7fbe60f17e7ad619dca796aa49fbca34bb9bea5f8faf1ec2643/ruff-0.12.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d1ab65e7d8152f519e7dea4de892317c9da7a108da1c56b6a3c1d5e7cf4c5e9a", size = 12501114 },
|
{ url = "https://files.pythonhosted.org/packages/63/5b/a4136b9921aa84638f1a6be7fb086f8cad0fde538ba76bda3682f2599a2f/ruff-0.12.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:789b7a03e72507c54fb3ba6209e4bb36517b90f1a3569ea17084e3fd295500fb", size = 12519389 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c5/de/c6bec1dce5ead9f9e6a946ea15e8d698c35f19edc508289d70a577921b30/ruff-0.12.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:962775ed5b27c7aa3fdc0d8f4d4433deae7659ef99ea20f783d666e77338b8cf", size = 11716873 },
|
{ url = "https://files.pythonhosted.org/packages/a8/c9/3e24a8472484269b6b1821794141f879c54645a111ded4b6f58f9ab0705f/ruff-0.12.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2e1c2a3b8626339bb6369116e7030a4cf194ea48f49b64bb505732a7fce4f4e3", size = 11743384 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/a1/16/cf372d2ebe91e4eb5b82a2275c3acfa879e0566a7ac94d331ea37b765ac8/ruff-0.12.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b4cae449597e7195a49eb1cdca89fd9fbb16140c7579899e87f4c85bf82f73", size = 11958829 },
|
{ url = "https://files.pythonhosted.org/packages/26/7c/458dd25deeb3452c43eaee853c0b17a1e84169f8021a26d500ead77964fd/ruff-0.12.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32dec41817623d388e645612ec70d5757a6d9c035f3744a52c7b195a57e03860", size = 11943759 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/25/bf/cd07e8f6a3a6ec746c62556b4c4b79eeb9b0328b362bb8431b7b8afd3856/ruff-0.12.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b13489c3dc50de5e2d40110c0cce371e00186b880842e245186ca862bf9a1ac", size = 11626619 },
|
{ url = "https://files.pythonhosted.org/packages/7f/8b/658798472ef260ca050e400ab96ef7e85c366c39cf3dfbef4d0a46a528b6/ruff-0.12.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47ef751f722053a5df5fa48d412dbb54d41ab9b17875c6840a58ec63ff0c247c", size = 11654028 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d8/c9/c2ccb3b8cbb5661ffda6925f81a13edbb786e623876141b04919d1128370/ruff-0.12.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1504fea81461cf4841778b3ef0a078757602a3b3ea4b008feb1308cb3f23e08", size = 13221894 },
|
{ url = "https://files.pythonhosted.org/packages/a8/86/9c2336f13b2a3326d06d39178fd3448dcc7025f82514d1b15816fe42bfe8/ruff-0.12.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a828a5fc25a3efd3e1ff7b241fd392686c9386f20e5ac90aa9234a5faa12c423", size = 13225209 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/6b/58/68a5be2c8e5590ecdad922b2bcd5583af19ba648f7648f95c51c3c1eca81/ruff-0.12.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c7da4129016ae26c32dfcbd5b671fe652b5ab7fc40095d80dcff78175e7eddd4", size = 14163909 },
|
{ url = "https://files.pythonhosted.org/packages/76/69/df73f65f53d6c463b19b6b312fd2391dc36425d926ec237a7ed028a90fc1/ruff-0.12.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5726f59b171111fa6a69d82aef48f00b56598b03a22f0f4170664ff4d8298efb", size = 14182353 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/bd/d1/ef6b19622009ba8386fdb792c0743f709cf917b0b2f1400589cbe4739a33/ruff-0.12.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ca972c80f7ebcfd8af75a0f18b17c42d9f1ef203d163669150453f50ca98ab7b", size = 13583652 },
|
{ url = "https://files.pythonhosted.org/packages/58/1e/de6cda406d99fea84b66811c189b5ea139814b98125b052424b55d28a41c/ruff-0.12.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74e6f5c04c4dd4aba223f4fe6e7104f79e0eebf7d307e4f9b18c18362124bccd", size = 13631555 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/62/e3/1c98c566fe6809a0c83751d825a03727f242cdbe0d142c9e292725585521/ruff-0.12.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8dbbf9f25dfb501f4237ae7501d6364b76a01341c6f1b2cd6764fe449124bb2a", size = 12700451 },
|
{ url = "https://files.pythonhosted.org/packages/6f/ae/625d46d5164a6cc9261945a5e89df24457dc8262539ace3ac36c40f0b51e/ruff-0.12.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0bfe4e77fba61bf2ccadf8cf005d6133e3ce08793bbe870dd1c734f2699a3e", size = 12667556 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/24/ff/96058f6506aac0fbc0d0fc0d60b0d0bd746240a0594657a2d94ad28033ba/ruff-0.12.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c47dea6ae39421851685141ba9734767f960113d51e83fd7bb9958d5be8763a", size = 12937465 },
|
{ url = "https://files.pythonhosted.org/packages/55/bf/9cb1ea5e3066779e42ade8d0cd3d3b0582a5720a814ae1586f85014656b6/ruff-0.12.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06bfb01e1623bf7f59ea749a841da56f8f653d641bfd046edee32ede7ff6c606", size = 12939784 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/eb/d3/68bc5e7ab96c94b3589d1789f2dd6dd4b27b263310019529ac9be1e8f31b/ruff-0.12.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c5076aa0e61e30f848846f0265c873c249d4b558105b221be1828f9f79903dc5", size = 11771136 },
|
{ url = "https://files.pythonhosted.org/packages/55/7f/7ead2663be5627c04be83754c4f3096603bf5e99ed856c7cd29618c691bd/ruff-0.12.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e41df94a957d50083fd09b916d6e89e497246698c3f3d5c681c8b3e7b9bb4ac8", size = 11771356 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/52/75/7356af30a14584981cabfefcf6106dea98cec9a7af4acb5daaf4b114845f/ruff-0.12.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a5a4c7830dadd3d8c39b1cc85386e2c1e62344f20766be6f173c22fb5f72f293", size = 11601644 },
|
{ url = "https://files.pythonhosted.org/packages/17/40/a95352ea16edf78cd3a938085dccc55df692a4d8ba1b3af7accbe2c806b0/ruff-0.12.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4000623300563c709458d0ce170c3d0d788c23a058912f28bbadc6f905d67afa", size = 11612124 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c2/67/91c71d27205871737cae11025ee2b098f512104e26ffd8656fd93d0ada0a/ruff-0.12.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:46699f73c2b5b137b9dc0fc1a190b43e35b008b398c6066ea1350cce6326adcb", size = 12478068 },
|
{ url = "https://files.pythonhosted.org/packages/4d/74/633b04871c669e23b8917877e812376827c06df866e1677f15abfadc95cb/ruff-0.12.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:69ffe0e5f9b2cf2b8e289a3f8945b402a1b19eff24ec389f45f23c42a3dd6fb5", size = 12479945 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/34/04/b6b00383cf2f48e8e78e14eb258942fdf2a9bf0287fbf5cdd398b749193a/ruff-0.12.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5a655a0a0d396f0f072faafc18ebd59adde8ca85fb848dc1b0d9f024b9c4d3bb", size = 12991537 },
|
{ url = "https://files.pythonhosted.org/packages/be/34/c3ef2d7799c9778b835a76189c6f53c179d3bdebc8c65288c29032e03613/ruff-0.12.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a07a5c8ffa2611a52732bdc67bf88e243abd84fe2d7f6daef3826b59abbfeda4", size = 12998677 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/3e/b9/053d6445dc7544fb6594785056d8ece61daae7214859ada4a152ad56b6e0/ruff-0.12.5-py3-none-win32.whl", hash = "sha256:dfeb2627c459b0b78ca2bbdc38dd11cc9a0a88bf91db982058b26ce41714ffa9", size = 11751575 },
|
{ url = "https://files.pythonhosted.org/packages/77/ab/aca2e756ad7b09b3d662a41773f3edcbd262872a4fc81f920dc1ffa44541/ruff-0.12.7-py3-none-win32.whl", hash = "sha256:c928f1b2ec59fb77dfdf70e0419408898b63998789cc98197e15f560b9e77f77", size = 11756687 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/bc/0f/ab16e8259493137598b9149734fec2e06fdeda9837e6f634f5c4e35916da/ruff-0.12.5-py3-none-win_amd64.whl", hash = "sha256:ae0d90cf5f49466c954991b9d8b953bd093c32c27608e409ae3564c63c5306a5", size = 12882273 },
|
{ url = "https://files.pythonhosted.org/packages/b4/71/26d45a5042bc71db22ddd8252ca9d01e9ca454f230e2996bb04f16d72799/ruff-0.12.7-py3-none-win_amd64.whl", hash = "sha256:9c18f3d707ee9edf89da76131956aba1270c6348bfee8f6c647de841eac7194f", size = 12912365 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/00/db/c376b0661c24cf770cb8815268190668ec1330eba8374a126ceef8c72d55/ruff-0.12.5-py3-none-win_arm64.whl", hash = "sha256:48cdbfc633de2c5c37d9f090ba3b352d1576b0015bfc3bc98eaf230275b7e805", size = 11951564 },
|
{ url = "https://files.pythonhosted.org/packages/4c/9b/0b8aa09817b63e78d94b4977f18b1fcaead3165a5ee49251c5d5c245bb2d/ruff-0.12.7-py3-none-win_arm64.whl", hash = "sha256:dfce05101dbd11833a0776716d5d1578641b7fddb537fe7fa956ab85d1769b69", size = 11982083 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
Reference in New Issue
Block a user