Compare commits

..

1 Commits

Author SHA1 Message Date
Andy Lee
2c6b65d69f fix: detect and report Ollama embedding dimension inconsistency
- Add validation for embedding dimension consistency in Ollama mode
- Provide clear error message with troubleshooting steps when dimensions mismatch
- Fail fast instead of silent fallback to prevent data corruption

Fixes #31
2025-08-11 17:36:44 -07:00
33 changed files with 643 additions and 2356 deletions

View File

@@ -5,7 +5,6 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
jobs:
build:

View File

@@ -64,16 +64,6 @@ jobs:
python: '3.12'
- os: macos-14
python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15
python: '3.10'
- os: macos-15
python: '3.11'
- os: macos-15
python: '3.12'
- os: macos-15
python: '3.13'
- os: macos-13
python: '3.9'
- os: macos-13
@@ -157,14 +147,7 @@ jobs:
# Use system clang for better compatibility
export CC=clang
export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.0
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
export MACOSX_DEPLOYMENT_TARGET=11.0
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
@@ -178,14 +161,7 @@ jobs:
export CC=clang
export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
# But Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
export MACOSX_DEPLOYMENT_TARGET=13.3
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
@@ -221,24 +197,10 @@ jobs:
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
fi
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
@@ -247,8 +209,7 @@ jobs:
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
@@ -277,16 +238,19 @@ jobs:
- name: Run tests with pytest
env:
CI: true
CI: true # Mark as CI environment to skip memory-intensive tests
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
MKL_NUM_THREADS: 1 # Single thread for MKL operations
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
pytest tests/ -v --tb=short
# Run all tests
pytest tests/
- name: Run sanity checks (optional)
run: |

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@@ -10,7 +10,7 @@ repos:
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7 # Fixed version to match pyproject.toml
rev: v0.2.1
hooks:
- id: ruff
- id: ruff-format

View File

@@ -455,7 +455,7 @@ leann --help
**To make it globally available:**
```bash
# Install the LEANN CLI globally using uv tool
uv tool install leann-core
uv tool install leann
# Now you can use leann from anywhere without activating venv
leann --help
@@ -468,7 +468,7 @@ leann --help
### Usage Examples
```bash
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
# build from a specific directory, and my_docs is the index name
leann build my-docs --docs ./your_documents
# Search your documents
@@ -543,16 +543,12 @@ Options:
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:**
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
## Benchmarks
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
### 📊 Storage Comparison
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |

View File

@@ -178,9 +178,6 @@ class BaseRAGExample(ABC):
config["host"] = args.llm_host
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config

View File

@@ -1,24 +1,9 @@
# 🧪 LEANN Benchmarks & Testing
# 🧪 Leann Sanity Checks
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
## 📁 Test Files
### `diskann_vs_hnsw_speed_comparison.py`
Performance comparison between DiskANN and HNSW backends:
-**Search latency** comparison with both backends using recompute
-**Index size** and **build time** measurements
-**Score validity** testing (ensures no -inf scores)
-**Configurable dataset sizes** for different scales
```bash
# Quick comparison with 500 docs, 10 queries
python benchmarks/diskann_vs_hnsw_speed_comparison.py
# Large-scale comparison with 2000 docs, 20 queries
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
```
### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend:
-**MIPS** (Maximum Inner Product Search)

View File

@@ -1,268 +0,0 @@
#!/usr/bin/env python3
"""
DiskANN vs HNSW Search Performance Comparison
This benchmark compares search performance between DiskANN and HNSW backends:
- DiskANN: With graph partitioning enabled (is_recompute=True)
- HNSW: With recompute enabled (is_recompute=True)
- Tests performance across different dataset sizes
- Measures search latency, recall, and index size
"""
import gc
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking."""
np.random.seed(42)
topics = [
"machine learning and artificial intelligence",
"natural language processing and text analysis",
"computer vision and image recognition",
"data science and statistical analysis",
"deep learning and neural networks",
"information retrieval and search engines",
"database systems and data management",
"software engineering and programming",
"cybersecurity and network protection",
"cloud computing and distributed systems",
]
texts = []
for i in range(n_docs):
topic = topics[i % len(topics)]
variation = np.random.randint(1, 100)
text = (
f"This is document {i} about {topic}. Content variation {variation}. "
f"Additional information about {topic} with details and examples. "
f"Technical discussion of {topic} including implementation aspects."
)
texts.append(text)
return texts
def benchmark_backend(
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
) -> dict[str, float]:
"""Benchmark a specific backend with the given configuration."""
from leann.api import LeannBuilder, LeannSearcher
print(f"\n🔧 Testing {backend_name.upper()} backend...")
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
# Build index
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
start_time = time.time()
builder = LeannBuilder(
backend_name=backend_name,
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
**backend_kwargs,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
build_time = time.time() - start_time
# Measure index size
index_dir = Path(index_path).parent
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
size_mb = total_size / (1024 * 1024)
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
# Search benchmark
print("🔍 Running search benchmark...")
searcher = LeannSearcher(index_path)
search_times = []
all_results = []
for query in test_queries:
start_time = time.time()
results = searcher.search(query, top_k=5)
search_time = time.time() - start_time
search_times.append(search_time)
all_results.append(results)
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
# Check for valid scores (detect -inf issues)
all_scores = [
result.score
for results in all_results
for result in results
if result.score is not None
]
valid_scores = [
score for score in all_scores if score != float("-inf") and score != float("inf")
]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up
try:
if hasattr(searcher, "__del__"):
searcher.__del__()
del searcher
del builder
gc.collect()
except Exception as e:
print(f"⚠️ Warning: Resource cleanup error: {e}")
return {
"build_time": build_time,
"avg_search_time_ms": avg_search_time,
"index_size_mb": size_mb,
"score_validity_rate": score_validity_rate,
}
def run_comparison(n_docs: int = 500, n_queries: int = 10):
"""Run performance comparison between DiskANN and HNSW."""
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
# Create test data
texts = create_test_texts(n_docs)
test_queries = [
"machine learning algorithms",
"natural language processing",
"computer vision techniques",
"data analysis methods",
"neural network architectures",
"database query optimization",
"software development practices",
"security vulnerabilities",
"cloud infrastructure",
"distributed computing",
][:n_queries]
# HNSW benchmark
hnsw_results = benchmark_backend(
backend_name="hnsw",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable recompute for fair comparison
"M": 16,
"efConstruction": 200,
},
)
# DiskANN benchmark
diskann_results = benchmark_backend(
backend_name="diskann",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable graph partitioning
"num_neighbors": 32,
"search_list_size": 50,
},
)
# Performance comparison
print("\n📈 Performance Comparison Results")
print(f"{'=' * 60}")
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
print(f"{'-' * 60}")
# Build time comparison
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
print(
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
)
# Search time comparison
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
print(
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
)
# Index size comparison
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
print(
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
)
# Score validity
print(
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
)
print(f"{'=' * 60}")
print("\n🎯 Summary:")
if search_speedup > 1:
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
else:
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
if size_ratio > 1:
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
else:
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
print(
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
)
if __name__ == "__main__":
import sys
try:
# Handle help request
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
print()
print("Arguments:")
print(" n_docs Number of documents to index (default: 500)")
print(" n_queries Number of test queries to run (default: 10)")
print()
print("Examples:")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
sys.exit(0)
# Parse command line arguments
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Dataset: {n_docs} documents, {n_queries} queries")
print()
run_comparison(n_docs=n_docs, n_queries=n_queries)
except KeyboardInterrupt:
print("\n⚠️ Benchmark interrupted by user")
sys.exit(130)
except Exception as e:
print(f"\n❌ Benchmark failed: {e}")
sys.exit(1)
finally:
# Ensure clean exit
try:
gc.collect()
print("\n🧹 Cleanup completed")
except Exception:
pass
sys.exit(0)

View File

@@ -97,30 +97,16 @@ ollama pull nomic-embed-text
```
### DiskANN
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
**How it works:**
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
**Trade-offs compared to HNSW:**
-**Faster search latency** (typically 2-8x speedup)
-**Better scaling** for large datasets
-**Smart storage management** with automatic partitioning
-**Better graph locality** with `--ldg-times` parameter for SSD optimization
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
```bash
# Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64
# For large-scale deployments
# For billion-scale deployments
--backend-name diskann --graph-degree 64 --build-complexity 128
```
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison
### LLM Engines
@@ -297,4 +283,3 @@ LEANN's recomputation feature provides exact distance calculations but can be di
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)

View File

@@ -1,7 +1 @@
from . import diskann_backend as diskann_backend
from . import graph_partition
# Export main classes and functions
from .graph_partition import GraphPartitioner, partition_graph
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]

View File

@@ -22,11 +22,6 @@ logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
if os.getenv("CI") == "true":
yield
return
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
@@ -142,71 +137,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
"""
Safely cleanup files after partition.
In partition mode, C++ doesn't read _disk.index content,
so we can delete it if all derived files exist.
"""
disk_index_file = index_dir / f"{index_prefix}_disk.index"
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
# Required files that C++ partition mode needs
# Note: C++ generates these with _disk.index suffix
disk_suffix = "_disk.index"
required_files = [
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
f"{index_prefix}_pq_pivots.bin", # PQ table
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
]
# Check if all required files exist
missing_files = []
for filename in required_files:
file_path = index_dir / filename
if not file_path.exists():
missing_files.append(filename)
if missing_files:
logger.warning(
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
)
logger.info("Keeping all original files for safety")
return
# Calculate space savings
space_saved = 0
files_to_delete = []
if disk_index_file.exists():
space_saved += disk_index_file.stat().st_size
files_to_delete.append(disk_index_file)
if beam_search_file.exists():
space_saved += beam_search_file.stat().st_size
files_to_delete.append(beam_search_file)
# Safe to delete!
for file_to_delete in files_to_delete:
try:
os.remove(file_to_delete)
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
except Exception as e:
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
if space_saved > 0:
space_saved_mb = space_saved / (1024 * 1024)
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
# Show what files are kept
logger.info("📁 Kept essential files for partition mode:")
for filename in required_files:
file_path = index_dir / filename
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info(f" - {filename} ({size_mb:.1f} MB)")
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
@@ -221,17 +151,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
_write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs}
# Extract is_recompute from nested backend_kwargs if needed
is_recompute = build_kwargs.get("is_recompute", False)
if not is_recompute and "backend_kwargs" in build_kwargs:
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
# Flatten all backend_kwargs parameters to top level for compatibility
if "backend_kwargs" in build_kwargs:
nested_params = build_kwargs.pop("backend_kwargs")
build_kwargs.update(nested_params)
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
@@ -266,30 +185,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
build_kwargs.get("pq_disk_bytes", 0),
"",
)
# Auto-partition if is_recompute is enabled
if build_kwargs.get("is_recompute", False):
logger.info("is_recompute=True, starting automatic graph partitioning...")
from .graph_partition import partition_graph
# Partition the index using absolute paths
# Convert to absolute paths to avoid issues with working directory changes
absolute_index_dir = Path(index_dir).resolve()
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
disk_graph_path, partition_bin_path = partition_graph(
index_prefix_path=absolute_index_prefix_path,
output_dir=str(absolute_index_dir),
partition_prefix=index_prefix,
)
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
self._safe_cleanup_after_partition(index_dir, index_prefix)
logger.info("✅ Graph partitioning completed successfully!")
logger.info(f" - Disk graph: {disk_graph_path}")
logger.info(f" - Partition file: {partition_bin_path}")
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
@@ -318,26 +213,7 @@ class DiskannSearcher(BaseSearcher):
# For DiskANN, we need to reinitialize the index when zmq_port changes
# Store the initialization parameters for later use
# Note: C++ load method expects the BASE path (without _disk.index suffix)
# C++ internally constructs: index_prefix + "_disk.index"
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
# Auto-detect partition files and set partition_prefix
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
partition_prefix = ""
if partition_graph_file.exists() and partition_bin_file.exists():
# C++ expects full path prefix, not just filename
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
else:
logger.debug("No partition files detected, using standard index files")
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._init_params = {
"metric_enum": metric_enum,
"full_index_prefix": full_index_prefix,
@@ -345,14 +221,8 @@ class DiskannSearcher(BaseSearcher):
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
"pq_prefix": "",
"partition_prefix": partition_prefix,
"partition_prefix": "",
}
# Log partition configuration for debugging
if partition_prefix:
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None

View File

@@ -81,8 +81,7 @@ def create_diskann_embedding_server(
with open(passages_file) as f:
meta = json.load(f)
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
passages = PassageManager(meta["passage_sources"])
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
@@ -103,9 +102,8 @@ def create_diskann_embedding_server(
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.setsockopt(zmq.SNDTIMEO, 1000)
socket.setsockopt(zmq.LINGER, 0)
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
while True:
try:
@@ -222,217 +220,30 @@ def create_diskann_embedding_server(
traceback.print_exc()
raise
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
This creates its own REP socket, binds to zmq_port, and periodically
checks shutdown_event using recv timeouts to exit cleanly.
"""
logger.info("DiskANN ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
# Set receive timeout so we can check shutdown_event periodically
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
# REP socket receives single-part messages
message = rep_socket.recv()
# Check for empty messages - REP socket requires response to every request
if not message:
logger.warning("Received empty message, sending empty response")
rep_socket.send(b"")
continue
# Try protobuf first (same logic as original)
texts = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
# Look up texts by node IDs
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"ZMQ received msgpack text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error("NaN or Inf detected in embeddings!")
# Send error response
if is_text_request:
import msgpack
response_data = msgpack.packb([])
else:
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = resp_proto.SerializeToString()
rep_socket.send(response_data)
continue
# Prepare response based on request type
if is_text_request:
# For direct text requests, return msgpack
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For protobuf requests, return protobuf
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
rep_socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
try:
# Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("DiskANN ZMQ server thread exiting gracefully")
# Add shutdown coordination
shutdown_event = threading.Event()
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Clean up other resources
try:
import gc
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed")
sys.exit(0)
# Register signal handlers within this function scope
import signal
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Start ZMQ thread (NOT daemon!)
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import signal
import sys
# Signal handlers are now registered within create_diskann_embedding_server
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env python3
"""
Graph Partition Module for LEANN DiskANN Backend
This module provides Python bindings for the graph partition functionality
of DiskANN, allowing users to partition disk-based indices for better
performance.
"""
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
class GraphPartitioner:
"""
A Python interface for DiskANN's graph partition functionality.
This class provides methods to partition disk-based indices for improved
search performance and memory efficiency.
"""
def __init__(self, build_type: str = "release"):
"""
Initialize the GraphPartitioner.
Args:
build_type: Build type for the executables ("debug" or "release")
"""
self.build_type = build_type
self._ensure_executables()
def _get_executable_path(self, name: str) -> str:
"""Get the path to a graph partition executable."""
# Get the directory where this Python module is located
module_dir = Path(__file__).parent
# Navigate to the graph_partition directory
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
if not executable_path.exists():
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
return str(executable_path)
def _ensure_executables(self):
"""Ensure that the required executables are built."""
try:
self._get_executable_path("partitioner")
self._get_executable_path("index_relayout")
except FileNotFoundError:
# Try to build the executables automatically
print("Executables not found, attempting to build them...")
self._build_executables()
def _build_executables(self):
"""Build the required executables."""
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Clean any existing build
if (graph_partition_dir / "build").exists():
shutil.rmtree(graph_partition_dir / "build")
# Run the build script
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
# Check if executables were created
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
print(f"✅ Built partitioner: {partitioner_path}")
print(f"✅ Built index_relayout: {relayout_path}")
except Exception as e:
raise RuntimeError(f"Failed to build executables: {e}")
finally:
os.chdir(original_dir)
def partition_graph(
self,
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
"""
Partition a disk-based index for improved performance.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory for results (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
**kwargs: Additional parameters for graph partitioning:
- gp_times: Number of LDG partition iterations (default: 10)
- lock_nums: Number of lock nodes (default: 10)
- cut: Cut adjacency list degree (default: 100)
- scale_factor: Scale factor (default: 1)
- data_type: Data type (default: "float")
- thread_nums: Number of threads (default: 10)
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
Raises:
RuntimeError: If the partitioning process fails
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Determine partition prefix
if partition_prefix is None:
partition_prefix = Path(index_prefix_path).name
# Get executable paths
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Change to the graph_partition directory for temporary files
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Create temporary data directory
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Run partitioner
gp_file_path = graph_gp_path / "_part.bin"
partitioner_cmd = [
partitioner_path,
"--index_file",
old_index_file,
"--data_type",
params["data_type"],
"--gp_file",
str(gp_file_path),
"-T",
str(params["thread_nums"]),
"--ldg_times",
str(params["gp_times"]),
"--scale",
str(params["scale_factor"]),
"--mode",
"1",
]
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
result = subprocess.run(
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Partitioner failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Run relayout
part_tmp_index = graph_gp_path / "_part_tmp.index"
relayout_cmd = [
relayout_path,
old_index_file,
str(gp_file_path),
params["data_type"],
"1",
]
print(f"Running relayout: {' '.join(relayout_cmd)}")
result = subprocess.run(
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Relayout failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Copy results to output directory
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
shutil.copy2(part_tmp_index, disk_graph_path)
shutil.copy2(gp_file_path, partition_bin_path)
print(f"Results copied to: {output_dir}")
return str(disk_graph_path), str(partition_bin_path)
finally:
os.chdir(original_dir)
def get_partition_info(self, partition_bin_path: str) -> dict:
"""
Get information about a partition file.
Args:
partition_bin_path: Path to the partition binary file
Returns:
Dictionary containing partition information
"""
if not os.path.exists(partition_bin_path):
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
# For now, return basic file information
# In the future, this could parse the binary file for detailed info
stat = os.stat(partition_bin_path)
return {
"file_size": stat.st_size,
"file_path": partition_bin_path,
"modified_time": stat.st_mtime,
}
def partition_graph(
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
build_type: str = "release",
**kwargs,
) -> tuple[str, str]:
"""
Convenience function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix
output_dir: Output directory (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
build_type: Build type for executables ("debug" or "release")
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
partitioner = GraphPartitioner(build_type=build_type)
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
# Example usage:
if __name__ == "__main__":
# Example: partition an index
try:
disk_graph_path, partition_bin_path = partition_graph(
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
)
print("Partitioning completed successfully!")
print(f"Disk graph index: {disk_graph_path}")
print(f"Partition binary: {partition_bin_path}")
except Exception as e:
print(f"Partitioning failed: {e}")

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.2.8"
dependencies = ["leann-core==0.2.8", "numpy", "protobuf>=3.19.0"]
version = "0.2.7"
dependencies = ["leann-core==0.2.7", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -13,7 +13,7 @@ if(APPLE)
else()
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
endif()
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")

View File

@@ -1,6 +1,5 @@
import argparse
import gc # Import garbage collector interface
import logging
import os
import struct
import sys
@@ -8,12 +7,6 @@ import time
import numpy as np
# Set up logging to avoid print buffer issues
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
# Add other HNSW fourccs if you expect different storage types inside HNSW
@@ -250,8 +243,6 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
output_filename: Output CSR index file
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
"""
# Keep prints simple; rely on CI runner to flush output as needed
print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time()
original_hnsw_data = {}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Union
import msgpack
import numpy as np
@@ -34,7 +34,7 @@ if not logger.handlers:
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_file: Union[str, None] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
@@ -82,317 +82,199 @@ def create_hnsw_embedding_server(
with open(passages_file) as f:
meta = json.load(f)
# Let PassageManager handle path resolution uniformly. It supports fallback order:
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
# Dimension from metadata for shaping responses
try:
embedding_dim: int = int(meta.get("dimensions", 0))
except Exception:
embedding_dim = 0
# Convert relative paths to absolute paths based on metadata file location
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
passage_sources = []
for source in meta["passage_sources"]:
source_copy = source.copy()
# Convert relative paths to absolute paths
if not Path(source_copy["path"]).is_absolute():
source_copy["path"] = str(metadata_dir / source_copy["path"])
if not Path(source_copy["index_path"]).is_absolute():
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
passage_sources.append(source_copy)
passages = PassageManager(passage_sources)
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
Creates its own REP socket bound to zmq_port and polls with timeouts
to allow graceful shutdown.
"""
logger.info("ZMQ server thread started with shutdown support")
def zmq_server_thread():
"""ZMQ server thread"""
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
# Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
while True:
try:
message_bytes = socket.recv()
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
# Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes)
e2e_start = time.time()
request_payload = msgpack.unpackb(message_bytes)
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Handle direct text embedding request
if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
rep_socket.send(msgpack.packb(embeddings.tolist()))
# Use unified embedding computation (now with model caching)
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
node_ids = request[0]
# Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
# Handle distance calculation requests
if (
isinstance(request_payload, list)
and len(request_payload) == 2
and isinstance(request_payload[0], list)
and isinstance(request_payload[1], list)
):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
# Get embeddings for node IDs
texts = []
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
txt = passage_data["text"]
texts.append(txt)
except KeyError:
logger.error(f"Passage with ID {nid} not found")
logger.error(f"Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
# Calculate distances
if distance_metric == "l2":
distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
distances = -np.dot(embeddings, query_vector)
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
logger.debug(f"Sending distance response with {len(distances)} distances")
rep_socket.send(response_bytes)
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("ZMQ server thread exiting gracefully")
# Standard embedding request (passage ID lookup)
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
logger.error(
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
# Add shutdown coordination
shutdown_event = threading.Event()
node_ids = request_payload[0]
logger.debug(f"Request for {len(node_ids)} node embeddings")
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
# Look up texts by node IDs
texts = []
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
raise AssertionError()
# Clean up other resources
try:
import gc
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
logger.info("Graceful shutdown completed")
sys.exit(0)
except zmq.Again:
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
# Register signal handlers within this function scope
import signal
traceback.print_exc()
socket.send(msgpack.packb([[], []]))
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Pass shutdown_event to ZMQ thread
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("HNSW Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import signal
import sys
# Signal handlers are now registered within create_hnsw_embedding_server
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.2.8"
version = "0.2.7"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.2.8",
"leann-core==0.2.7",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.2.8"
version = "0.2.7"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -115,62 +115,20 @@ class SearchResult:
class PassageManager:
def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
):
def __init__(self, passage_sources: list[dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
index_name_base = None
if metadata_file_path:
meta_name = Path(metadata_file_path).name
if meta_name.endswith(".meta.json"):
index_name_base = meta_name[: -len(".meta.json")]
for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source.get("path", "")
index_file = source.get("index_path", "") # .idx file
passage_file = source["path"]
index_file = source["index_path"] # .idx file
# Fix path resolution - relative paths should be relative to metadata file directory
def _resolve_candidates(
primary: str,
relative_key: str,
default_name: Optional[str],
source_dict: dict[str, Any],
) -> list[Path]:
candidates: list[Path] = []
# 1) Primary as-is (absolute or relative)
if primary:
p = Path(primary)
candidates.append(p if p.is_absolute() else (Path.cwd() / p))
# 2) metadata-relative explicit relative key
if metadata_file_path and source_dict.get(relative_key):
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
# 3) metadata-relative standard sibling filename
if metadata_file_path and default_name:
candidates.append(Path(metadata_file_path).parent / default_name)
return candidates
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
idx_candidates = _resolve_candidates(
index_file, "index_path_relative", idx_default, source
)
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
def _pick_existing(cands: list[Path]) -> str:
for c in cands:
if c.exists():
return str(c.resolve())
# Fallback to last candidate (best guess) even if not exists; will error below
return str(cands[-1].resolve()) if cands else ""
index_file = _pick_existing(idx_candidates)
passage_file = _pick_existing(pas_candidates)
# Fix path resolution for Colab and other environments
if not Path(index_file).is_absolute():
# If relative path, try to resolve it properly
index_file = str(Path(index_file).resolve())
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
@@ -356,12 +314,8 @@ class LeannBuilder:
"passage_sources": [
{
"type": "jsonl",
# Preserve existing relative file names (backward-compatible)
"path": passages_file.name,
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
"path": str(passages_file),
"index_path": str(offset_file),
}
],
}
@@ -476,12 +430,8 @@ class LeannBuilder:
"passage_sources": [
{
"type": "jsonl",
# Preserve existing relative file names (backward-compatible)
"path": passages_file.name,
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
"path": str(passages_file),
"index_path": str(offset_file),
}
],
"built_from_precomputed_embeddings": True,
@@ -523,9 +473,7 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
)
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
@@ -598,13 +546,13 @@ class LeannSearcher:
zmq_port=zmq_port,
**kwargs,
)
time.time() - start_time
# logger.info(f" Search time: {search_time} seconds")
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
# Python 3.9 does not support zip(strict=...); lengths are expected to match
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
@@ -632,26 +580,13 @@ class LeannSearcher:
)
except KeyError:
RED = "\033[91m"
RESET = "\033[0m"
logger.error(
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
# Define color codes outside the loop for final message
GREEN = "\033[92m"
RESET = "\033[0m"
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios.
"""
if hasattr(self.backend_impl, "embedding_server_manager"):
self.backend_impl.embedding_server_manager.stop_server()
class LeannChat:
def __init__(
@@ -721,12 +656,3 @@ class LeannChat:
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
if hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()

View File

@@ -5,7 +5,6 @@ from typing import Union
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher
@@ -76,14 +75,11 @@ class LeannCLI:
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
leann build my-docs --docs ./documents # Build index from directory
leann build my-code --docs ./src ./tests ./config # Build index from multiple directories
leann build my-files --docs ./file1.py ./file2.txt ./docs/ # Build index from files and directories
leann build my-mixed --docs ./readme.md ./src/ ./config.json # Build index from mixed files/dirs
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
leann build my-docs --docs ./documents # Build index named my-docs
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
""",
)
@@ -95,11 +91,7 @@ Examples:
"index_name", nargs="?", help="Index name (default: current directory name)"
)
build_parser.add_argument(
"--docs",
type=str,
nargs="+",
default=["."],
help="Documents directories and/or files (default: current directory)",
"--docs", type=str, default=".", help="Documents directory (default: current directory)"
)
build_parser.add_argument(
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
@@ -243,32 +235,6 @@ Examples:
"""Check if a file should be excluded using gitignore parser."""
return gitignore_matches(str(relative_path))
def _is_git_submodule(self, path: Path) -> bool:
"""Check if a path is a git submodule."""
try:
# Find the git repo root
current_dir = Path.cwd()
while current_dir != current_dir.parent:
if (current_dir / ".git").exists():
gitmodules_path = current_dir / ".gitmodules"
if gitmodules_path.exists():
# Read .gitmodules to check if this path is a submodule
gitmodules_content = gitmodules_path.read_text()
# Convert path to relative to git root
try:
relative_path = path.resolve().relative_to(current_dir)
# Check if this path appears in .gitmodules
return f"path = {relative_path}" in gitmodules_content
except ValueError:
# Path is not under git root
return False
break
current_dir = current_dir.parent
return False
except Exception:
# If anything goes wrong, assume it's not a submodule
return False
def list_indexes(self):
print("Stored LEANN indexes:")
@@ -298,9 +264,7 @@ Examples:
valid_projects.append(current_path)
if not valid_projects:
print(
"No indexes found. Use 'leann build <name> --docs <dir> [<dir2> ...]' to create one."
)
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
return
total_indexes = 0
@@ -347,88 +311,56 @@ Examples:
print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
def load_documents(
self, docs_paths: Union[str, list], custom_file_types: Union[str, None] = None
):
# Handle both single path (string) and multiple paths (list) for backward compatibility
if isinstance(docs_paths, str):
docs_paths = [docs_paths]
# Separate files and directories
files = []
directories = []
for path in docs_paths:
path_obj = Path(path)
if path_obj.is_file():
files.append(str(path_obj))
elif path_obj.is_dir():
# Check if this is a git submodule - if so, skip it
if self._is_git_submodule(path_obj):
print(f"⚠️ Skipping git submodule: {path}")
continue
directories.append(str(path_obj))
else:
print(f"⚠️ Warning: Path '{path}' does not exist, skipping...")
continue
# Print summary of what we're processing
total_items = len(files) + len(directories)
items_desc = []
if files:
items_desc.append(f"{len(files)} file{'s' if len(files) > 1 else ''}")
if directories:
items_desc.append(
f"{len(directories)} director{'ies' if len(directories) > 1 else 'y'}"
)
print(f"Loading documents from {' and '.join(items_desc)} ({total_items} total):")
if files:
print(f" 📄 Files: {', '.join([Path(f).name for f in files])}")
if directories:
print(f" 📁 Directories: {', '.join(directories)}")
def load_documents(self, docs_dir: str, custom_file_types: Union[str, None] = None):
print(f"Loading documents from {docs_dir}...")
if custom_file_types:
print(f"Using custom file types: {custom_file_types}")
all_documents = []
# Build gitignore parser
gitignore_matches = self._build_gitignore_parser(docs_dir)
# First, process individual files if any
if files:
print(f"\n🔄 Processing {len(files)} individual file{'s' if len(files) > 1 else ''}...")
# Try to use better PDF parsers first, but only if PDFs are requested
documents = []
docs_path = Path(docs_dir)
# Load individual files using SimpleDirectoryReader with input_files
# Note: We skip gitignore filtering for explicitly specified files
try:
# Group files by their parent directory for efficient loading
from collections import defaultdict
# Check if we should process PDFs
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
files_by_dir = defaultdict(list)
for file_path in files:
parent_dir = str(Path(file_path).parent)
files_by_dir[parent_dir].append(file_path)
if should_process_pdfs:
for file_path in docs_path.rglob("*.pdf"):
# Check if file matches any exclude pattern
relative_path = file_path.relative_to(docs_path)
if self._should_exclude_file(relative_path, gitignore_matches):
continue
# Load files from each parent directory
for parent_dir, file_list in files_by_dir.items():
print(
f" Loading {len(file_list)} file{'s' if len(file_list) > 1 else ''} from {parent_dir}"
)
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
try:
file_docs = SimpleDirectoryReader(
parent_dir,
input_files=file_list,
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
all_documents.extend(file_docs)
print(
f" ✅ Loaded {len(file_docs)} document{'s' if len(file_docs) > 1 else ''}"
)
documents.extend(default_docs)
except Exception as e:
print(f"Warning: Could not load files from {parent_dir}: {e}")
print(f"Warning: Could not process {file_path}: {e}")
except Exception as e:
print(f"❌ Error processing individual files: {e}")
# Define file extensions to process
# Load other file types with default reader
if custom_file_types:
# Parse custom file types from comma-separated string
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
@@ -490,106 +422,41 @@ Examples:
".py",
".jl",
]
# Try to load other file types, but don't fail if none are found
try:
# Create a custom file filter function using our PathSpec
def file_filter(file_path: str) -> bool:
"""Return True if file should be included (not excluded)"""
try:
docs_path_obj = Path(docs_dir)
file_path_obj = Path(file_path)
relative_path = file_path_obj.relative_to(docs_path_obj)
return not self._should_exclude_file(relative_path, gitignore_matches)
except (ValueError, OSError):
return True # Include files that can't be processed
# Process each directory
if directories:
print(
f"\n🔄 Processing {len(directories)} director{'ies' if len(directories) > 1 else 'y'}..."
)
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=code_extensions,
file_extractor={}, # Use default extractors
filename_as_id=True,
).load_data(show_progress=True)
for docs_dir in directories:
print(f"Processing directory: {docs_dir}")
# Build gitignore parser for each directory
gitignore_matches = self._build_gitignore_parser(docs_dir)
# Filter documents after loading based on gitignore rules
filtered_docs = []
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
# Try to use better PDF parsers first, but only if PDFs are requested
documents = []
docs_path = Path(docs_dir)
# Check if we should process PDFs
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
if should_process_pdfs:
for file_path in docs_path.rglob("*.pdf"):
# Check if file matches any exclude pattern
try:
relative_path = file_path.relative_to(docs_path)
if self._should_exclude_file(relative_path, gitignore_matches):
continue
except ValueError:
# Skip files that can't be made relative to docs_path
print(f"⚠️ Skipping file outside directory scope: {file_path}")
continue
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
try:
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
documents.extend(default_docs)
except Exception as e:
print(f"Warning: Could not process {file_path}: {e}")
# Load other file types with default reader
try:
# Create a custom file filter function using our PathSpec
def file_filter(
file_path: str, docs_dir=docs_dir, gitignore_matches=gitignore_matches
) -> bool:
"""Return True if file should be included (not excluded)"""
try:
docs_path_obj = Path(docs_dir)
file_path_obj = Path(file_path)
relative_path = file_path_obj.relative_to(docs_path_obj)
return not self._should_exclude_file(relative_path, gitignore_matches)
except (ValueError, OSError):
return True # Include files that can't be processed
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=code_extensions,
file_extractor={}, # Use default extractors
filename_as_id=True,
).load_data(show_progress=True)
# Filter documents after loading based on gitignore rules
filtered_docs = []
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents.extend(filtered_docs)
except ValueError as e:
if "No files found" in str(e):
print(f"No additional files found for other supported types in {docs_dir}.")
else:
raise e
all_documents.extend(documents)
print(f"Loaded {len(documents)} documents from {docs_dir}")
documents = all_documents
documents.extend(filtered_docs)
except ValueError as e:
if "No files found" in str(e):
print("No additional files found for other supported types.")
else:
raise e
all_texts = []
@@ -640,9 +507,7 @@ Examples:
".jl",
}
print("start chunking documents")
# Add progress bar for document chunking
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
for doc in documents:
# Check if this is a code file based on source path
source_path = doc.metadata.get("source", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
@@ -658,7 +523,7 @@ Examples:
return all_texts
async def build_index(self, args):
docs_paths = args.docs
docs_dir = args.docs
# Use current directory name if index_name not provided
if args.index_name:
index_name = args.index_name
@@ -669,25 +534,13 @@ Examples:
index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name)
# Display all paths being indexed with file/directory distinction
files = [p for p in docs_paths if Path(p).is_file()]
directories = [p for p in docs_paths if Path(p).is_dir()]
print(f"📂 Indexing {len(docs_paths)} path{'s' if len(docs_paths) > 1 else ''}:")
if files:
print(f" 📄 Files ({len(files)}):")
for i, file_path in enumerate(files, 1):
print(f" {i}. {Path(file_path).resolve()}")
if directories:
print(f" 📁 Directories ({len(directories)}):")
for i, dir_path in enumerate(directories, 1):
print(f" {i}. {Path(dir_path).resolve()}")
print(f"📂 Indexing: {Path(docs_dir).resolve()}")
if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return
all_texts = self.load_documents(docs_paths, args.file_types)
all_texts = self.load_documents(docs_dir, args.file_types)
if not all_texts:
print("No documents found")
return
@@ -723,7 +576,7 @@ Examples:
if not self.index_exists(index_name):
print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
)
return
@@ -750,7 +603,7 @@ Examples:
if not self.index_exists(index_name):
print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
)
return

View File

@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import numpy as np
@@ -373,9 +374,7 @@ def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Compute embeddings using Ollama API.
Args:
texts: List of texts to compute embeddings for
@@ -439,19 +438,12 @@ def compute_embeddings_ollama(
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model)
# Check if model exists (handle versioned names) and resolve to full name
resolved_model_name = None
for name in model_names:
# Exact match
if model_name == name:
resolved_model_name = name
break
# Match without version tag (use the versioned name)
elif model_name == name.split(":")[0]:
resolved_model_name = name
break
# Check if model exists (handle versioned names)
model_found = any(
model_name == name.split(":")[0] or model_name == name for name in model_names
)
if not resolved_model_name:
if not model_found:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model
@@ -473,11 +465,6 @@ def compute_embeddings_ollama(
error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg)
# Use the resolved model name for all subsequent operations
if resolved_model_name != model_name:
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
@@ -498,147 +485,162 @@ def compute_embeddings_ollama(
except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}")
# Determine batch size based on device availability
# Check for CUDA/MPS availability using torch if available
batch_size = 32 # Default for MPS/CPU
try:
import torch
# Process embeddings with optimized concurrent processing
import requests
if torch.cuda.is_available():
batch_size = 128 # CUDA gets larger batch size
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
batch_size = 32 # MPS gets smaller batch size
except ImportError:
# If torch is not available, use conservative batch size
batch_size = 32
def get_single_embedding(text_idx_tuple):
"""Helper function to get embedding for a single text."""
text, idx = text_idx_tuple
max_retries = 3
retry_count = 0
logger.info(f"Using batch size: {batch_size}")
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
result = response.json()
embedding = result.get("embedding")
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
if embedding is None:
raise ValueError(f"No embedding returned for text {idx}")
return idx, embedding
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
return idx, None
except Exception as e:
if retry_count >= max_retries - 1:
logger.error(f"Failed to get embedding for text {idx}: {e}")
return idx, None
retry_count += 1
return idx, None
# Determine if we should use concurrent processing
use_concurrent = (
len(texts) > 5 and not is_build
) # Don't use concurrent in build mode to avoid overwhelming
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
failed_indices = []
if use_concurrent:
logger.info(
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_idx = {
executor.submit(get_single_embedding, (text, idx)): idx
for idx, text in enumerate(texts)
}
# Add progress bar for concurrent processing
try:
if is_build or len(texts) > 10:
from tqdm import tqdm
futures_iterator = tqdm(
as_completed(future_to_idx),
total=len(texts),
desc="Computing Ollama embeddings",
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
else:
futures_iterator = as_completed(future_to_idx)
except ImportError:
futures_iterator = as_completed(future_to_idx)
# Collect results as they complete
for future in futures_iterator:
try:
idx, embedding = future.result()
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
idx = future_to_idx[future]
logger.error(f"Exception for text {idx}: {e}")
failed_indices.append(idx)
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
except ImportError:
show_progress = False
# Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else:
batch_iterator = range(num_batches)
# Sequential processing with progress bar
show_progress = is_build or len(texts) > 10
for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
try:
if show_progress:
from tqdm import tqdm
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
iterator = tqdm(
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
)
else:
iterator = enumerate(texts)
except ImportError:
iterator = enumerate(texts)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
for idx, text in iterator:
result_idx, embedding = get_single_embedding((text, idx))
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
# Handle failed embeddings
if all_failed_indices:
if len(all_failed_indices) == len(texts):
if failed_indices:
if len(failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings")
logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
# Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding:
embedding_dim = len(valid_embedding)
for i, embedding in enumerate(all_embeddings):
if embedding is None:
all_embeddings[i] = [0.0] * embedding_dim
for idx in failed_indices:
all_embeddings[idx] = [0.0] * embedding_dim
# Remove None values
# Remove None values and convert to numpy array
all_embeddings = [e for e in all_embeddings if e is not None]
if not all_embeddings:
raise RuntimeError("No valid embeddings were computed")
# Validate embedding dimensions before creating numpy array
if all_embeddings:
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
# Validate embedding dimensions
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += (
f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
)
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
# Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32)

View File

@@ -8,7 +8,7 @@ import time
from pathlib import Path
from typing import Optional
# Lightweight, self-contained server manager with no cross-process inspection
import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -43,7 +43,130 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0
# Note: All cross-process scanning helpers removed for simplicity
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
cmdline = proc.info["cmdline"]
if not cmdline:
continue
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
logger.warning(f"Could not check process on port {port}: {e}")
return False
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
logger.info(f"Found compatible server on port {port}")
return port, True
else:
logger.info(f"Port {port} has incompatible server, trying next port...")
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager:
@@ -62,16 +185,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
# Track last-started config for in-process reuse only
self._server_config: Optional[dict] = None
self._atexit_registered = False
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
self._finalizer = None
def start_server(
self,
@@ -81,24 +195,26 @@ class EmbeddingServerManager:
**kwargs,
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
passages_file = kwargs.get("passages_file")
# If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port:
logger.info("Reusing in-process server")
return True, self.server_port
# Check if we have a compatible server already running
if self._has_compatible_running_server(model_name, passages_file):
logger.info("Found compatible running server!")
return True, port
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Always pick a fresh available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
# Find a compatible port or next available
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
if is_compatible:
logger.info(f"Found compatible server on port {actual_port}")
return True, actual_port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
@@ -131,7 +247,17 @@ class EmbeddingServerManager:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
return True
logger.info("Existing server process is incompatible. Should start a new server.")
return False
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
@@ -178,61 +304,22 @@ class EmbeddingServerManager:
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
# Embedding servers use many print statements that can fill stdout buffers
is_ci = os.environ.get("CI") == "true"
if is_ci:
stdout_target = subprocess.DEVNULL
stderr_target = None # Keep stderr for error debugging in CI
logger.info(
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
)
else:
stdout_target = None # Direct to console for visible logs
stderr_target = None # Direct to console for visible logs
# Start embedding server subprocess
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
stdout=None, # Direct to console
stderr=None, # Direct to console
)
self.server_port = port
# Record config for in-process reuse
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Always attempt best-effort finalize at interpreter exit
atexit.register(self._finalize_process)
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
# Touch finalizer so it knows there is a live process
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
pass
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
@@ -257,26 +344,22 @@ class EmbeddingServerManager:
if not self.server_process:
return
if self.server_process and self.server_process.poll() is not None:
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
self.server_port = None
self._server_config = None
return
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
# Use simple termination - our improved server shutdown should handle this properly
self.server_process.terminate()
try:
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
self.server_process.wait(timeout=3)
logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
logger.warning(
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
)
self.server_process.kill()
try:
@@ -286,33 +369,15 @@ class EmbeddingServerManager:
logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung"
)
# Don't hang indefinitely
# Clean up process resources with timeout to avoid CI hang
# Clean up process resources to prevent resource tracker warnings
try:
# Use shorter timeout in CI environments
is_ci = os.environ.get("CI") == "true"
timeout = 3 if is_ci else 10
self.server_process.wait(timeout=timeout)
logger.info(f"Server process {self.server_process.pid} cleanup completed")
except subprocess.TimeoutExpired:
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
except Exception as e:
logger.warning(f"Error during process cleanup: {e}")
finally:
self.server_process = None
self.server_port = None
self._server_config = None
def _finalize_process(self) -> None:
"""Best-effort cleanup used by weakref.finalize/atexit."""
try:
self.stop_server()
self.server_process.wait() # Ensure process is fully cleaned up
except Exception:
pass
def _adopt_existing_server(self, *args, **kwargs) -> None:
# Removed: cross-process adoption no longer supported
return
self.server_process = None
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
@@ -328,16 +393,10 @@ class EmbeddingServerManager:
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback (unified)
# Register atexit callback
if not self._atexit_registered:
atexit.register(self._finalize_process)
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
from typing import Any, Literal, Union
import numpy as np
@@ -35,7 +35,7 @@ class LeannBackendSearcherInterface(ABC):
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
self, passages_source_file: str, port: Union[int, None], **kwargs
) -> int:
"""Ensure server is running"""
pass
@@ -50,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: Union[int, None] = None,
**kwargs,
) -> dict[str, Any]:
"""Search for nearest neighbors
@@ -76,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
self,
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
zmq_port: Union[int, None] = None,
) -> np.ndarray:
"""Compute embedding for a query string

View File

@@ -116,6 +116,7 @@ def handle_request(request):
f"--top-k={args.get('top_k', 5)}",
f"--complexity={args.get('complexity', 32)}",
]
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_status":

View File

@@ -45,42 +45,6 @@ leann build my-project --docs ./
claude
```
## 🚀 Advanced Usage Examples
### Index Entire Git Repository
```bash
# Index all tracked files in your git repository, note right now we will skip submodules, but we can add it back easily if you want
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Index only specific file types from git
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
```
### Multiple Directories and Files
```bash
# Index multiple directories
leann build my-codebase --docs ./src ./tests ./docs ./config --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Mix files and directories
leann build my-project --docs ./README.md ./src/ ./package.json ./docs/ --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Specific files only
leann build my-configs --docs ./tsconfig.json ./package.json ./webpack.config.js --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
```
### Advanced Git Integration
```bash
# Index recently modified files
leann build recent-changes --docs $(git diff --name-only HEAD~10..HEAD) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Index files matching pattern
leann build frontend --docs $(git ls-files "*.tsx" "*.ts" "*.jsx" "*.js") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Index documentation and config files
leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.json" "*.toml") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
```
**Try this in Claude Code:**
```
Help me understand this codebase. List available indexes and search for authentication patterns.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.2.8"
version = "0.2.7"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -43,7 +43,6 @@ dependencies = [
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
"psutil>=5.8.0",
"pybind11>=3.0.0",
"pathspec>=0.12.1",
"nbconvert>=7.16.6",
"gitignore-parser>=0.1.12",
@@ -55,7 +54,7 @@ dev = [
"pytest-cov>=4.0",
"pytest-xdist>=3.0", # For parallel test execution
"black>=23.0",
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
"ruff>=0.1.0",
"matplotlib",
"huggingface-hub>=0.20.0",
"pre-commit>=3.5.0",
@@ -155,7 +154,7 @@ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key",
]
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
timeout = 600
addopts = [
"-v",
"--tb=short",

View File

@@ -6,11 +6,10 @@ This directory contains automated tests for the LEANN project using pytest.
### `test_readme_examples.py`
Tests the examples shown in README.md:
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
- The basic example code that users see first
- Import statements work correctly
- Different backend options (HNSW, DiskANN)
- Different LLM configuration options (parametrized for both backends)
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
- Different LLM configuration options
### `test_basic.py`
Basic functionality tests that verify:
@@ -26,16 +25,6 @@ Tests the document RAG example functionality:
- Tests error handling with invalid parameters
- Verifies that normalized embeddings are detected and cosine distance is used
### `test_diskann_partition.py`
Tests DiskANN graph partitioning functionality:
- Tests DiskANN index building without partitioning (baseline)
- Tests automatic graph partitioning with `is_recompute=True`
- Verifies that partition files are created and large files are cleaned up for storage saving
- Tests search functionality with partitioned indices
- Validates medoid and max_base_norm file generation and usage
- Includes performance comparison between DiskANN (with partition) and HNSW
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
## Running Tests
### Install test dependencies:
@@ -65,23 +54,15 @@ pytest tests/ -m "not openai"
# Skip slow tests
pytest tests/ -m "not slow"
# Run DiskANN partition tests (requires local machine, not CI)
pytest tests/test_diskann_partition.py
```
### Run with specific backend:
```bash
# Test only HNSW backend
pytest tests/test_basic.py::test_backend_basic[hnsw]
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
# Test only DiskANN backend
pytest tests/test_basic.py::test_backend_basic[diskann]
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
# All DiskANN tests (parametrized + specialized partition tests)
pytest tests/ -k diskann
```
## CI/CD Integration

View File

@@ -64,9 +64,6 @@ def test_backend_basic(backend_name):
assert isinstance(results[0], SearchResult)
assert "topic 2" in results[0].text or "document" in results[0].text
# Ensure cleanup to avoid hanging background servers
searcher.cleanup()
@pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
@@ -93,5 +90,3 @@ def test_large_index():
searcher = LeannSearcher(index_path)
results = searcher.search(["word10 word20"], top_k=10)
assert len(results[0]) == 10
# Cleanup
searcher.cleanup()

View File

@@ -1,369 +0,0 @@
"""
Test DiskANN graph partitioning functionality.
Tests the automatic graph partitioning feature that was implemented to save
storage space by partitioning large DiskANN indices and safely deleting
redundant files while maintaining search functionality.
"""
import os
import tempfile
from pathlib import Path
import pytest
@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
)
def test_diskann_without_partition():
"""Test DiskANN index building without partition (baseline)."""
from leann.api import LeannBuilder, LeannSearcher
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / "test_no_partition.leann")
# Test data - enough to trigger index building
texts = [
f"Document {i} discusses topic {i % 10} with detailed analysis of subject {i // 10}."
for i in range(500)
]
# Build without partition (is_recompute=False)
builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
num_neighbors=32,
search_list_size=50,
is_recompute=False, # No partition
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
# Verify index was created
index_dir = Path(index_path).parent
assert index_dir.exists()
# Check that traditional DiskANN files exist
index_prefix = Path(index_path).stem
# Core DiskANN files (beam search index may not be created for small datasets)
required_files = [
f"{index_prefix}_disk.index",
f"{index_prefix}_pq_compressed.bin",
f"{index_prefix}_pq_pivots.bin",
]
# Check all generated files first for debugging
generated_files = [f.name for f in index_dir.glob(f"{index_prefix}*")]
print(f"Generated files: {generated_files}")
for required_file in required_files:
file_path = index_dir / required_file
assert file_path.exists(), f"Required file {required_file} not found"
# Ensure no partition files exist in non-partition mode
partition_files = [f"{index_prefix}_disk_graph.index", f"{index_prefix}_partition.bin"]
for partition_file in partition_files:
file_path = index_dir / partition_file
assert not file_path.exists(), (
f"Partition file {partition_file} should not exist in non-partition mode"
)
# Test search functionality
searcher = LeannSearcher(index_path)
results = searcher.search("topic 3 analysis", top_k=3)
assert len(results) > 0
assert all(result.score is not None and result.score != float("-inf") for result in results)
@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
)
def test_diskann_with_partition():
"""Test DiskANN index building with automatic graph partitioning."""
from leann.api import LeannBuilder
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / "test_with_partition.leann")
# Test data - enough to trigger partitioning
texts = [
f"Document {i} explores subject {i % 15} with comprehensive coverage of area {i // 15}."
for i in range(500)
]
# Build with partition (is_recompute=True)
builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
num_neighbors=32,
search_list_size=50,
is_recompute=True, # Enable automatic partitioning
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
# Verify index was created
index_dir = Path(index_path).parent
assert index_dir.exists()
# Check that partition files exist
index_prefix = Path(index_path).stem
partition_files = [
f"{index_prefix}_disk_graph.index", # Partitioned graph
f"{index_prefix}_partition.bin", # Partition metadata
f"{index_prefix}_pq_compressed.bin",
f"{index_prefix}_pq_pivots.bin",
]
for partition_file in partition_files:
file_path = index_dir / partition_file
assert file_path.exists(), f"Expected partition file {partition_file} not found"
# Check that large files were cleaned up (storage saving goal)
large_files = [f"{index_prefix}_disk.index", f"{index_prefix}_disk_beam_search.index"]
for large_file in large_files:
file_path = index_dir / large_file
assert not file_path.exists(), (
f"Large file {large_file} should have been deleted for storage saving"
)
# Verify required auxiliary files for partition mode exist
required_files = [
f"{index_prefix}_disk.index_medoids.bin",
f"{index_prefix}_disk.index_max_base_norm.bin",
]
for req_file in required_files:
file_path = index_dir / req_file
assert file_path.exists(), (
f"Required auxiliary file {req_file} missing for partition mode"
)
@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
)
def test_diskann_partition_search_functionality():
"""Test that search works correctly with partitioned indices."""
from leann.api import LeannBuilder, LeannSearcher
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / "test_partition_search.leann")
# Create diverse test data
texts = [
"LEANN is a storage-efficient approximate nearest neighbor search system.",
"Graph partitioning helps reduce memory usage in large scale vector search.",
"DiskANN provides high-performance disk-based approximate nearest neighbor search.",
"Vector embeddings enable semantic search over unstructured text data.",
"Approximate nearest neighbor algorithms trade accuracy for speed and storage.",
] * 100 # Repeat to get enough data
# Build with partitioning
builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
is_recompute=True, # Enable partitioning
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
# Test search with partitioned index
searcher = LeannSearcher(index_path)
# Test various queries
test_queries = [
("vector search algorithms", 5),
("LEANN storage efficiency", 3),
("graph partitioning memory", 4),
("approximate nearest neighbor", 7),
]
for query, top_k in test_queries:
results = searcher.search(query, top_k=top_k)
# Verify search results
assert len(results) == top_k, f"Expected {top_k} results for query '{query}'"
assert all(result.score is not None for result in results), (
"All results should have scores"
)
assert all(result.score != float("-inf") for result in results), (
"No result should have -inf score"
)
assert all(result.text is not None for result in results), (
"All results should have text"
)
# Scores should be in descending order (higher similarity first)
scores = [result.score for result in results]
assert scores == sorted(scores, reverse=True), (
"Results should be sorted by score descending"
)
@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
)
def test_diskann_medoid_and_norm_files():
"""Test that medoid and max_base_norm files are correctly generated and used."""
import struct
from leann.api import LeannBuilder, LeannSearcher
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / "test_medoid_norm.leann")
# Small but sufficient dataset
texts = [f"Test document {i} with content about subject {i % 10}." for i in range(200)]
builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
is_recompute=True,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
index_dir = Path(index_path).parent
index_prefix = Path(index_path).stem
# Test medoids file
medoids_file = index_dir / f"{index_prefix}_disk.index_medoids.bin"
assert medoids_file.exists(), "Medoids file should be generated"
# Read and validate medoids file format
with open(medoids_file, "rb") as f:
nshards = struct.unpack("<I", f.read(4))[0]
one_val = struct.unpack("<I", f.read(4))[0]
medoid_id = struct.unpack("<I", f.read(4))[0]
assert nshards == 1, "Single-shot build should have 1 shard"
assert one_val == 1, "Expected value should be 1"
assert medoid_id >= 0, "Medoid ID should be valid (not hardcoded 0)"
# Test max_base_norm file
norm_file = index_dir / f"{index_prefix}_disk.index_max_base_norm.bin"
assert norm_file.exists(), "Max base norm file should be generated"
# Read and validate norm file
with open(norm_file, "rb") as f:
npts = struct.unpack("<I", f.read(4))[0]
ndims = struct.unpack("<I", f.read(4))[0]
norm_val = struct.unpack("<f", f.read(4))[0]
assert npts == 1, "Should have 1 norm point"
assert ndims == 1, "Should have 1 dimension"
assert norm_val > 0, "Norm value should be positive"
assert norm_val != float("inf"), "Norm value should be finite"
# Test that search works with these files
searcher = LeannSearcher(index_path)
results = searcher.search("test subject", top_k=3)
# Verify that scores are not -inf (which indicates norm file was loaded correctly)
assert len(results) > 0
assert all(result.score != float("-inf") for result in results), (
"Scores should not be -inf when norm file is correct"
)
@pytest.mark.skipif(
os.environ.get("CI") == "true",
reason="Skip performance comparison in CI - requires significant compute time",
)
def test_diskann_vs_hnsw_performance():
"""Compare DiskANN (with partition) vs HNSW performance."""
import time
from leann.api import LeannBuilder, LeannSearcher
with tempfile.TemporaryDirectory() as temp_dir:
# Test data
texts = [
f"Performance test document {i} covering topic {i % 20} in detail." for i in range(1000)
]
query = "performance topic test"
# Test DiskANN with partitioning
diskann_path = str(Path(temp_dir) / "perf_diskann.leann")
diskann_builder = LeannBuilder(
backend_name="diskann",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
is_recompute=True,
)
for text in texts:
diskann_builder.add_text(text)
start_time = time.time()
diskann_builder.build_index(diskann_path)
# Test HNSW
hnsw_path = str(Path(temp_dir) / "perf_hnsw.leann")
hnsw_builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
is_recompute=True,
)
for text in texts:
hnsw_builder.add_text(text)
start_time = time.time()
hnsw_builder.build_index(hnsw_path)
# Compare search performance
diskann_searcher = LeannSearcher(diskann_path)
hnsw_searcher = LeannSearcher(hnsw_path)
# Warm up searches
diskann_searcher.search(query, top_k=5)
hnsw_searcher.search(query, top_k=5)
# Timed searches
start_time = time.time()
diskann_results = diskann_searcher.search(query, top_k=10)
diskann_search_time = time.time() - start_time
start_time = time.time()
hnsw_results = hnsw_searcher.search(query, top_k=10)
hnsw_search_time = time.time() - start_time
# Basic assertions
assert len(diskann_results) == 10
assert len(hnsw_results) == 10
assert all(r.score != float("-inf") for r in diskann_results)
assert all(r.score != float("-inf") for r in hnsw_results)
# Performance ratio (informational)
if hnsw_search_time > 0:
speed_ratio = hnsw_search_time / diskann_search_time
print(f"DiskANN search time: {diskann_search_time:.4f}s")
print(f"HNSW search time: {hnsw_search_time:.4f}s")
print(f"DiskANN is {speed_ratio:.2f}x faster than HNSW")

View File

@@ -58,9 +58,6 @@ def test_document_rag_simulated(test_data_dir):
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
@pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip OpenAI tests in CI to avoid API costs"
)
def test_document_rag_openai(test_data_dir):
"""Test document_rag with OpenAI embeddings."""
with tempfile.TemporaryDirectory() as temp_dir:

View File

@@ -10,33 +10,29 @@ from pathlib import Path
import pytest
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
def test_readme_basic_example(backend_name):
"""Test the basic example from README.md with both backends."""
def test_readme_basic_example():
"""Test the basic example from README.md."""
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
# Skip DiskANN on CI (Linux runners) due to C++ extension memory/hardware constraints
if os.environ.get("CI") == "true" and backend_name == "diskann":
pytest.skip("Skip DiskANN tests in CI due to resource constraints and instability")
# This is the exact code from README (with smaller model for CI)
from leann import LeannBuilder, LeannChat, LeannSearcher
from leann.api import SearchResult
with tempfile.TemporaryDirectory() as temp_dir:
INDEX_PATH = str(Path(temp_dir) / f"demo_{backend_name}.leann")
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
# Build an index
# In CI, use a smaller model to avoid memory issues
if os.environ.get("CI") == "true":
builder = LeannBuilder(
backend_name=backend_name,
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
dimensions=384, # Smaller dimensions
)
else:
builder = LeannBuilder(backend_name=backend_name)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
builder.build_index(INDEX_PATH)
@@ -56,15 +52,9 @@ def test_readme_basic_example(backend_name):
# Verify search results
assert len(results) > 0
assert isinstance(results[0], SearchResult)
assert results[0].score != float("-inf"), (
f"should return valid scores, got {results[0].score}"
)
# The second text about banana-crocodile should be more relevant
assert "banana" in results[0].text or "crocodile" in results[0].text
# Ensure we cleanup background embedding server
searcher.cleanup()
# Chat with your data (using simulated LLM to avoid external dependencies)
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
@@ -72,8 +62,6 @@ def test_readme_basic_example(backend_name):
# Verify chat works
assert isinstance(response, str)
assert len(response) > 0
# Cleanup chat resources
chat.cleanup()
def test_readme_imports():
@@ -122,31 +110,26 @@ def test_backend_options():
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
def test_llm_config_simulated(backend_name):
"""Test simulated LLM configuration option with both backends."""
def test_llm_config_simulated():
"""Test simulated LLM configuration option."""
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
# Skip DiskANN tests in CI due to hardware requirements
if os.environ.get("CI") == "true" and backend_name == "diskann":
pytest.skip("Skip DiskANN tests in CI - requires specific hardware and large memory")
from leann import LeannBuilder, LeannChat
with tempfile.TemporaryDirectory() as temp_dir:
# Build a simple index
index_path = str(Path(temp_dir) / f"test_{backend_name}.leann")
index_path = str(Path(temp_dir) / "test.leann")
# Use smaller model in CI to avoid memory issues
if os.environ.get("CI") == "true":
builder = LeannBuilder(
backend_name=backend_name,
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
dimensions=384,
)
else:
builder = LeannBuilder(backend_name=backend_name)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("Test document for LLM testing")
builder.build_index(index_path)

77
uv.lock generated
View File

@@ -2223,7 +2223,7 @@ wheels = [
[[package]]
name = "leann-backend-diskann"
version = "0.2.8"
version = "0.2.6"
source = { editable = "packages/leann-backend-diskann" }
dependencies = [
{ name = "leann-core" },
@@ -2235,14 +2235,14 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "leann-core", specifier = "==0.2.8" },
{ name = "leann-core", specifier = "==0.2.6" },
{ name = "numpy" },
{ name = "protobuf", specifier = ">=3.19.0" },
]
[[package]]
name = "leann-backend-hnsw"
version = "0.2.8"
version = "0.2.6"
source = { editable = "packages/leann-backend-hnsw" }
dependencies = [
{ name = "leann-core" },
@@ -2255,7 +2255,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "leann-core", specifier = "==0.2.8" },
{ name = "leann-core", specifier = "==0.2.6" },
{ name = "msgpack", specifier = ">=1.0.0" },
{ name = "numpy" },
{ name = "pyzmq", specifier = ">=23.0.0" },
@@ -2263,7 +2263,7 @@ requires-dist = [
[[package]]
name = "leann-core"
version = "0.2.8"
version = "0.2.6"
source = { editable = "packages/leann-core" }
dependencies = [
{ name = "accelerate" },
@@ -2272,8 +2272,8 @@ dependencies = [
{ name = "llama-index-core" },
{ name = "llama-index-embeddings-huggingface" },
{ name = "llama-index-readers-file" },
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin'" },
{ name = "msgpack" },
{ name = "nbconvert" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
@@ -2302,8 +2302,8 @@ requires-dist = [
{ name = "llama-index-core", specifier = ">=0.12.0" },
{ name = "llama-index-embeddings-huggingface", specifier = ">=0.5.5" },
{ name = "llama-index-readers-file", specifier = ">=0.4.0" },
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.3" },
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.0" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.26.3" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin'", specifier = ">=0.26.0" },
{ name = "msgpack", specifier = ">=1.0.0" },
{ name = "nbconvert", specifier = ">=7.0.0" },
{ name = "numpy", specifier = ">=1.20.0" },
@@ -2343,8 +2343,8 @@ dependencies = [
{ name = "llama-index-embeddings-huggingface" },
{ name = "llama-index-readers-file" },
{ name = "llama-index-vector-stores-faiss" },
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin'" },
{ name = "msgpack" },
{ name = "nbconvert" },
{ name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
@@ -2356,7 +2356,6 @@ dependencies = [
{ name = "pdfplumber" },
{ name = "protobuf" },
{ name = "psutil" },
{ name = "pybind11" },
{ name = "pymupdf" },
{ name = "pypdf2" },
{ name = "pypdfium2" },
@@ -2425,8 +2424,8 @@ requires-dist = [
{ name = "llama-index-readers-file", marker = "extra == 'test'", specifier = ">=0.4.0" },
{ name = "llama-index-vector-stores-faiss", specifier = ">=0.4.0" },
{ name = "matplotlib", marker = "extra == 'dev'" },
{ name = "mlx", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.3" },
{ name = "mlx-lm", marker = "platform_machine == 'arm64' and sys_platform == 'darwin'", specifier = ">=0.26.0" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = ">=0.26.3" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin'", specifier = ">=0.26.0" },
{ name = "msgpack", specifier = ">=1.1.1" },
{ name = "nbconvert", specifier = ">=7.16.6" },
{ name = "numpy", specifier = ">=1.26.0" },
@@ -2439,7 +2438,6 @@ requires-dist = [
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5.0" },
{ name = "protobuf", specifier = "==4.25.3" },
{ name = "psutil", specifier = ">=5.8.0" },
{ name = "pybind11", specifier = ">=3.0.0" },
{ name = "pymupdf", specifier = ">=1.26.0" },
{ name = "pypdf2", specifier = ">=3.0.0" },
{ name = "pypdfium2", specifier = ">=4.30.0" },
@@ -2451,7 +2449,7 @@ requires-dist = [
{ name = "python-docx", marker = "extra == 'documents'", specifier = ">=0.8.11" },
{ name = "python-dotenv", marker = "extra == 'test'", specifier = ">=1.0.0" },
{ name = "requests", specifier = ">=2.25.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = "==0.12.7" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
{ name = "sentence-transformers", specifier = ">=2.2.0" },
{ name = "sentence-transformers", marker = "extra == 'test'", specifier = ">=2.2.0" },
{ name = "sglang" },
@@ -4360,15 +4358,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/15/6b30e77872012bbfe8265d42a01d5b3c17ef0ac0f2fae531ad91b6a6c02e/pyarrow-21.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdc4c17afda4dab2a9c0b79148a43a7f4e1094916b3e18d8975bfd6d6d52241f", size = 26227521 },
]
[[package]]
name = "pybind11"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ef/83/698d120e257a116f2472c710932023ad779409adf2734d2e940f34eea2c5/pybind11-3.0.0.tar.gz", hash = "sha256:c3f07bce3ada51c3e4b76badfa85df11688d12c46111f9d242bc5c9415af7862", size = 544819 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/41/9c/85f50a5476832c3efc67b6d7997808388236ae4754bf53e1749b3bc27577/pybind11-3.0.0-py3-none-any.whl", hash = "sha256:7c5cac504da5a701b5163f0e6a7ba736c713a096a5378383c5b4b064b753f607", size = 292118 },
]
[[package]]
name = "pycparser"
version = "2.22"
@@ -5215,27 +5204,27 @@ wheels = [
[[package]]
name = "ruff"
version = "0.12.7"
version = "0.12.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/81/0bd3594fa0f690466e41bd033bdcdf86cba8288345ac77ad4afbe5ec743a/ruff-0.12.7.tar.gz", hash = "sha256:1fc3193f238bc2d7968772c82831a4ff69252f673be371fb49663f0068b7ec71", size = 5197814 }
sdist = { url = "https://files.pythonhosted.org/packages/30/cd/01015eb5034605fd98d829c5839ec2c6b4582b479707f7c1c2af861e8258/ruff-0.12.5.tar.gz", hash = "sha256:b209db6102b66f13625940b7f8c7d0f18e20039bb7f6101fbdac935c9612057e", size = 5170722 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e1/d2/6cb35e9c85e7a91e8d22ab32ae07ac39cc34a71f1009a6f9e4a2a019e602/ruff-0.12.7-py3-none-linux_armv6l.whl", hash = "sha256:76e4f31529899b8c434c3c1dede98c4483b89590e15fb49f2d46183801565303", size = 11852189 },
{ url = "https://files.pythonhosted.org/packages/63/5b/a4136b9921aa84638f1a6be7fb086f8cad0fde538ba76bda3682f2599a2f/ruff-0.12.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:789b7a03e72507c54fb3ba6209e4bb36517b90f1a3569ea17084e3fd295500fb", size = 12519389 },
{ url = "https://files.pythonhosted.org/packages/a8/c9/3e24a8472484269b6b1821794141f879c54645a111ded4b6f58f9ab0705f/ruff-0.12.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2e1c2a3b8626339bb6369116e7030a4cf194ea48f49b64bb505732a7fce4f4e3", size = 11743384 },
{ url = "https://files.pythonhosted.org/packages/26/7c/458dd25deeb3452c43eaee853c0b17a1e84169f8021a26d500ead77964fd/ruff-0.12.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32dec41817623d388e645612ec70d5757a6d9c035f3744a52c7b195a57e03860", size = 11943759 },
{ url = "https://files.pythonhosted.org/packages/7f/8b/658798472ef260ca050e400ab96ef7e85c366c39cf3dfbef4d0a46a528b6/ruff-0.12.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47ef751f722053a5df5fa48d412dbb54d41ab9b17875c6840a58ec63ff0c247c", size = 11654028 },
{ url = "https://files.pythonhosted.org/packages/a8/86/9c2336f13b2a3326d06d39178fd3448dcc7025f82514d1b15816fe42bfe8/ruff-0.12.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a828a5fc25a3efd3e1ff7b241fd392686c9386f20e5ac90aa9234a5faa12c423", size = 13225209 },
{ url = "https://files.pythonhosted.org/packages/76/69/df73f65f53d6c463b19b6b312fd2391dc36425d926ec237a7ed028a90fc1/ruff-0.12.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5726f59b171111fa6a69d82aef48f00b56598b03a22f0f4170664ff4d8298efb", size = 14182353 },
{ url = "https://files.pythonhosted.org/packages/58/1e/de6cda406d99fea84b66811c189b5ea139814b98125b052424b55d28a41c/ruff-0.12.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74e6f5c04c4dd4aba223f4fe6e7104f79e0eebf7d307e4f9b18c18362124bccd", size = 13631555 },
{ url = "https://files.pythonhosted.org/packages/6f/ae/625d46d5164a6cc9261945a5e89df24457dc8262539ace3ac36c40f0b51e/ruff-0.12.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0bfe4e77fba61bf2ccadf8cf005d6133e3ce08793bbe870dd1c734f2699a3e", size = 12667556 },
{ url = "https://files.pythonhosted.org/packages/55/bf/9cb1ea5e3066779e42ade8d0cd3d3b0582a5720a814ae1586f85014656b6/ruff-0.12.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06bfb01e1623bf7f59ea749a841da56f8f653d641bfd046edee32ede7ff6c606", size = 12939784 },
{ url = "https://files.pythonhosted.org/packages/55/7f/7ead2663be5627c04be83754c4f3096603bf5e99ed856c7cd29618c691bd/ruff-0.12.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e41df94a957d50083fd09b916d6e89e497246698c3f3d5c681c8b3e7b9bb4ac8", size = 11771356 },
{ url = "https://files.pythonhosted.org/packages/17/40/a95352ea16edf78cd3a938085dccc55df692a4d8ba1b3af7accbe2c806b0/ruff-0.12.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4000623300563c709458d0ce170c3d0d788c23a058912f28bbadc6f905d67afa", size = 11612124 },
{ url = "https://files.pythonhosted.org/packages/4d/74/633b04871c669e23b8917877e812376827c06df866e1677f15abfadc95cb/ruff-0.12.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:69ffe0e5f9b2cf2b8e289a3f8945b402a1b19eff24ec389f45f23c42a3dd6fb5", size = 12479945 },
{ url = "https://files.pythonhosted.org/packages/be/34/c3ef2d7799c9778b835a76189c6f53c179d3bdebc8c65288c29032e03613/ruff-0.12.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a07a5c8ffa2611a52732bdc67bf88e243abd84fe2d7f6daef3826b59abbfeda4", size = 12998677 },
{ url = "https://files.pythonhosted.org/packages/77/ab/aca2e756ad7b09b3d662a41773f3edcbd262872a4fc81f920dc1ffa44541/ruff-0.12.7-py3-none-win32.whl", hash = "sha256:c928f1b2ec59fb77dfdf70e0419408898b63998789cc98197e15f560b9e77f77", size = 11756687 },
{ url = "https://files.pythonhosted.org/packages/b4/71/26d45a5042bc71db22ddd8252ca9d01e9ca454f230e2996bb04f16d72799/ruff-0.12.7-py3-none-win_amd64.whl", hash = "sha256:9c18f3d707ee9edf89da76131956aba1270c6348bfee8f6c647de841eac7194f", size = 12912365 },
{ url = "https://files.pythonhosted.org/packages/4c/9b/0b8aa09817b63e78d94b4977f18b1fcaead3165a5ee49251c5d5c245bb2d/ruff-0.12.7-py3-none-win_arm64.whl", hash = "sha256:dfce05101dbd11833a0776716d5d1578641b7fddb537fe7fa956ab85d1769b69", size = 11982083 },
{ url = "https://files.pythonhosted.org/packages/d4/de/ad2f68f0798ff15dd8c0bcc2889558970d9a685b3249565a937cd820ad34/ruff-0.12.5-py3-none-linux_armv6l.whl", hash = "sha256:1de2c887e9dec6cb31fcb9948299de5b2db38144e66403b9660c9548a67abd92", size = 11819133 },
{ url = "https://files.pythonhosted.org/packages/f8/fc/c6b65cd0e7fbe60f17e7ad619dca796aa49fbca34bb9bea5f8faf1ec2643/ruff-0.12.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d1ab65e7d8152f519e7dea4de892317c9da7a108da1c56b6a3c1d5e7cf4c5e9a", size = 12501114 },
{ url = "https://files.pythonhosted.org/packages/c5/de/c6bec1dce5ead9f9e6a946ea15e8d698c35f19edc508289d70a577921b30/ruff-0.12.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:962775ed5b27c7aa3fdc0d8f4d4433deae7659ef99ea20f783d666e77338b8cf", size = 11716873 },
{ url = "https://files.pythonhosted.org/packages/a1/16/cf372d2ebe91e4eb5b82a2275c3acfa879e0566a7ac94d331ea37b765ac8/ruff-0.12.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b4cae449597e7195a49eb1cdca89fd9fbb16140c7579899e87f4c85bf82f73", size = 11958829 },
{ url = "https://files.pythonhosted.org/packages/25/bf/cd07e8f6a3a6ec746c62556b4c4b79eeb9b0328b362bb8431b7b8afd3856/ruff-0.12.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b13489c3dc50de5e2d40110c0cce371e00186b880842e245186ca862bf9a1ac", size = 11626619 },
{ url = "https://files.pythonhosted.org/packages/d8/c9/c2ccb3b8cbb5661ffda6925f81a13edbb786e623876141b04919d1128370/ruff-0.12.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1504fea81461cf4841778b3ef0a078757602a3b3ea4b008feb1308cb3f23e08", size = 13221894 },
{ url = "https://files.pythonhosted.org/packages/6b/58/68a5be2c8e5590ecdad922b2bcd5583af19ba648f7648f95c51c3c1eca81/ruff-0.12.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c7da4129016ae26c32dfcbd5b671fe652b5ab7fc40095d80dcff78175e7eddd4", size = 14163909 },
{ url = "https://files.pythonhosted.org/packages/bd/d1/ef6b19622009ba8386fdb792c0743f709cf917b0b2f1400589cbe4739a33/ruff-0.12.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ca972c80f7ebcfd8af75a0f18b17c42d9f1ef203d163669150453f50ca98ab7b", size = 13583652 },
{ url = "https://files.pythonhosted.org/packages/62/e3/1c98c566fe6809a0c83751d825a03727f242cdbe0d142c9e292725585521/ruff-0.12.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8dbbf9f25dfb501f4237ae7501d6364b76a01341c6f1b2cd6764fe449124bb2a", size = 12700451 },
{ url = "https://files.pythonhosted.org/packages/24/ff/96058f6506aac0fbc0d0fc0d60b0d0bd746240a0594657a2d94ad28033ba/ruff-0.12.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c47dea6ae39421851685141ba9734767f960113d51e83fd7bb9958d5be8763a", size = 12937465 },
{ url = "https://files.pythonhosted.org/packages/eb/d3/68bc5e7ab96c94b3589d1789f2dd6dd4b27b263310019529ac9be1e8f31b/ruff-0.12.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c5076aa0e61e30f848846f0265c873c249d4b558105b221be1828f9f79903dc5", size = 11771136 },
{ url = "https://files.pythonhosted.org/packages/52/75/7356af30a14584981cabfefcf6106dea98cec9a7af4acb5daaf4b114845f/ruff-0.12.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a5a4c7830dadd3d8c39b1cc85386e2c1e62344f20766be6f173c22fb5f72f293", size = 11601644 },
{ url = "https://files.pythonhosted.org/packages/c2/67/91c71d27205871737cae11025ee2b098f512104e26ffd8656fd93d0ada0a/ruff-0.12.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:46699f73c2b5b137b9dc0fc1a190b43e35b008b398c6066ea1350cce6326adcb", size = 12478068 },
{ url = "https://files.pythonhosted.org/packages/34/04/b6b00383cf2f48e8e78e14eb258942fdf2a9bf0287fbf5cdd398b749193a/ruff-0.12.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5a655a0a0d396f0f072faafc18ebd59adde8ca85fb848dc1b0d9f024b9c4d3bb", size = 12991537 },
{ url = "https://files.pythonhosted.org/packages/3e/b9/053d6445dc7544fb6594785056d8ece61daae7214859ada4a152ad56b6e0/ruff-0.12.5-py3-none-win32.whl", hash = "sha256:dfeb2627c459b0b78ca2bbdc38dd11cc9a0a88bf91db982058b26ce41714ffa9", size = 11751575 },
{ url = "https://files.pythonhosted.org/packages/bc/0f/ab16e8259493137598b9149734fec2e06fdeda9837e6f634f5c4e35916da/ruff-0.12.5-py3-none-win_amd64.whl", hash = "sha256:ae0d90cf5f49466c954991b9d8b953bd093c32c27608e409ae3564c63c5306a5", size = 12882273 },
{ url = "https://files.pythonhosted.org/packages/00/db/c376b0661c24cf770cb8815268190668ec1330eba8374a126ceef8c72d55/ruff-0.12.5-py3-none-win_arm64.whl", hash = "sha256:48cdbfc633de2c5c37d9f090ba3b352d1576b0015bfc3bc98eaf230275b7e805", size = 11951564 },
]
[[package]]