Compare commits
13 Commits
fix/clean-
...
v0.2.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 |
87
.github/workflows/build-reusable.yml
vendored
87
.github/workflows/build-reusable.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff==0.12.7
|
||||
uv tool install ruff
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
@@ -54,16 +54,26 @@ jobs:
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
- os: macos-latest
|
||||
- os: macos-14
|
||||
python: '3.9'
|
||||
- os: macos-latest
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-latest
|
||||
- os: macos-14
|
||||
python: '3.11'
|
||||
- os: macos-latest
|
||||
- os: macos-14
|
||||
python: '3.12'
|
||||
- os: macos-latest
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-13
|
||||
python: '3.9'
|
||||
- os: macos-13
|
||||
python: '3.10'
|
||||
- os: macos-13
|
||||
python: '3.11'
|
||||
- os: macos-13
|
||||
python: '3.12'
|
||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
@@ -109,41 +119,56 @@ jobs:
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Use brew --prefix to automatically detect Homebrew installation path
|
||||
HOMEBREW_PREFIX=$(brew --prefix)
|
||||
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||
|
||||
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
|
||||
# Set compiler flags for OpenMP (required for both backends)
|
||||
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent) on all platforms for consistency
|
||||
# Build core (platform independent)
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
uv build --wheel --python python
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python python
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# sgesdd_ is only available on macOS 13.3+
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
uv build --wheel --python python
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python python
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent) on all platforms
|
||||
# Build meta package (platform independent)
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
@@ -160,15 +185,10 @@ jobs:
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel - use show first to debug
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
echo "Checking DiskANN wheel contents before repair:"
|
||||
unzip -l dist/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found"
|
||||
auditwheel show dist/*.whl || echo "auditwheel show failed"
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
echo "Checking DiskANN wheel contents after repair:"
|
||||
unzip -l dist_repaired/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found after repair"
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
@@ -200,29 +220,22 @@ jobs:
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
uv venv --python python${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install the built wheels directly to ensure we use locally built packages
|
||||
# Use only locally built wheels on all platforms for full consistency
|
||||
FIND_LINKS="--find-links packages/leann-core/dist --find-links packages/leann/dist"
|
||||
FIND_LINKS="$FIND_LINKS --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist"
|
||||
|
||||
uv pip install leann-core leann leann-backend-hnsw leann-backend-diskann \
|
||||
$FIND_LINKS --force-reinstall
|
||||
# Install packages using --find-links to prioritize local builds
|
||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
# Debug: Check if _diskannpy module is installed correctly
|
||||
echo "Checking installed DiskANN module structure:"
|
||||
python -c "import leann_backend_diskann; print('leann_backend_diskann location:', leann_backend_diskann.__file__)" || echo "Failed to import leann_backend_diskann"
|
||||
python -c "from leann_backend_diskann import _diskannpy; print('_diskannpy imported successfully')" || echo "Failed to import _diskannpy"
|
||||
ls -la $(python -c "import leann_backend_diskann; import os; print(os.path.dirname(leann_backend_diskann.__file__))" 2>/dev/null) 2>/dev/null || echo "Failed to list module directory"
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
@@ -10,7 +10,7 @@ repos:
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
rev: v0.2.1
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-format
|
||||
|
||||
25
README.md
25
README.md
@@ -3,10 +3,11 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||
</p>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
@@ -97,6 +98,7 @@ uv sync
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
@@ -188,7 +190,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
||||
--force-rebuild # Force rebuild index even if it exists
|
||||
|
||||
# Embedding Parameters
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, or mlx-community/multilingual-e5-base-mlx
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
@@ -453,7 +455,7 @@ leann --help
|
||||
**To make it globally available:**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann-core
|
||||
uv tool install leann
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
@@ -466,7 +468,7 @@ leann --help
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# build from a specific directory, and my_docs is the index name
|
||||
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
|
||||
leann build my-docs --docs ./your_documents
|
||||
|
||||
# Search your documents
|
||||
@@ -541,16 +543,12 @@ Options:
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
**Backends:**
|
||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
||||
### 📊 Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
@@ -609,8 +607,9 @@ We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
|
||||
---
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/#yichuan-w/LEANN&Date)
|
||||
<p align="center">
|
||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||
</p>
|
||||
|
||||
@@ -1,24 +1,9 @@
|
||||
# 🧪 LEANN Benchmarks & Testing
|
||||
# 🧪 Leann Sanity Checks
|
||||
|
||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `diskann_vs_hnsw_speed_comparison.py`
|
||||
Performance comparison between DiskANN and HNSW backends:
|
||||
- ✅ **Search latency** comparison with both backends using recompute
|
||||
- ✅ **Index size** and **build time** measurements
|
||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||
- ✅ **Configurable dataset sizes** for different scales
|
||||
|
||||
```bash
|
||||
# Quick comparison with 500 docs, 10 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||
|
||||
# Large-scale comparison with 2000 docs, 20 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||
```
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN vs HNSW Search Performance Comparison
|
||||
|
||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||
- HNSW: With recompute enabled (is_recompute=True)
|
||||
- Tests performance across different dataset sizes
|
||||
- Measures search latency, recall, and index size
|
||||
"""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_test_texts(n_docs: int) -> list[str]:
|
||||
"""Create synthetic test documents for benchmarking."""
|
||||
np.random.seed(42)
|
||||
topics = [
|
||||
"machine learning and artificial intelligence",
|
||||
"natural language processing and text analysis",
|
||||
"computer vision and image recognition",
|
||||
"data science and statistical analysis",
|
||||
"deep learning and neural networks",
|
||||
"information retrieval and search engines",
|
||||
"database systems and data management",
|
||||
"software engineering and programming",
|
||||
"cybersecurity and network protection",
|
||||
"cloud computing and distributed systems",
|
||||
]
|
||||
|
||||
texts = []
|
||||
for i in range(n_docs):
|
||||
topic = topics[i % len(topics)]
|
||||
variation = np.random.randint(1, 100)
|
||||
text = (
|
||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||
f"Additional information about {topic} with details and examples. "
|
||||
f"Technical discussion of {topic} including implementation aspects."
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def benchmark_backend(
|
||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark a specific backend with the given configuration."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||
|
||||
# Build index
|
||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||
start_time = time.time()
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
build_time = time.time() - start_time
|
||||
|
||||
# Measure index size
|
||||
index_dir = Path(index_path).parent
|
||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||
|
||||
# Search benchmark
|
||||
print("🔍 Running search benchmark...")
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
search_times = []
|
||||
all_results = []
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
all_results.append(results)
|
||||
|
||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||
|
||||
# Check for valid scores (detect -inf issues)
|
||||
all_scores = [
|
||||
result.score
|
||||
for results in all_results
|
||||
for result in results
|
||||
if result.score is not None
|
||||
]
|
||||
valid_scores = [
|
||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||
]
|
||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||
|
||||
# Clean up
|
||||
try:
|
||||
if hasattr(searcher, "__del__"):
|
||||
searcher.__del__()
|
||||
del searcher
|
||||
del builder
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||
|
||||
return {
|
||||
"build_time": build_time,
|
||||
"avg_search_time_ms": avg_search_time,
|
||||
"index_size_mb": size_mb,
|
||||
"score_validity_rate": score_validity_rate,
|
||||
}
|
||||
|
||||
|
||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||
"""Run performance comparison between DiskANN and HNSW."""
|
||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||
|
||||
# Create test data
|
||||
texts = create_test_texts(n_docs)
|
||||
test_queries = [
|
||||
"machine learning algorithms",
|
||||
"natural language processing",
|
||||
"computer vision techniques",
|
||||
"data analysis methods",
|
||||
"neural network architectures",
|
||||
"database query optimization",
|
||||
"software development practices",
|
||||
"security vulnerabilities",
|
||||
"cloud infrastructure",
|
||||
"distributed computing",
|
||||
][:n_queries]
|
||||
|
||||
# HNSW benchmark
|
||||
hnsw_results = benchmark_backend(
|
||||
backend_name="hnsw",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable recompute for fair comparison
|
||||
"M": 16,
|
||||
"efConstruction": 200,
|
||||
},
|
||||
)
|
||||
|
||||
# DiskANN benchmark
|
||||
diskann_results = benchmark_backend(
|
||||
backend_name="diskann",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable graph partitioning
|
||||
"num_neighbors": 32,
|
||||
"search_list_size": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# Performance comparison
|
||||
print("\n📈 Performance Comparison Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||
print(f"{'-' * 60}")
|
||||
|
||||
# Build time comparison
|
||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||
print(
|
||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Search time comparison
|
||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||
print(
|
||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Index size comparison
|
||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||
print(
|
||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||
)
|
||||
|
||||
# Score validity
|
||||
print(
|
||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print("\n🎯 Summary:")
|
||||
if search_speedup > 1:
|
||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||
else:
|
||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||
|
||||
if size_ratio > 1:
|
||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||
else:
|
||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||
|
||||
print(
|
||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Handle help request
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||
print()
|
||||
print("Arguments:")
|
||||
print(" n_docs Number of documents to index (default: 500)")
|
||||
print(" n_queries Number of test queries to run (default: 10)")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse command line arguments
|
||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||
print()
|
||||
|
||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Benchmark interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Benchmark failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Ensure clean exit
|
||||
try:
|
||||
gc.collect()
|
||||
print("\n🧹 Cleanup completed")
|
||||
except Exception:
|
||||
pass
|
||||
sys.exit(0)
|
||||
@@ -97,30 +97,16 @@ ollama pull nomic-embed-text
|
||||
```
|
||||
|
||||
### DiskANN
|
||||
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
|
||||
|
||||
**How it works:**
|
||||
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
|
||||
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
|
||||
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
|
||||
|
||||
**Trade-offs compared to HNSW:**
|
||||
- ✅ **Faster search latency** (typically 2-8x speedup)
|
||||
- ✅ **Better scaling** for large datasets
|
||||
- ✅ **Smart storage management** with automatic partitioning
|
||||
- ✅ **Better graph locality** with `--ldg-times` parameter for SSD optimization
|
||||
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
|
||||
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
|
||||
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
|
||||
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
|
||||
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
|
||||
|
||||
```bash
|
||||
# Recommended for most use cases
|
||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||
|
||||
# For large-scale deployments
|
||||
# For billion-scale deployments
|
||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
||||
```
|
||||
|
||||
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||
|
||||
## LLM Selection: Engine and Model Comparison
|
||||
|
||||
### LLM Engines
|
||||
@@ -236,9 +222,15 @@ python apps/document_rag.py --query "What are the main techniques LEANN explores
|
||||
|
||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||
```bash
|
||||
--embedding-mode mlx --embedding-model mlx-community/multilingual-e5-base-mlx
|
||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||
```
|
||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||
|
||||
4. **Use Ollama**
|
||||
```bash
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||
### If Search Quality is Poor
|
||||
|
||||
1. **Increase retrieval count**:
|
||||
@@ -291,4 +283,3 @@ LEANN's recomputation feature provides exact distance calculations but can be di
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||
# DiskANN will handle everything itself, including compiling Python bindings
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
@@ -1,7 +1 @@
|
||||
from . import diskann_backend as diskann_backend
|
||||
from . import graph_partition
|
||||
|
||||
# Export main classes and functions
|
||||
from .graph_partition import GraphPartitioner, partition_graph
|
||||
|
||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||
|
||||
@@ -137,71 +137,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||
"""
|
||||
Safely cleanup files after partition.
|
||||
In partition mode, C++ doesn't read _disk.index content,
|
||||
so we can delete it if all derived files exist.
|
||||
"""
|
||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||
|
||||
# Required files that C++ partition mode needs
|
||||
# Note: C++ generates these with _disk.index suffix
|
||||
disk_suffix = "_disk.index"
|
||||
required_files = [
|
||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||
]
|
||||
|
||||
# Check if all required files exist
|
||||
missing_files = []
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if not file_path.exists():
|
||||
missing_files.append(filename)
|
||||
|
||||
if missing_files:
|
||||
logger.warning(
|
||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||
)
|
||||
logger.info("Keeping all original files for safety")
|
||||
return
|
||||
|
||||
# Calculate space savings
|
||||
space_saved = 0
|
||||
files_to_delete = []
|
||||
|
||||
if disk_index_file.exists():
|
||||
space_saved += disk_index_file.stat().st_size
|
||||
files_to_delete.append(disk_index_file)
|
||||
|
||||
if beam_search_file.exists():
|
||||
space_saved += beam_search_file.stat().st_size
|
||||
files_to_delete.append(beam_search_file)
|
||||
|
||||
# Safe to delete!
|
||||
for file_to_delete in files_to_delete:
|
||||
try:
|
||||
os.remove(file_to_delete)
|
||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||
|
||||
if space_saved > 0:
|
||||
space_saved_mb = space_saved / (1024 * 1024)
|
||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||
|
||||
# Show what files are kept
|
||||
logger.info("📁 Kept essential files for partition mode:")
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -216,17 +151,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
|
||||
# Extract is_recompute from nested backend_kwargs if needed
|
||||
is_recompute = build_kwargs.get("is_recompute", False)
|
||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||
|
||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||
if "backend_kwargs" in build_kwargs:
|
||||
nested_params = build_kwargs.pop("backend_kwargs")
|
||||
build_kwargs.update(nested_params)
|
||||
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
@@ -261,30 +185,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
build_kwargs.get("pq_disk_bytes", 0),
|
||||
"",
|
||||
)
|
||||
|
||||
# Auto-partition if is_recompute is enabled
|
||||
if build_kwargs.get("is_recompute", False):
|
||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||
from .graph_partition import partition_graph
|
||||
|
||||
# Partition the index using absolute paths
|
||||
# Convert to absolute paths to avoid issues with working directory changes
|
||||
absolute_index_dir = Path(index_dir).resolve()
|
||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
index_prefix_path=absolute_index_prefix_path,
|
||||
output_dir=str(absolute_index_dir),
|
||||
partition_prefix=index_prefix,
|
||||
)
|
||||
|
||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||
|
||||
logger.info("✅ Graph partitioning completed successfully!")
|
||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||
logger.info(f" - Partition file: {partition_bin_path}")
|
||||
|
||||
finally:
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
@@ -313,26 +213,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||
# C++ internally constructs: index_prefix + "_disk.index"
|
||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||
|
||||
# Auto-detect partition files and set partition_prefix
|
||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||
|
||||
partition_prefix = ""
|
||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||
# C++ expects full path prefix, not just filename
|
||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
else:
|
||||
logger.debug("No partition files detected, using standard index files")
|
||||
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
@@ -340,14 +221,8 @@ class DiskannSearcher(BaseSearcher):
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
"partition_prefix": "",
|
||||
}
|
||||
|
||||
# Log partition configuration for debugging
|
||||
if partition_prefix:
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
@@ -459,25 +334,3 @@ class DiskannSearcher(BaseSearcher):
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup DiskANN-specific resources including C++ index."""
|
||||
# Call parent cleanup first
|
||||
super().cleanup()
|
||||
|
||||
# Delete the C++ index to trigger destructors
|
||||
try:
|
||||
if hasattr(self, "_index") and self._index is not None:
|
||||
del self._index
|
||||
self._index = None
|
||||
self._current_zmq_port = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection to ensure C++ objects are destroyed
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -81,8 +81,7 @@ def create_diskann_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides Python bindings for the graph partition functionality
|
||||
of DiskANN, allowing users to partition disk-based indices for better
|
||||
performance.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GraphPartitioner:
|
||||
"""
|
||||
A Python interface for DiskANN's graph partition functionality.
|
||||
|
||||
This class provides methods to partition disk-based indices for improved
|
||||
search performance and memory efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, build_type: str = "release"):
|
||||
"""
|
||||
Initialize the GraphPartitioner.
|
||||
|
||||
Args:
|
||||
build_type: Build type for the executables ("debug" or "release")
|
||||
"""
|
||||
self.build_type = build_type
|
||||
self._ensure_executables()
|
||||
|
||||
def _get_executable_path(self, name: str) -> str:
|
||||
"""Get the path to a graph partition executable."""
|
||||
# Get the directory where this Python module is located
|
||||
module_dir = Path(__file__).parent
|
||||
# Navigate to the graph_partition directory
|
||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||
|
||||
if not executable_path.exists():
|
||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||
|
||||
return str(executable_path)
|
||||
|
||||
def _ensure_executables(self):
|
||||
"""Ensure that the required executables are built."""
|
||||
try:
|
||||
self._get_executable_path("partitioner")
|
||||
self._get_executable_path("index_relayout")
|
||||
except FileNotFoundError:
|
||||
# Try to build the executables automatically
|
||||
print("Executables not found, attempting to build them...")
|
||||
self._build_executables()
|
||||
|
||||
def _build_executables(self):
|
||||
"""Build the required executables."""
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Clean any existing build
|
||||
if (graph_partition_dir / "build").exists():
|
||||
shutil.rmtree(graph_partition_dir / "build")
|
||||
|
||||
# Run the build script
|
||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||
|
||||
# Check if executables were created
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
print(f"✅ Built partitioner: {partitioner_path}")
|
||||
print(f"✅ Built index_relayout: {relayout_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to build executables: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def partition_graph(
|
||||
self,
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Partition a disk-based index for improved performance.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning:
|
||||
- gp_times: Number of LDG partition iterations (default: 10)
|
||||
- lock_nums: Number of lock nodes (default: 10)
|
||||
- cut: Cut adjacency list degree (default: 100)
|
||||
- scale_factor: Scale factor (default: 1)
|
||||
- data_type: Data type (default: "float")
|
||||
- thread_nums: Number of threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the partitioning process fails
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine partition prefix
|
||||
if partition_prefix is None:
|
||||
partition_prefix = Path(index_prefix_path).name
|
||||
|
||||
# Get executable paths
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Change to the graph_partition directory for temporary files
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Create temporary data directory
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Run partitioner
|
||||
gp_file_path = graph_gp_path / "_part.bin"
|
||||
partitioner_cmd = [
|
||||
partitioner_path,
|
||||
"--index_file",
|
||||
old_index_file,
|
||||
"--data_type",
|
||||
params["data_type"],
|
||||
"--gp_file",
|
||||
str(gp_file_path),
|
||||
"-T",
|
||||
str(params["thread_nums"]),
|
||||
"--ldg_times",
|
||||
str(params["gp_times"]),
|
||||
"--scale",
|
||||
str(params["scale_factor"]),
|
||||
"--mode",
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||
result = subprocess.run(
|
||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Partitioner failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Run relayout
|
||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||
relayout_cmd = [
|
||||
relayout_path,
|
||||
old_index_file,
|
||||
str(gp_file_path),
|
||||
params["data_type"],
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||
result = subprocess.run(
|
||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Relayout failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Copy results to output directory
|
||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||
|
||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||
shutil.copy2(gp_file_path, partition_bin_path)
|
||||
|
||||
print(f"Results copied to: {output_dir}")
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||
"""
|
||||
Get information about a partition file.
|
||||
|
||||
Args:
|
||||
partition_bin_path: Path to the partition binary file
|
||||
|
||||
Returns:
|
||||
Dictionary containing partition information
|
||||
"""
|
||||
if not os.path.exists(partition_bin_path):
|
||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||
|
||||
# For now, return basic file information
|
||||
# In the future, this could parse the binary file for detailed info
|
||||
stat = os.stat(partition_bin_path)
|
||||
return {
|
||||
"file_size": stat.st_size,
|
||||
"file_path": partition_bin_path,
|
||||
"modified_time": stat.st_mtime,
|
||||
}
|
||||
|
||||
|
||||
def partition_graph(
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
build_type: str = "release",
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Convenience function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
build_type: Build type for executables ("debug" or "release")
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
partitioner = GraphPartitioner(build_type=build_type)
|
||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Example: partition an index
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||
)
|
||||
print("Partitioning completed successfully!")
|
||||
print(f"Disk graph index: {disk_graph_path}")
|
||||
print(f"Partition binary: {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Partitioning failed: {e}")
|
||||
@@ -1,137 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simplified Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides a simple Python interface for graph partitioning
|
||||
that directly calls the existing executables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def partition_graph_simple(
|
||||
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Simple function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Find the graph_partition directory
|
||||
current_file = Path(__file__)
|
||||
graph_partition_dir = current_file.parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
|
||||
if not graph_partition_dir.exists():
|
||||
raise RuntimeError(f"Graph partition directory not found: {graph_partition_dir}")
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run the build script with our parameters
|
||||
cmd = [str(graph_partition_dir / "build.sh"), "release", "split_graph", index_prefix_path]
|
||||
|
||||
# Set environment variables for parameters
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"GP_TIMES": str(params["gp_times"]),
|
||||
"GP_LOCK_NUMS": str(params["lock_nums"]),
|
||||
"GP_CUT": str(params["cut"]),
|
||||
"GP_SCALE_F": str(params["scale_factor"]),
|
||||
"DATA_TYPE": params["data_type"],
|
||||
"GP_T": str(params["thread_nums"]),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Running graph partition with command: {' '.join(cmd)}")
|
||||
print(f"Working directory: {graph_partition_dir}")
|
||||
|
||||
# Run the command
|
||||
result = subprocess.run(
|
||||
cmd, env=env, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Command failed with return code {result.returncode}")
|
||||
print(f"stdout: {result.stdout}")
|
||||
print(f"stderr: {result.stderr}")
|
||||
raise RuntimeError(
|
||||
f"Graph partitioning failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Check if output files were created
|
||||
disk_graph_path = Path(output_dir) / "_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / "_partition.bin"
|
||||
|
||||
if not disk_graph_path.exists():
|
||||
raise RuntimeError(f"Expected output file not found: {disk_graph_path}")
|
||||
|
||||
if not partition_bin_path.exists():
|
||||
raise RuntimeError(f"Expected output file not found: {partition_bin_path}")
|
||||
|
||||
print("✅ Partitioning completed successfully!")
|
||||
print(f" Disk graph index: {disk_graph_path}")
|
||||
print(f" Partition binary: {partition_bin_path}")
|
||||
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph_simple(
|
||||
"/Users/yichuan/Desktop/release2/leann/diskannbuild/test_doc_files",
|
||||
gp_times=5,
|
||||
lock_nums=5,
|
||||
cut=50,
|
||||
)
|
||||
print("Success! Output files:")
|
||||
print(f" - {disk_graph_path}")
|
||||
print(f" - {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.2.5"
|
||||
dependencies = ["leann-core==0.2.5", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.2.8"
|
||||
dependencies = ["leann-core==0.2.8", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
@@ -17,3 +17,5 @@ editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
# Let CMake find packages via Homebrew prefix
|
||||
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: b2dc4ea2c7...04048bb302
@@ -5,11 +5,20 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||
# Detect Homebrew installation path (Apple Silicon vs Intel)
|
||||
if(EXISTS "/opt/homebrew/opt/libomp")
|
||||
set(HOMEBREW_PREFIX "/opt/homebrew")
|
||||
elseif(EXISTS "/usr/local/opt/libomp")
|
||||
set(HOMEBREW_PREFIX "/usr/local")
|
||||
else()
|
||||
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
|
||||
endif()
|
||||
|
||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
|
||||
|
||||
# Force use of system libc++ to avoid version mismatch
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
@@ -34,7 +34,7 @@ if not logger.handlers:
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
passages_file: Union[str, None] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
distance_metric: str = "mips",
|
||||
@@ -82,8 +82,19 @@ def create_hnsw_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Let PassageManager handle path resolution uniformly
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
# Convert relative paths to absolute paths based on metadata file location
|
||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||
passage_sources = []
|
||||
for source in meta["passage_sources"]:
|
||||
source_copy = source.copy()
|
||||
# Convert relative paths to absolute paths
|
||||
if not Path(source_copy["path"]).is_absolute():
|
||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||
if not Path(source_copy["index_path"]).is_absolute():
|
||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||
passage_sources.append(source_copy)
|
||||
|
||||
passages = PassageManager(passage_sources)
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.2.5"
|
||||
version = "0.2.8"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.2.5",
|
||||
"leann-core==0.2.8",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
@@ -22,6 +22,8 @@ cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
# CMake definitions to optimize compilation and find Homebrew packages
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
|
||||
OpenMP_ROOT = {env = "OpenMP_ROOT"}
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: ff22e2c86b...4a2c0d67d3
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.2.5"
|
||||
version = "0.2.8"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -31,8 +31,10 @@ dependencies = [
|
||||
"PyPDF2>=3.0.0",
|
||||
"pymupdf>=1.23.0",
|
||||
"pdfplumber>=0.10.0",
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"nbconvert>=7.0.0", # For .ipynb file support
|
||||
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -87,26 +87,21 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
socket.setsockopt(zmq.IMMEDIATE, 1)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
try:
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
finally:
|
||||
socket.close()
|
||||
# Don't call context.term() - this was causing hangs
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -120,9 +115,7 @@ class SearchResult:
|
||||
|
||||
|
||||
class PassageManager:
|
||||
def __init__(
|
||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||
):
|
||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||
self.offset_maps = {}
|
||||
self.passage_files = {}
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
@@ -132,26 +125,10 @@ class PassageManager:
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
|
||||
# Fix path resolution - relative paths should be relative to metadata file directory
|
||||
# Fix path resolution for Colab and other environments
|
||||
if not Path(index_file).is_absolute():
|
||||
if metadata_file_path:
|
||||
# Resolve relative to metadata file directory
|
||||
metadata_dir = Path(metadata_file_path).parent
|
||||
logger.debug(
|
||||
f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
|
||||
)
|
||||
index_file = str((metadata_dir / index_file).resolve())
|
||||
passage_file = str((metadata_dir / passage_file).resolve())
|
||||
logger.debug(f"PassageManager: Resolved index_file: {index_file}")
|
||||
else:
|
||||
# Fallback to current directory resolution (legacy behavior)
|
||||
logger.warning(
|
||||
"PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
|
||||
)
|
||||
logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
|
||||
index_file = str(Path(index_file).resolve())
|
||||
passage_file = str(Path(passage_file).resolve())
|
||||
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
|
||||
# If relative path, try to resolve it properly
|
||||
index_file = str(Path(index_file).resolve())
|
||||
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
@@ -337,8 +314,8 @@ class LeannBuilder:
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": passages_file.name, # Use relative path (just filename)
|
||||
"index_path": offset_file.name, # Use relative path (just filename)
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -453,8 +430,8 @@ class LeannBuilder:
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": passages_file.name, # Use relative path (just filename)
|
||||
"index_path": offset_file.name, # Use relative path (just filename)
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
}
|
||||
],
|
||||
"built_from_precomputed_embeddings": True,
|
||||
@@ -496,9 +473,7 @@ class LeannSearcher:
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.passage_manager = PassageManager(
|
||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||
)
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
@@ -571,6 +546,7 @@ class LeannSearcher:
|
||||
zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
time.time() - start_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||
|
||||
@@ -611,11 +587,6 @@ class LeannSearcher:
|
||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||
return enriched_results
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup embedding server and other resources."""
|
||||
if hasattr(self.backend_impl, "cleanup"):
|
||||
self.backend_impl.cleanup()
|
||||
|
||||
|
||||
class LeannChat:
|
||||
def __init__(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
@@ -75,11 +76,14 @@ class LeannCLI:
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
leann build my-docs --docs ./documents # Build index from directory
|
||||
leann build my-code --docs ./src ./tests ./config # Build index from multiple directories
|
||||
leann build my-files --docs ./file1.py ./file2.txt ./docs/ # Build index from files and directories
|
||||
leann build my-mixed --docs ./readme.md ./src/ ./config.json # Build index from mixed files/dirs
|
||||
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -87,9 +91,15 @@ Examples:
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--docs", type=str, default=".", help="Documents directory (default: current directory)"
|
||||
"index_name", nargs="?", help="Index name (default: current directory name)"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--docs",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["."],
|
||||
help="Documents directories and/or files (default: current directory)",
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||
@@ -202,6 +212,63 @@ Examples:
|
||||
with open(global_registry, "w") as f:
|
||||
json.dump(projects, f, indent=2)
|
||||
|
||||
def _build_gitignore_parser(self, docs_dir: str):
|
||||
"""Build gitignore parser using gitignore-parser library."""
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
# Try to parse the root .gitignore
|
||||
gitignore_path = Path(docs_dir) / ".gitignore"
|
||||
|
||||
if gitignore_path.exists():
|
||||
try:
|
||||
# gitignore-parser automatically handles all subdirectory .gitignore files!
|
||||
matches = parse_gitignore(str(gitignore_path))
|
||||
print(f"📋 Loaded .gitignore from {docs_dir} (includes all subdirectories)")
|
||||
return matches
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not parse .gitignore: {e}")
|
||||
else:
|
||||
print("📋 No .gitignore found")
|
||||
|
||||
# Fallback: basic pattern matching for essential files
|
||||
essential_patterns = {".git", ".DS_Store", "__pycache__", "node_modules", ".venv", "venv"}
|
||||
|
||||
def basic_matches(file_path):
|
||||
path_parts = Path(file_path).parts
|
||||
return any(part in essential_patterns for part in path_parts)
|
||||
|
||||
return basic_matches
|
||||
|
||||
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
|
||||
"""Check if a file should be excluded using gitignore parser."""
|
||||
return gitignore_matches(str(relative_path))
|
||||
|
||||
def _is_git_submodule(self, path: Path) -> bool:
|
||||
"""Check if a path is a git submodule."""
|
||||
try:
|
||||
# Find the git repo root
|
||||
current_dir = Path.cwd()
|
||||
while current_dir != current_dir.parent:
|
||||
if (current_dir / ".git").exists():
|
||||
gitmodules_path = current_dir / ".gitmodules"
|
||||
if gitmodules_path.exists():
|
||||
# Read .gitmodules to check if this path is a submodule
|
||||
gitmodules_content = gitmodules_path.read_text()
|
||||
# Convert path to relative to git root
|
||||
try:
|
||||
relative_path = path.resolve().relative_to(current_dir)
|
||||
# Check if this path appears in .gitmodules
|
||||
return f"path = {relative_path}" in gitmodules_content
|
||||
except ValueError:
|
||||
# Path is not under git root
|
||||
return False
|
||||
break
|
||||
current_dir = current_dir.parent
|
||||
return False
|
||||
except Exception:
|
||||
# If anything goes wrong, assume it's not a submodule
|
||||
return False
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
@@ -231,7 +298,9 @@ Examples:
|
||||
valid_projects.append(current_path)
|
||||
|
||||
if not valid_projects:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
print(
|
||||
"No indexes found. Use 'leann build <name> --docs <dir> [<dir2> ...]' to create one."
|
||||
)
|
||||
return
|
||||
|
||||
total_indexes = 0
|
||||
@@ -278,41 +347,88 @@ Examples:
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str, custom_file_types: Optional[str] = None):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
def load_documents(
|
||||
self, docs_paths: Union[str, list], custom_file_types: Union[str, None] = None
|
||||
):
|
||||
# Handle both single path (string) and multiple paths (list) for backward compatibility
|
||||
if isinstance(docs_paths, str):
|
||||
docs_paths = [docs_paths]
|
||||
|
||||
# Separate files and directories
|
||||
files = []
|
||||
directories = []
|
||||
for path in docs_paths:
|
||||
path_obj = Path(path)
|
||||
if path_obj.is_file():
|
||||
files.append(str(path_obj))
|
||||
elif path_obj.is_dir():
|
||||
# Check if this is a git submodule - if so, skip it
|
||||
if self._is_git_submodule(path_obj):
|
||||
print(f"⚠️ Skipping git submodule: {path}")
|
||||
continue
|
||||
directories.append(str(path_obj))
|
||||
else:
|
||||
print(f"⚠️ Warning: Path '{path}' does not exist, skipping...")
|
||||
continue
|
||||
|
||||
# Print summary of what we're processing
|
||||
total_items = len(files) + len(directories)
|
||||
items_desc = []
|
||||
if files:
|
||||
items_desc.append(f"{len(files)} file{'s' if len(files) > 1 else ''}")
|
||||
if directories:
|
||||
items_desc.append(
|
||||
f"{len(directories)} director{'ies' if len(directories) > 1 else 'y'}"
|
||||
)
|
||||
|
||||
print(f"Loading documents from {' and '.join(items_desc)} ({total_items} total):")
|
||||
if files:
|
||||
print(f" 📄 Files: {', '.join([Path(f).name for f in files])}")
|
||||
if directories:
|
||||
print(f" 📁 Directories: {', '.join(directories)}")
|
||||
|
||||
if custom_file_types:
|
||||
print(f"Using custom file types: {custom_file_types}")
|
||||
|
||||
# Try to use better PDF parsers first
|
||||
documents = []
|
||||
docs_path = Path(docs_dir)
|
||||
all_documents = []
|
||||
|
||||
for file_path in docs_path.rglob("*.pdf"):
|
||||
print(f"Processing PDF: {file_path}")
|
||||
# First, process individual files if any
|
||||
if files:
|
||||
print(f"\n🔄 Processing {len(files)} individual file{'s' if len(files) > 1 else ''}...")
|
||||
|
||||
# Try PyMuPDF first (best quality)
|
||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||
if text is None:
|
||||
# Try pdfplumber
|
||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||
# Load individual files using SimpleDirectoryReader with input_files
|
||||
# Note: We skip gitignore filtering for explicitly specified files
|
||||
try:
|
||||
# Group files by their parent directory for efficient loading
|
||||
from collections import defaultdict
|
||||
|
||||
if text:
|
||||
# Create a simple document structure
|
||||
from llama_index.core import Document
|
||||
files_by_dir = defaultdict(list)
|
||||
for file_path in files:
|
||||
parent_dir = str(Path(file_path).parent)
|
||||
files_by_dir[parent_dir].append(file_path)
|
||||
|
||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||
documents.append(doc)
|
||||
else:
|
||||
# Fallback to default reader
|
||||
print(f"Using default reader for {file_path}")
|
||||
default_docs = SimpleDirectoryReader(
|
||||
str(file_path.parent),
|
||||
filename_as_id=True,
|
||||
required_exts=[file_path.suffix],
|
||||
).load_data()
|
||||
documents.extend(default_docs)
|
||||
# Load files from each parent directory
|
||||
for parent_dir, file_list in files_by_dir.items():
|
||||
print(
|
||||
f" Loading {len(file_list)} file{'s' if len(file_list) > 1 else ''} from {parent_dir}"
|
||||
)
|
||||
try:
|
||||
file_docs = SimpleDirectoryReader(
|
||||
parent_dir,
|
||||
input_files=file_list,
|
||||
filename_as_id=True,
|
||||
).load_data()
|
||||
all_documents.extend(file_docs)
|
||||
print(
|
||||
f" ✅ Loaded {len(file_docs)} document{'s' if len(file_docs) > 1 else ''}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f" ❌ Warning: Could not load files from {parent_dir}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing individual files: {e}")
|
||||
|
||||
# Define file extensions to process
|
||||
if custom_file_types:
|
||||
# Parse custom file types from comma-separated string
|
||||
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
|
||||
@@ -374,20 +490,106 @@ Examples:
|
||||
".py",
|
||||
".jl",
|
||||
]
|
||||
# Try to load other file types, but don't fail if none are found
|
||||
try:
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
).load_data(show_progress=True)
|
||||
documents.extend(other_docs)
|
||||
except ValueError as e:
|
||||
if "No files found" in str(e):
|
||||
print("No additional files found for other supported types.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Process each directory
|
||||
if directories:
|
||||
print(
|
||||
f"\n🔄 Processing {len(directories)} director{'ies' if len(directories) > 1 else 'y'}..."
|
||||
)
|
||||
|
||||
for docs_dir in directories:
|
||||
print(f"Processing directory: {docs_dir}")
|
||||
# Build gitignore parser for each directory
|
||||
gitignore_matches = self._build_gitignore_parser(docs_dir)
|
||||
|
||||
# Try to use better PDF parsers first, but only if PDFs are requested
|
||||
documents = []
|
||||
docs_path = Path(docs_dir)
|
||||
|
||||
# Check if we should process PDFs
|
||||
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
||||
|
||||
if should_process_pdfs:
|
||||
for file_path in docs_path.rglob("*.pdf"):
|
||||
# Check if file matches any exclude pattern
|
||||
try:
|
||||
relative_path = file_path.relative_to(docs_path)
|
||||
if self._should_exclude_file(relative_path, gitignore_matches):
|
||||
continue
|
||||
except ValueError:
|
||||
# Skip files that can't be made relative to docs_path
|
||||
print(f"⚠️ Skipping file outside directory scope: {file_path}")
|
||||
continue
|
||||
|
||||
print(f"Processing PDF: {file_path}")
|
||||
|
||||
# Try PyMuPDF first (best quality)
|
||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||
if text is None:
|
||||
# Try pdfplumber
|
||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||
|
||||
if text:
|
||||
# Create a simple document structure
|
||||
from llama_index.core import Document
|
||||
|
||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||
documents.append(doc)
|
||||
else:
|
||||
# Fallback to default reader
|
||||
print(f"Using default reader for {file_path}")
|
||||
try:
|
||||
default_docs = SimpleDirectoryReader(
|
||||
str(file_path.parent),
|
||||
filename_as_id=True,
|
||||
required_exts=[file_path.suffix],
|
||||
).load_data()
|
||||
documents.extend(default_docs)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process {file_path}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
try:
|
||||
# Create a custom file filter function using our PathSpec
|
||||
def file_filter(
|
||||
file_path: str, docs_dir=docs_dir, gitignore_matches=gitignore_matches
|
||||
) -> bool:
|
||||
"""Return True if file should be included (not excluded)"""
|
||||
try:
|
||||
docs_path_obj = Path(docs_dir)
|
||||
file_path_obj = Path(file_path)
|
||||
relative_path = file_path_obj.relative_to(docs_path_obj)
|
||||
return not self._should_exclude_file(relative_path, gitignore_matches)
|
||||
except (ValueError, OSError):
|
||||
return True # Include files that can't be processed
|
||||
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Filter documents after loading based on gitignore rules
|
||||
filtered_docs = []
|
||||
for doc in other_docs:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents.extend(filtered_docs)
|
||||
except ValueError as e:
|
||||
if "No files found" in str(e):
|
||||
print(f"No additional files found for other supported types in {docs_dir}.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
all_documents.extend(documents)
|
||||
print(f"Loaded {len(documents)} documents from {docs_dir}")
|
||||
|
||||
documents = all_documents
|
||||
|
||||
all_texts = []
|
||||
|
||||
@@ -438,7 +640,9 @@ Examples:
|
||||
".jl",
|
||||
}
|
||||
|
||||
for doc in documents:
|
||||
print("start chunking documents")
|
||||
# Add progress bar for document chunking
|
||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||
# Check if this is a code file based on source path
|
||||
source_path = doc.metadata.get("source", "")
|
||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||
@@ -454,18 +658,36 @@ Examples:
|
||||
return all_texts
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
docs_paths = args.docs
|
||||
# Use current directory name if index_name not provided
|
||||
if args.index_name:
|
||||
index_name = args.index_name
|
||||
else:
|
||||
index_name = Path.cwd().name
|
||||
print(f"Using current directory name as index: '{index_name}'")
|
||||
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
print(f"📂 Indexing: {Path(docs_dir).resolve()}")
|
||||
# Display all paths being indexed with file/directory distinction
|
||||
files = [p for p in docs_paths if Path(p).is_file()]
|
||||
directories = [p for p in docs_paths if Path(p).is_dir()]
|
||||
|
||||
print(f"📂 Indexing {len(docs_paths)} path{'s' if len(docs_paths) > 1 else ''}:")
|
||||
if files:
|
||||
print(f" 📄 Files ({len(files)}):")
|
||||
for i, file_path in enumerate(files, 1):
|
||||
print(f" {i}. {Path(file_path).resolve()}")
|
||||
if directories:
|
||||
print(f" 📁 Directories ({len(directories)}):")
|
||||
for i, dir_path in enumerate(directories, 1):
|
||||
print(f" {i}. {Path(dir_path).resolve()}")
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir, args.file_types)
|
||||
all_texts = self.load_documents(docs_paths, args.file_types)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
@@ -501,7 +723,7 @@ Examples:
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -528,7 +750,7 @@ Examples:
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ Preserves all optimization parameters to ensure performance
|
||||
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -374,7 +373,9 @@ def compute_embeddings_ollama(
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API.
|
||||
Compute embeddings using Ollama API with simplified batch processing.
|
||||
|
||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
@@ -438,12 +439,19 @@ def compute_embeddings_ollama(
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
||||
embedding_models.append(model)
|
||||
|
||||
# Check if model exists (handle versioned names)
|
||||
model_found = any(
|
||||
model_name == name.split(":")[0] or model_name == name for name in model_names
|
||||
)
|
||||
# Check if model exists (handle versioned names) and resolve to full name
|
||||
resolved_model_name = None
|
||||
for name in model_names:
|
||||
# Exact match
|
||||
if model_name == name:
|
||||
resolved_model_name = name
|
||||
break
|
||||
# Match without version tag (use the versioned name)
|
||||
elif model_name == name.split(":")[0]:
|
||||
resolved_model_name = name
|
||||
break
|
||||
|
||||
if not model_found:
|
||||
if not resolved_model_name:
|
||||
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
||||
|
||||
# Suggest pulling the model
|
||||
@@ -465,6 +473,11 @@ def compute_embeddings_ollama(
|
||||
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Use the resolved model name for all subsequent operations
|
||||
if resolved_model_name != model_name:
|
||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||
model_name = resolved_model_name
|
||||
|
||||
# Verify the model supports embeddings by testing it
|
||||
try:
|
||||
test_response = requests.post(
|
||||
@@ -485,138 +498,148 @@ def compute_embeddings_ollama(
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Could not verify model existence: {e}")
|
||||
|
||||
# Process embeddings with optimized concurrent processing
|
||||
import requests
|
||||
# Determine batch size based on device availability
|
||||
# Check for CUDA/MPS availability using torch if available
|
||||
batch_size = 32 # Default for MPS/CPU
|
||||
try:
|
||||
import torch
|
||||
|
||||
def get_single_embedding(text_idx_tuple):
|
||||
"""Helper function to get embedding for a single text."""
|
||||
text, idx = text_idx_tuple
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
if torch.cuda.is_available():
|
||||
batch_size = 128 # CUDA gets larger batch size
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
batch_size = 32 # MPS gets smaller batch size
|
||||
except ImportError:
|
||||
# If torch is not available, use conservative batch size
|
||||
batch_size = 32
|
||||
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||
logger.info(f"Using batch size: {batch_size}")
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
def get_batch_embeddings(batch_texts):
|
||||
"""Get embeddings for a batch of texts."""
|
||||
all_embeddings = []
|
||||
failed_indices = []
|
||||
|
||||
result = response.json()
|
||||
embedding = result.get("embedding")
|
||||
for i, text in enumerate(batch_texts):
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError(f"No embedding returned for text {idx}")
|
||||
|
||||
return idx, embedding
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
|
||||
return idx, None
|
||||
|
||||
except Exception as e:
|
||||
if retry_count >= max_retries - 1:
|
||||
logger.error(f"Failed to get embedding for text {idx}: {e}")
|
||||
return idx, None
|
||||
retry_count += 1
|
||||
|
||||
return idx, None
|
||||
|
||||
# Determine if we should use concurrent processing
|
||||
use_concurrent = (
|
||||
len(texts) > 5 and not is_build
|
||||
) # Don't use concurrent in build mode to avoid overwhelming
|
||||
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
|
||||
|
||||
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
|
||||
failed_indices = []
|
||||
|
||||
if use_concurrent:
|
||||
logger.info(
|
||||
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_idx = {
|
||||
executor.submit(get_single_embedding, (text, idx)): idx
|
||||
for idx, text in enumerate(texts)
|
||||
}
|
||||
|
||||
# Add progress bar for concurrent processing
|
||||
try:
|
||||
if is_build or len(texts) > 10:
|
||||
from tqdm import tqdm
|
||||
|
||||
futures_iterator = tqdm(
|
||||
as_completed(future_to_idx),
|
||||
total=len(texts),
|
||||
desc="Computing Ollama embeddings",
|
||||
)
|
||||
else:
|
||||
futures_iterator = as_completed(future_to_idx)
|
||||
except ImportError:
|
||||
futures_iterator = as_completed(future_to_idx)
|
||||
|
||||
# Collect results as they complete
|
||||
for future in futures_iterator:
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
idx, embedding = future.result()
|
||||
if embedding is not None:
|
||||
all_embeddings[idx] = embedding
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
response = requests.post(
|
||||
f"{host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result.get("embedding")
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError(f"No embedding returned for text {i}")
|
||||
|
||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
||||
raise ValueError(f"Invalid embedding format for text {i}")
|
||||
|
||||
all_embeddings.append(embedding)
|
||||
break
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
||||
failed_indices.append(i)
|
||||
all_embeddings.append(None)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
idx = future_to_idx[future]
|
||||
logger.error(f"Exception for text {idx}: {e}")
|
||||
failed_indices.append(idx)
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
||||
failed_indices.append(i)
|
||||
all_embeddings.append(None)
|
||||
break
|
||||
return all_embeddings, failed_indices
|
||||
|
||||
# Process texts in batches
|
||||
all_embeddings = []
|
||||
all_failed_indices = []
|
||||
|
||||
# Setup progress bar if needed
|
||||
show_progress = is_build or len(texts) > 10
|
||||
try:
|
||||
if show_progress:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
show_progress = False
|
||||
|
||||
# Process batches
|
||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
if show_progress:
|
||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
||||
else:
|
||||
# Sequential processing with progress bar
|
||||
show_progress = is_build or len(texts) > 10
|
||||
batch_iterator = range(num_batches)
|
||||
|
||||
try:
|
||||
if show_progress:
|
||||
from tqdm import tqdm
|
||||
for batch_idx in batch_iterator:
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(texts))
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
|
||||
iterator = tqdm(
|
||||
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
|
||||
)
|
||||
else:
|
||||
iterator = enumerate(texts)
|
||||
except ImportError:
|
||||
iterator = enumerate(texts)
|
||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||
|
||||
for idx, text in iterator:
|
||||
result_idx, embedding = get_single_embedding((text, idx))
|
||||
if embedding is not None:
|
||||
all_embeddings[idx] = embedding
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
# Adjust failed indices to global indices
|
||||
global_failed = [start_idx + idx for idx in batch_failed]
|
||||
all_failed_indices.extend(global_failed)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Handle failed embeddings
|
||||
if failed_indices:
|
||||
if len(failed_indices) == len(texts):
|
||||
if all_failed_indices:
|
||||
if len(all_failed_indices) == len(texts):
|
||||
raise RuntimeError("Failed to compute any embeddings")
|
||||
|
||||
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
|
||||
logger.warning(
|
||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
||||
)
|
||||
|
||||
# Use zero embeddings as fallback for failed ones
|
||||
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
||||
if valid_embedding:
|
||||
embedding_dim = len(valid_embedding)
|
||||
for idx in failed_indices:
|
||||
all_embeddings[idx] = [0.0] * embedding_dim
|
||||
for i, embedding in enumerate(all_embeddings):
|
||||
if embedding is None:
|
||||
all_embeddings[i] = [0.0] * embedding_dim
|
||||
|
||||
# Remove None values and convert to numpy array
|
||||
# Remove None values
|
||||
all_embeddings = [e for e in all_embeddings if e is not None]
|
||||
|
||||
if not all_embeddings:
|
||||
raise RuntimeError("No valid embeddings were computed")
|
||||
|
||||
# Validate embedding dimensions
|
||||
expected_dim = len(all_embeddings[0])
|
||||
inconsistent_dims = []
|
||||
for i, embedding in enumerate(all_embeddings):
|
||||
if len(embedding) != expected_dim:
|
||||
inconsistent_dims.append((i, len(embedding)))
|
||||
|
||||
if inconsistent_dims:
|
||||
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
|
||||
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
|
||||
error_msg += f" - Text {idx}: {dim} dimensions\n"
|
||||
if len(inconsistent_dims) > 10:
|
||||
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
|
||||
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
|
||||
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
|
||||
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
|
||||
error_msg += (
|
||||
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
|
||||
)
|
||||
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Convert to numpy array and normalize
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -312,7 +311,6 @@ class EmbeddingServerManager:
|
||||
cwd=project_root,
|
||||
stdout=None, # Direct to console
|
||||
stderr=None, # Direct to console
|
||||
start_new_session=True, # Create new process group for better cleanup
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
@@ -354,14 +352,7 @@ class EmbeddingServerManager:
|
||||
logger.info(
|
||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
|
||||
# Try terminating the whole process group first
|
||||
try:
|
||||
pgid = os.getpgid(self.server_process.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception:
|
||||
# Fallback to terminating just the process
|
||||
self.server_process.terminate()
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=3)
|
||||
@@ -370,13 +361,7 @@ class EmbeddingServerManager:
|
||||
logger.warning(
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
||||
)
|
||||
# Try killing the whole process group
|
||||
try:
|
||||
pgid = os.getpgid(self.server_process.pid)
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except Exception:
|
||||
# Fallback to killing just the process
|
||||
self.server_process.kill()
|
||||
self.server_process.kill()
|
||||
try:
|
||||
self.server_process.wait(timeout=2)
|
||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
||||
@@ -388,12 +373,7 @@ class EmbeddingServerManager:
|
||||
|
||||
# Clean up process resources to prevent resource tracker warnings
|
||||
try:
|
||||
self.server_process.wait(timeout=1) # Give it one final chance with timeout
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
f"Process {self.server_process.pid} still hanging after all kill attempts"
|
||||
)
|
||||
# Don't wait indefinitely - just abandon it
|
||||
self.server_process.wait() # Ensure process is fully cleaned up
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -35,7 +35,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
self, passages_source_file: str, port: Union[int, None], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
@@ -50,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: Optional[int] = None,
|
||||
zmq_port: Union[int, None] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -76,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: Optional[int] = None,
|
||||
zmq_port: Union[int, None] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
|
||||
@@ -25,32 +25,61 @@ def handle_request(request):
|
||||
"tools": [
|
||||
{
|
||||
"name": "leann_search",
|
||||
"description": "Search LEANN index",
|
||||
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
|
||||
|
||||
🎯 **Perfect for**:
|
||||
- "How does authentication work?" → finds auth-related code
|
||||
- "Error handling patterns" → locates try-catch blocks and error logic
|
||||
- "Database connection setup" → finds DB initialization code
|
||||
- "API endpoint definitions" → locates route handlers
|
||||
- "Configuration management" → finds config files and usage
|
||||
|
||||
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_name": {"type": "string"},
|
||||
"query": {"type": "string"},
|
||||
"top_k": {"type": "integer", "default": 5},
|
||||
"index_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
|
||||
},
|
||||
"complexity": {
|
||||
"type": "integer",
|
||||
"default": 32,
|
||||
"minimum": 16,
|
||||
"maximum": 128,
|
||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||
},
|
||||
},
|
||||
"required": ["index_name", "query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leann_ask",
|
||||
"description": "Ask question using LEANN RAG",
|
||||
"name": "leann_status",
|
||||
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_name": {"type": "string"},
|
||||
"question": {"type": "string"},
|
||||
"index_name": {
|
||||
"type": "string",
|
||||
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
|
||||
}
|
||||
},
|
||||
"required": ["index_name", "question"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leann_list",
|
||||
"description": "List all LEANN indexes",
|
||||
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
@@ -63,19 +92,41 @@ def handle_request(request):
|
||||
|
||||
try:
|
||||
if tool_name == "leann_search":
|
||||
# Validate required parameters
|
||||
if not args.get("index_name") or not args.get("query"):
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Error: Both index_name and query are required",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Build simplified command
|
||||
cmd = [
|
||||
"leann",
|
||||
"search",
|
||||
args["index_name"],
|
||||
args["query"],
|
||||
"--recompute-embeddings",
|
||||
f"--top-k={args.get('top_k', 5)}",
|
||||
f"--complexity={args.get('complexity', 32)}",
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_ask":
|
||||
cmd = f'echo "{args["question"]}" | leann ask {args["index_name"]} --recompute-embeddings --llm ollama --model qwen3:8b'
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
elif tool_name == "leann_status":
|
||||
if args.get("index_name"):
|
||||
# Check specific index status - for now, we'll use leann list and filter
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
# We could enhance this to show more detailed status per index
|
||||
else:
|
||||
# Show all indexes status
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_list":
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
|
||||
@@ -132,15 +132,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
context = None
|
||||
socket = None
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
socket.setsockopt(zmq.IMMEDIATE, 1)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
|
||||
# Send embedding request
|
||||
@@ -152,6 +147,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Convert response to numpy array
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return np.array(response, dtype=np.float32)
|
||||
@@ -160,10 +158,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||
finally:
|
||||
if socket:
|
||||
socket.close()
|
||||
# Don't call context.term() - this was causing hangs
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
@@ -197,15 +191,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources including embedding server."""
|
||||
def __del__(self):
|
||||
"""Ensures the embedding server is stopped when the searcher is destroyed."""
|
||||
if hasattr(self, "embedding_server_manager"):
|
||||
self.embedding_server_manager.stop_server()
|
||||
|
||||
def __del__(self):
|
||||
"""Ensures resources are cleaned up when the searcher is destroyed."""
|
||||
try:
|
||||
self.cleanup()
|
||||
except Exception:
|
||||
# Ignore errors during destruction
|
||||
pass
|
||||
|
||||
@@ -45,6 +45,42 @@ leann build my-project --docs ./
|
||||
claude
|
||||
```
|
||||
|
||||
## 🚀 Advanced Usage Examples
|
||||
|
||||
### Index Entire Git Repository
|
||||
```bash
|
||||
# Index all tracked files in your git repository, note right now we will skip submodules, but we can add it back easily if you want
|
||||
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
|
||||
# Index only specific file types from git
|
||||
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
```
|
||||
|
||||
### Multiple Directories and Files
|
||||
```bash
|
||||
# Index multiple directories
|
||||
leann build my-codebase --docs ./src ./tests ./docs ./config --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
|
||||
# Mix files and directories
|
||||
leann build my-project --docs ./README.md ./src/ ./package.json ./docs/ --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
|
||||
# Specific files only
|
||||
leann build my-configs --docs ./tsconfig.json ./package.json ./webpack.config.js --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
```
|
||||
|
||||
### Advanced Git Integration
|
||||
```bash
|
||||
# Index recently modified files
|
||||
leann build recent-changes --docs $(git diff --name-only HEAD~10..HEAD) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
|
||||
# Index files matching pattern
|
||||
leann build frontend --docs $(git ls-files "*.tsx" "*.ts" "*.jsx" "*.js") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
|
||||
# Index documentation and config files
|
||||
leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.json" "*.toml") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||
```
|
||||
|
||||
|
||||
**Try this in Claude Code:**
|
||||
```
|
||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.2.5"
|
||||
version = "0.2.8"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@@ -40,10 +40,12 @@ dependencies = [
|
||||
# Other dependencies
|
||||
"ipykernel==6.29.5",
|
||||
"msgpack>=1.1.1",
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||
"psutil>=5.8.0",
|
||||
"pybind11>=3.0.0",
|
||||
"pathspec>=0.12.1",
|
||||
"nbconvert>=7.16.6",
|
||||
"gitignore-parser>=0.1.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -52,7 +54,7 @@ dev = [
|
||||
"pytest-cov>=4.0",
|
||||
"pytest-xdist>=3.0", # For parallel test execution
|
||||
"black>=23.0",
|
||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||
"ruff>=0.1.0",
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"pre-commit>=3.5.0",
|
||||
@@ -60,7 +62,7 @@ dev = [
|
||||
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"pytest-timeout>=2.0", # Simple timeout protection for CI
|
||||
"pytest-timeout>=2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
@@ -152,7 +154,7 @@ markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"openai: marks tests that require OpenAI API key",
|
||||
]
|
||||
timeout = 300 # Simple timeout for CI safety (5 minutes)
|
||||
timeout = 600
|
||||
addopts = [
|
||||
"-v",
|
||||
"--tb=short",
|
||||
|
||||
@@ -6,11 +6,10 @@ This directory contains automated tests for the LEANN project using pytest.
|
||||
|
||||
### `test_readme_examples.py`
|
||||
Tests the examples shown in README.md:
|
||||
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
||||
- The basic example code that users see first
|
||||
- Import statements work correctly
|
||||
- Different backend options (HNSW, DiskANN)
|
||||
- Different LLM configuration options (parametrized for both backends)
|
||||
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
||||
- Different LLM configuration options
|
||||
|
||||
### `test_basic.py`
|
||||
Basic functionality tests that verify:
|
||||
@@ -26,16 +25,6 @@ Tests the document RAG example functionality:
|
||||
- Tests error handling with invalid parameters
|
||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||
|
||||
### `test_diskann_partition.py`
|
||||
Tests DiskANN graph partitioning functionality:
|
||||
- Tests DiskANN index building without partitioning (baseline)
|
||||
- Tests automatic graph partitioning with `is_recompute=True`
|
||||
- Verifies that partition files are created and large files are cleaned up for storage saving
|
||||
- Tests search functionality with partitioned indices
|
||||
- Validates medoid and max_base_norm file generation and usage
|
||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Install test dependencies:
|
||||
@@ -65,23 +54,15 @@ pytest tests/ -m "not openai"
|
||||
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Run DiskANN partition tests (requires local machine, not CI)
|
||||
pytest tests/test_diskann_partition.py
|
||||
```
|
||||
|
||||
### Run with specific backend:
|
||||
```bash
|
||||
# Test only HNSW backend
|
||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
||||
|
||||
# Test only DiskANN backend
|
||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
||||
|
||||
# All DiskANN tests (parametrized + specialized partition tests)
|
||||
pytest tests/ -k diskann
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Pytest configuration and fixtures for LEANN tests."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def test_environment():
|
||||
"""Set up test environment variables."""
|
||||
# Mark as test environment to skip memory-intensive operations
|
||||
os.environ["CI"] = "true"
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def cleanup_session():
|
||||
"""Session-level cleanup to ensure no hanging processes."""
|
||||
yield
|
||||
|
||||
# Basic cleanup after all tests
|
||||
try:
|
||||
import os
|
||||
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
children = current_process.children(recursive=True)
|
||||
|
||||
for child in children:
|
||||
try:
|
||||
child.terminate()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
# Give them time to terminate gracefully
|
||||
psutil.wait_procs(children, timeout=3)
|
||||
|
||||
except Exception:
|
||||
# Don't fail tests due to cleanup errors
|
||||
pass
|
||||
@@ -1,369 +0,0 @@
|
||||
"""
|
||||
Test DiskANN graph partitioning functionality.
|
||||
|
||||
Tests the automatic graph partitioning feature that was implemented to save
|
||||
storage space by partitioning large DiskANN indices and safely deleting
|
||||
redundant files while maintaining search functionality.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_without_partition():
|
||||
"""Test DiskANN index building without partition (baseline)."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_no_partition.leann")
|
||||
|
||||
# Test data - enough to trigger index building
|
||||
texts = [
|
||||
f"Document {i} discusses topic {i % 10} with detailed analysis of subject {i // 10}."
|
||||
for i in range(500)
|
||||
]
|
||||
|
||||
# Build without partition (is_recompute=False)
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
is_recompute=False, # No partition
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Verify index was created
|
||||
index_dir = Path(index_path).parent
|
||||
assert index_dir.exists()
|
||||
|
||||
# Check that traditional DiskANN files exist
|
||||
index_prefix = Path(index_path).stem
|
||||
# Core DiskANN files (beam search index may not be created for small datasets)
|
||||
required_files = [
|
||||
f"{index_prefix}_disk.index",
|
||||
f"{index_prefix}_pq_compressed.bin",
|
||||
f"{index_prefix}_pq_pivots.bin",
|
||||
]
|
||||
|
||||
# Check all generated files first for debugging
|
||||
generated_files = [f.name for f in index_dir.glob(f"{index_prefix}*")]
|
||||
print(f"Generated files: {generated_files}")
|
||||
|
||||
for required_file in required_files:
|
||||
file_path = index_dir / required_file
|
||||
assert file_path.exists(), f"Required file {required_file} not found"
|
||||
|
||||
# Ensure no partition files exist in non-partition mode
|
||||
partition_files = [f"{index_prefix}_disk_graph.index", f"{index_prefix}_partition.bin"]
|
||||
|
||||
for partition_file in partition_files:
|
||||
file_path = index_dir / partition_file
|
||||
assert not file_path.exists(), (
|
||||
f"Partition file {partition_file} should not exist in non-partition mode"
|
||||
)
|
||||
|
||||
# Test search functionality
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("topic 3 analysis", top_k=3)
|
||||
|
||||
assert len(results) > 0
|
||||
assert all(result.score is not None and result.score != float("-inf") for result in results)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_with_partition():
|
||||
"""Test DiskANN index building with automatic graph partitioning."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_with_partition.leann")
|
||||
|
||||
# Test data - enough to trigger partitioning
|
||||
texts = [
|
||||
f"Document {i} explores subject {i % 15} with comprehensive coverage of area {i // 15}."
|
||||
for i in range(500)
|
||||
]
|
||||
|
||||
# Build with partition (is_recompute=True)
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
is_recompute=True, # Enable automatic partitioning
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Verify index was created
|
||||
index_dir = Path(index_path).parent
|
||||
assert index_dir.exists()
|
||||
|
||||
# Check that partition files exist
|
||||
index_prefix = Path(index_path).stem
|
||||
partition_files = [
|
||||
f"{index_prefix}_disk_graph.index", # Partitioned graph
|
||||
f"{index_prefix}_partition.bin", # Partition metadata
|
||||
f"{index_prefix}_pq_compressed.bin",
|
||||
f"{index_prefix}_pq_pivots.bin",
|
||||
]
|
||||
|
||||
for partition_file in partition_files:
|
||||
file_path = index_dir / partition_file
|
||||
assert file_path.exists(), f"Expected partition file {partition_file} not found"
|
||||
|
||||
# Check that large files were cleaned up (storage saving goal)
|
||||
large_files = [f"{index_prefix}_disk.index", f"{index_prefix}_disk_beam_search.index"]
|
||||
|
||||
for large_file in large_files:
|
||||
file_path = index_dir / large_file
|
||||
assert not file_path.exists(), (
|
||||
f"Large file {large_file} should have been deleted for storage saving"
|
||||
)
|
||||
|
||||
# Verify required auxiliary files for partition mode exist
|
||||
required_files = [
|
||||
f"{index_prefix}_disk.index_medoids.bin",
|
||||
f"{index_prefix}_disk.index_max_base_norm.bin",
|
||||
]
|
||||
|
||||
for req_file in required_files:
|
||||
file_path = index_dir / req_file
|
||||
assert file_path.exists(), (
|
||||
f"Required auxiliary file {req_file} missing for partition mode"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_partition_search_functionality():
|
||||
"""Test that search works correctly with partitioned indices."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_partition_search.leann")
|
||||
|
||||
# Create diverse test data
|
||||
texts = [
|
||||
"LEANN is a storage-efficient approximate nearest neighbor search system.",
|
||||
"Graph partitioning helps reduce memory usage in large scale vector search.",
|
||||
"DiskANN provides high-performance disk-based approximate nearest neighbor search.",
|
||||
"Vector embeddings enable semantic search over unstructured text data.",
|
||||
"Approximate nearest neighbor algorithms trade accuracy for speed and storage.",
|
||||
] * 100 # Repeat to get enough data
|
||||
|
||||
# Build with partitioning
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True, # Enable partitioning
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test search with partitioned index
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test various queries
|
||||
test_queries = [
|
||||
("vector search algorithms", 5),
|
||||
("LEANN storage efficiency", 3),
|
||||
("graph partitioning memory", 4),
|
||||
("approximate nearest neighbor", 7),
|
||||
]
|
||||
|
||||
for query, top_k in test_queries:
|
||||
results = searcher.search(query, top_k=top_k)
|
||||
|
||||
# Verify search results
|
||||
assert len(results) == top_k, f"Expected {top_k} results for query '{query}'"
|
||||
assert all(result.score is not None for result in results), (
|
||||
"All results should have scores"
|
||||
)
|
||||
assert all(result.score != float("-inf") for result in results), (
|
||||
"No result should have -inf score"
|
||||
)
|
||||
assert all(result.text is not None for result in results), (
|
||||
"All results should have text"
|
||||
)
|
||||
|
||||
# Scores should be in descending order (higher similarity first)
|
||||
scores = [result.score for result in results]
|
||||
assert scores == sorted(scores, reverse=True), (
|
||||
"Results should be sorted by score descending"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_medoid_and_norm_files():
|
||||
"""Test that medoid and max_base_norm files are correctly generated and used."""
|
||||
import struct
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_medoid_norm.leann")
|
||||
|
||||
# Small but sufficient dataset
|
||||
texts = [f"Test document {i} with content about subject {i % 10}." for i in range(200)]
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
index_dir = Path(index_path).parent
|
||||
index_prefix = Path(index_path).stem
|
||||
|
||||
# Test medoids file
|
||||
medoids_file = index_dir / f"{index_prefix}_disk.index_medoids.bin"
|
||||
assert medoids_file.exists(), "Medoids file should be generated"
|
||||
|
||||
# Read and validate medoids file format
|
||||
with open(medoids_file, "rb") as f:
|
||||
nshards = struct.unpack("<I", f.read(4))[0]
|
||||
one_val = struct.unpack("<I", f.read(4))[0]
|
||||
medoid_id = struct.unpack("<I", f.read(4))[0]
|
||||
|
||||
assert nshards == 1, "Single-shot build should have 1 shard"
|
||||
assert one_val == 1, "Expected value should be 1"
|
||||
assert medoid_id >= 0, "Medoid ID should be valid (not hardcoded 0)"
|
||||
|
||||
# Test max_base_norm file
|
||||
norm_file = index_dir / f"{index_prefix}_disk.index_max_base_norm.bin"
|
||||
assert norm_file.exists(), "Max base norm file should be generated"
|
||||
|
||||
# Read and validate norm file
|
||||
with open(norm_file, "rb") as f:
|
||||
npts = struct.unpack("<I", f.read(4))[0]
|
||||
ndims = struct.unpack("<I", f.read(4))[0]
|
||||
norm_val = struct.unpack("<f", f.read(4))[0]
|
||||
|
||||
assert npts == 1, "Should have 1 norm point"
|
||||
assert ndims == 1, "Should have 1 dimension"
|
||||
assert norm_val > 0, "Norm value should be positive"
|
||||
assert norm_val != float("inf"), "Norm value should be finite"
|
||||
|
||||
# Test that search works with these files
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("test subject", top_k=3)
|
||||
|
||||
# Verify that scores are not -inf (which indicates norm file was loaded correctly)
|
||||
assert len(results) > 0
|
||||
assert all(result.score != float("-inf") for result in results), (
|
||||
"Scores should not be -inf when norm file is correct"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip performance comparison in CI - requires significant compute time",
|
||||
)
|
||||
def test_diskann_vs_hnsw_performance():
|
||||
"""Compare DiskANN (with partition) vs HNSW performance."""
|
||||
import time
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Test data
|
||||
texts = [
|
||||
f"Performance test document {i} covering topic {i % 20} in detail." for i in range(1000)
|
||||
]
|
||||
query = "performance topic test"
|
||||
|
||||
# Test DiskANN with partitioning
|
||||
diskann_path = str(Path(temp_dir) / "perf_diskann.leann")
|
||||
diskann_builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
diskann_builder.add_text(text)
|
||||
|
||||
start_time = time.time()
|
||||
diskann_builder.build_index(diskann_path)
|
||||
|
||||
# Test HNSW
|
||||
hnsw_path = str(Path(temp_dir) / "perf_hnsw.leann")
|
||||
hnsw_builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
hnsw_builder.add_text(text)
|
||||
|
||||
start_time = time.time()
|
||||
hnsw_builder.build_index(hnsw_path)
|
||||
|
||||
# Compare search performance
|
||||
diskann_searcher = LeannSearcher(diskann_path)
|
||||
hnsw_searcher = LeannSearcher(hnsw_path)
|
||||
|
||||
# Warm up searches
|
||||
diskann_searcher.search(query, top_k=5)
|
||||
hnsw_searcher.search(query, top_k=5)
|
||||
|
||||
# Timed searches
|
||||
start_time = time.time()
|
||||
diskann_results = diskann_searcher.search(query, top_k=10)
|
||||
diskann_search_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
hnsw_results = hnsw_searcher.search(query, top_k=10)
|
||||
hnsw_search_time = time.time() - start_time
|
||||
|
||||
# Basic assertions
|
||||
assert len(diskann_results) == 10
|
||||
assert len(hnsw_results) == 10
|
||||
assert all(r.score != float("-inf") for r in diskann_results)
|
||||
assert all(r.score != float("-inf") for r in hnsw_results)
|
||||
|
||||
# Performance ratio (informational)
|
||||
if hnsw_search_time > 0:
|
||||
speed_ratio = hnsw_search_time / diskann_search_time
|
||||
print(f"DiskANN search time: {diskann_search_time:.4f}s")
|
||||
print(f"HNSW search time: {hnsw_search_time:.4f}s")
|
||||
print(f"DiskANN is {speed_ratio:.2f}x faster than HNSW")
|
||||
@@ -10,9 +10,8 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
def test_readme_basic_example(backend_name):
|
||||
"""Test the basic example from README.md with both backends."""
|
||||
def test_readme_basic_example():
|
||||
"""Test the basic example from README.md."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
@@ -22,18 +21,18 @@ def test_readme_basic_example(backend_name):
|
||||
from leann.api import SearchResult
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
INDEX_PATH = str(Path(temp_dir) / f"demo_{backend_name}.leann")
|
||||
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
# In CI, use a smaller model to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
backend_name="hnsw",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
||||
dimensions=384, # Smaller dimensions
|
||||
)
|
||||
else:
|
||||
builder = LeannBuilder(backend_name=backend_name)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
@@ -53,9 +52,6 @@ def test_readme_basic_example(backend_name):
|
||||
# Verify search results
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], SearchResult)
|
||||
assert results[0].score != float("-inf"), (
|
||||
f"should return valid scores, got {results[0].score}"
|
||||
)
|
||||
# The second text about banana-crocodile should be more relevant
|
||||
assert "banana" in results[0].text or "crocodile" in results[0].text
|
||||
|
||||
@@ -114,31 +110,26 @@ def test_backend_options():
|
||||
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
def test_llm_config_simulated(backend_name):
|
||||
"""Test simulated LLM configuration option with both backends."""
|
||||
def test_llm_config_simulated():
|
||||
"""Test simulated LLM configuration option."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
# Skip DiskANN tests in CI due to hardware requirements
|
||||
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
||||
pytest.skip("Skip DiskANN tests in CI - requires specific hardware and large memory")
|
||||
|
||||
from leann import LeannBuilder, LeannChat
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Build a simple index
|
||||
index_path = str(Path(temp_dir) / f"test_{backend_name}.leann")
|
||||
index_path = str(Path(temp_dir) / "test.leann")
|
||||
# Use smaller model in CI to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
backend_name="hnsw",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
dimensions=384,
|
||||
)
|
||||
else:
|
||||
builder = LeannBuilder(backend_name=backend_name)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("Test document for LLM testing")
|
||||
builder.build_index(index_path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user