Compare commits
38 Commits
fix/clean-
...
refactor-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0877960547 | ||
|
|
d68af63d05 | ||
|
|
b844aca968 | ||
|
|
85277ba67a | ||
|
|
e9562acdc2 | ||
|
|
7fd3db1ddb | ||
|
|
c1ccc51a75 | ||
|
|
b0239b6e4d | ||
|
|
58556ef44c | ||
|
|
87c930d705 | ||
|
|
86f919a6da | ||
|
|
f8d34663b4 | ||
|
|
568cf597f4 | ||
|
|
baf70dc411 | ||
|
|
7ad2ec39d6 | ||
|
|
31fd3c816a | ||
|
|
1f6c7f2f5a | ||
|
|
c1124eb349 | ||
|
|
274bbb19ea | ||
|
|
8c152c7a31 | ||
|
|
ce77eef13a | ||
|
|
9d77175ac8 | ||
|
|
7fbb6c98ef | ||
|
|
914a248c28 | ||
|
|
55fc5862f9 | ||
|
|
fd97b8dfa8 | ||
|
|
57959947a1 | ||
|
|
cc0c091ca5 | ||
|
|
ff389c7d8d | ||
|
|
6780a8eaba | ||
|
|
984056f126 | ||
|
|
bd4451bf50 | ||
|
|
34e313f64a | ||
|
|
ddc789b231 | ||
|
|
ff1b622bdd | ||
|
|
3cde4fc7b3 | ||
|
|
4e3bcda5fa | ||
|
|
46f6f76fc3 |
56
.github/workflows/build-reusable.yml
vendored
56
.github/workflows/build-reusable.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install ruff
|
- name: Install ruff
|
||||||
run: |
|
run: |
|
||||||
uv tool install ruff==0.12.7
|
uv tool install ruff
|
||||||
|
|
||||||
- name: Run ruff check
|
- name: Run ruff check
|
||||||
run: |
|
run: |
|
||||||
@@ -111,10 +111,12 @@ jobs:
|
|||||||
|
|
||||||
- name: Build packages
|
- name: Build packages
|
||||||
run: |
|
run: |
|
||||||
# Build core (platform independent) on all platforms for consistency
|
# Build core (platform independent)
|
||||||
cd packages/leann-core
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
uv build
|
cd packages/leann-core
|
||||||
cd ../..
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
# Build HNSW backend
|
# Build HNSW backend
|
||||||
cd packages/leann-backend-hnsw
|
cd packages/leann-backend-hnsw
|
||||||
@@ -135,7 +137,7 @@ jobs:
|
|||||||
# Use system clang instead of homebrew LLVM for better compatibility
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# sgesdd_ is only available on macOS 13.3+
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
uv build --wheel --python python
|
uv build --wheel --python python
|
||||||
else
|
else
|
||||||
@@ -143,10 +145,12 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
cd ../..
|
cd ../..
|
||||||
|
|
||||||
# Build meta package (platform independent) on all platforms
|
# Build meta package (platform independent)
|
||||||
cd packages/leann
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
uv build
|
cd packages/leann
|
||||||
cd ../..
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Repair wheels (Linux)
|
- name: Repair wheels (Linux)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
@@ -160,15 +164,10 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
cd ../..
|
cd ../..
|
||||||
|
|
||||||
# Repair DiskANN wheel - use show first to debug
|
# Repair DiskANN wheel
|
||||||
cd packages/leann-backend-diskann
|
cd packages/leann-backend-diskann
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
echo "Checking DiskANN wheel contents before repair:"
|
|
||||||
unzip -l dist/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found"
|
|
||||||
auditwheel show dist/*.whl || echo "auditwheel show failed"
|
|
||||||
auditwheel repair dist/*.whl -w dist_repaired
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
echo "Checking DiskANN wheel contents after repair:"
|
|
||||||
unzip -l dist_repaired/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found after repair"
|
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -202,27 +201,22 @@ jobs:
|
|||||||
|
|
||||||
- name: Install built packages for testing
|
- name: Install built packages for testing
|
||||||
run: |
|
run: |
|
||||||
# Create a virtual environment with the correct Python version
|
# Create a virtual environment
|
||||||
uv venv --python python${{ matrix.python }}
|
uv venv
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
# Install the built wheels directly to ensure we use locally built packages
|
# Install the built wheels
|
||||||
# Use only locally built wheels on all platforms for full consistency
|
# Use --find-links to let uv choose the correct wheel for the platform
|
||||||
FIND_LINKS="--find-links packages/leann-core/dist --find-links packages/leann/dist"
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
FIND_LINKS="$FIND_LINKS --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist"
|
uv pip install leann-core --find-links packages/leann-core/dist
|
||||||
|
uv pip install leann --find-links packages/leann/dist
|
||||||
uv pip install leann-core leann leann-backend-hnsw leann-backend-diskann \
|
fi
|
||||||
$FIND_LINKS --force-reinstall
|
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
||||||
|
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
||||||
|
|
||||||
# Install test dependencies using extras
|
# Install test dependencies using extras
|
||||||
uv pip install -e ".[test]"
|
uv pip install -e ".[test]"
|
||||||
|
|
||||||
# Debug: Check if _diskannpy module is installed correctly
|
|
||||||
echo "Checking installed DiskANN module structure:"
|
|
||||||
python -c "import leann_backend_diskann; print('leann_backend_diskann location:', leann_backend_diskann.__file__)" || echo "Failed to import leann_backend_diskann"
|
|
||||||
python -c "from leann_backend_diskann import _diskannpy; print('_diskannpy imported successfully')" || echo "Failed to import _diskannpy"
|
|
||||||
ls -la $(python -c "import leann_backend_diskann; import os; print(os.path.dirname(leann_backend_diskann.__file__))" 2>/dev/null) 2>/dev/null || echo "Failed to list module directory"
|
|
||||||
|
|
||||||
- name: Run tests with pytest
|
- name: Run tests with pytest
|
||||||
env:
|
env:
|
||||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -38,7 +38,7 @@ data/*
|
|||||||
!data/2501.14312v1 (1).pdf
|
!data/2501.14312v1 (1).pdf
|
||||||
!data/2506.08276v1.pdf
|
!data/2506.08276v1.pdf
|
||||||
!data/PrideandPrejudice.txt
|
!data/PrideandPrejudice.txt
|
||||||
!data/huawei_pangu.md
|
!data/README.md
|
||||||
!data/ground_truth/
|
!data/ground_truth/
|
||||||
!data/indices/
|
!data/indices/
|
||||||
!data/queries/
|
!data/queries/
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v4.5.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -10,7 +10,7 @@ repos:
|
|||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
rev: v0.2.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
65
README.md
65
README.md
@@ -6,7 +6,6 @@
|
|||||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -17,10 +16,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
|||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +26,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
@@ -97,6 +93,7 @@ uv sync
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Our declarative API makes RAG as easy as writing a config file.
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
@@ -169,12 +166,10 @@ ollama pull llama3.2:1b
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### ⭐ Flexible Configuration
|
### Flexible Configuration
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||||
|
|
||||||
@@ -188,13 +183,12 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
|||||||
--force-rebuild # Force rebuild index even if it exists
|
--force-rebuild # Force rebuild index even if it exists
|
||||||
|
|
||||||
# Embedding Parameters
|
# Embedding Parameters
|
||||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, or mlx-community/multilingual-e5-base-mlx
|
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small or mlx-community/multilingual-e5-base-mlx
|
||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
--embedding-mode MODE # sentence-transformers, openai, or mlx
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# LLM Parameters (Text generation models)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
|
||||||
|
|
||||||
# Search Parameters
|
# Search Parameters
|
||||||
--top-k N # Number of results to retrieve (default: 20)
|
--top-k N # Number of results to retrieve (default: 20)
|
||||||
@@ -222,7 +216,7 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
|||||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a README in Chinese) and this is the **easiest example** to run here:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||||
@@ -417,26 +411,7 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
|
||||||
|
|
||||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
|
||||||
|
|
||||||
**Key features:**
|
|
||||||
- 🔍 **Semantic code search** across your entire project
|
|
||||||
- 📚 **Context-aware assistance** for debugging and development
|
|
||||||
- 🚀 **Zero-config setup** with automatic language detection
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install LEANN globally for MCP integration
|
|
||||||
uv tool install leann-core
|
|
||||||
|
|
||||||
# Setup is automatic - just start using Claude Code!
|
|
||||||
```
|
|
||||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
|
||||||
|
|
||||||
## 🖥️ Command Line Interface
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
@@ -450,24 +425,22 @@ source .venv/bin/activate
|
|||||||
leann --help
|
leann --help
|
||||||
```
|
```
|
||||||
|
|
||||||
**To make it globally available:**
|
**To make it globally available (recommended for daily use):**
|
||||||
```bash
|
```bash
|
||||||
# Install the LEANN CLI globally using uv tool
|
# Install the LEANN CLI globally using uv tool
|
||||||
uv tool install leann-core
|
uv tool install leann
|
||||||
|
|
||||||
# Now you can use leann from anywhere without activating venv
|
# Now you can use leann from anywhere without activating venv
|
||||||
leann --help
|
leann --help
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Usage Examples
|
### Usage Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# build from a specific directory, and my_docs is the index name
|
# Build an index from documents
|
||||||
leann build my-docs --docs ./your_documents
|
leann build my-docs --docs ./documents
|
||||||
|
|
||||||
# Search your documents
|
# Search your documents
|
||||||
leann search my-docs "machine learning concepts"
|
leann search my-docs "machine learning concepts"
|
||||||
@@ -541,16 +514,12 @@ Options:
|
|||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
**Backends:**
|
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
|
||||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
|
||||||
|
|
||||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
|
||||||
|
|
||||||
|
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
||||||
### 📊 Storage Comparison
|
### 📊 Storage Comparison
|
||||||
|
|
||||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
@@ -565,7 +534,8 @@ Options:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
python benchmarks/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||||
|
python benchmarks/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||||
```
|
```
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
@@ -603,11 +573,8 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class BaseRAGExample(ABC):
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
help="Embedding backend mode (default: sentence-transformers)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ class BaseRAGExample(ABC):
|
|||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
default="openai",
|
default="openai",
|
||||||
choices=["openai", "ollama", "hf", "simulated"],
|
choices=["openai", "ollama", "hf"],
|
||||||
help="LLM backend to use (default: openai)",
|
help="LLM backend to use (default: openai)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
@@ -100,13 +100,6 @@ class BaseRAGExample(ABC):
|
|||||||
default="http://localhost:11434",
|
default="http://localhost:11434",
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
help="Host for Ollama API (default: http://localhost:11434)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
|
||||||
"--thinking-budget",
|
|
||||||
type=str,
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
default=None,
|
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search parameters
|
# Search parameters
|
||||||
search_group = parser.add_argument_group("Search Parameters")
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
@@ -235,17 +228,7 @@ class BaseRAGExample(ABC):
|
|||||||
if not query:
|
if not query:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
response = chat.ask(query, top_k=args.top_k, complexity=args.search_complexity)
|
||||||
llm_kwargs = {}
|
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"\nAssistant: {response}\n")
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -264,15 +247,7 @@ class BaseRAGExample(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||||
|
response = chat.ask(query, top_k=args.top_k, complexity=args.search_complexity)
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
|
||||||
)
|
|
||||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|||||||
@@ -99,9 +99,7 @@ if __name__ == "__main__":
|
|||||||
print("- 'What are the main techniques LEANN uses?'")
|
print("- 'What are the main techniques LEANN uses?'")
|
||||||
print("- 'What is the technique DLPM?'")
|
print("- 'What is the technique DLPM?'")
|
||||||
print("- 'Who does Elizabeth Bennet marry?'")
|
print("- 'Who does Elizabeth Bennet marry?'")
|
||||||
print(
|
print("- 'What is the problem of developing pan gu model? (盘古大模型开发中遇到什么问题?)'")
|
||||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
|
||||||
)
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
rag = DocumentRAG()
|
rag = DocumentRAG()
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 73 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 224 KiB |
@@ -1,24 +1,9 @@
|
|||||||
# 🧪 LEANN Benchmarks & Testing
|
# 🧪 Leann Sanity Checks
|
||||||
|
|
||||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
### `diskann_vs_hnsw_speed_comparison.py`
|
|
||||||
Performance comparison between DiskANN and HNSW backends:
|
|
||||||
- ✅ **Search latency** comparison with both backends using recompute
|
|
||||||
- ✅ **Index size** and **build time** measurements
|
|
||||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
|
||||||
- ✅ **Configurable dataset sizes** for different scales
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Quick comparison with 500 docs, 10 queries
|
|
||||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
|
||||||
|
|
||||||
# Large-scale comparison with 2000 docs, 20 queries
|
|
||||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
|
||||||
```
|
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
|
|||||||
@@ -1,268 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
DiskANN vs HNSW Search Performance Comparison
|
|
||||||
|
|
||||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
|
||||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
|
||||||
- HNSW: With recompute enabled (is_recompute=True)
|
|
||||||
- Tests performance across different dataset sizes
|
|
||||||
- Measures search latency, recall, and index size
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_texts(n_docs: int) -> list[str]:
|
|
||||||
"""Create synthetic test documents for benchmarking."""
|
|
||||||
np.random.seed(42)
|
|
||||||
topics = [
|
|
||||||
"machine learning and artificial intelligence",
|
|
||||||
"natural language processing and text analysis",
|
|
||||||
"computer vision and image recognition",
|
|
||||||
"data science and statistical analysis",
|
|
||||||
"deep learning and neural networks",
|
|
||||||
"information retrieval and search engines",
|
|
||||||
"database systems and data management",
|
|
||||||
"software engineering and programming",
|
|
||||||
"cybersecurity and network protection",
|
|
||||||
"cloud computing and distributed systems",
|
|
||||||
]
|
|
||||||
|
|
||||||
texts = []
|
|
||||||
for i in range(n_docs):
|
|
||||||
topic = topics[i % len(topics)]
|
|
||||||
variation = np.random.randint(1, 100)
|
|
||||||
text = (
|
|
||||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
|
||||||
f"Additional information about {topic} with details and examples. "
|
|
||||||
f"Technical discussion of {topic} including implementation aspects."
|
|
||||||
)
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_backend(
|
|
||||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
|
||||||
) -> dict[str, float]:
|
|
||||||
"""Benchmark a specific backend with the given configuration."""
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
|
||||||
|
|
||||||
# Build index
|
|
||||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=backend_name,
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
**backend_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
build_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Measure index size
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
|
||||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
|
||||||
size_mb = total_size / (1024 * 1024)
|
|
||||||
|
|
||||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
|
||||||
|
|
||||||
# Search benchmark
|
|
||||||
print("🔍 Running search benchmark...")
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
|
|
||||||
search_times = []
|
|
||||||
all_results = []
|
|
||||||
|
|
||||||
for query in test_queries:
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(query, top_k=5)
|
|
||||||
search_time = time.time() - start_time
|
|
||||||
search_times.append(search_time)
|
|
||||||
all_results.append(results)
|
|
||||||
|
|
||||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
|
||||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
|
||||||
|
|
||||||
# Check for valid scores (detect -inf issues)
|
|
||||||
all_scores = [
|
|
||||||
result.score
|
|
||||||
for results in all_results
|
|
||||||
for result in results
|
|
||||||
if result.score is not None
|
|
||||||
]
|
|
||||||
valid_scores = [
|
|
||||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
|
||||||
]
|
|
||||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
try:
|
|
||||||
if hasattr(searcher, "__del__"):
|
|
||||||
searcher.__del__()
|
|
||||||
del searcher
|
|
||||||
del builder
|
|
||||||
gc.collect()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"build_time": build_time,
|
|
||||||
"avg_search_time_ms": avg_search_time,
|
|
||||||
"index_size_mb": size_mb,
|
|
||||||
"score_validity_rate": score_validity_rate,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
|
||||||
"""Run performance comparison between DiskANN and HNSW."""
|
|
||||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
|
||||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
texts = create_test_texts(n_docs)
|
|
||||||
test_queries = [
|
|
||||||
"machine learning algorithms",
|
|
||||||
"natural language processing",
|
|
||||||
"computer vision techniques",
|
|
||||||
"data analysis methods",
|
|
||||||
"neural network architectures",
|
|
||||||
"database query optimization",
|
|
||||||
"software development practices",
|
|
||||||
"security vulnerabilities",
|
|
||||||
"cloud infrastructure",
|
|
||||||
"distributed computing",
|
|
||||||
][:n_queries]
|
|
||||||
|
|
||||||
# HNSW benchmark
|
|
||||||
hnsw_results = benchmark_backend(
|
|
||||||
backend_name="hnsw",
|
|
||||||
texts=texts,
|
|
||||||
test_queries=test_queries,
|
|
||||||
backend_kwargs={
|
|
||||||
"is_recompute": True, # Enable recompute for fair comparison
|
|
||||||
"M": 16,
|
|
||||||
"efConstruction": 200,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# DiskANN benchmark
|
|
||||||
diskann_results = benchmark_backend(
|
|
||||||
backend_name="diskann",
|
|
||||||
texts=texts,
|
|
||||||
test_queries=test_queries,
|
|
||||||
backend_kwargs={
|
|
||||||
"is_recompute": True, # Enable graph partitioning
|
|
||||||
"num_neighbors": 32,
|
|
||||||
"search_list_size": 50,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Performance comparison
|
|
||||||
print("\n📈 Performance Comparison Results")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
|
||||||
print(f"{'-' * 60}")
|
|
||||||
|
|
||||||
# Build time comparison
|
|
||||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
|
||||||
print(
|
|
||||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search time comparison
|
|
||||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
|
||||||
print(
|
|
||||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index size comparison
|
|
||||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
|
||||||
print(
|
|
||||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Score validity
|
|
||||||
print(
|
|
||||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print("\n🎯 Summary:")
|
|
||||||
if search_speedup > 1:
|
|
||||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
|
||||||
else:
|
|
||||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
|
||||||
|
|
||||||
if size_ratio > 1:
|
|
||||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
|
||||||
else:
|
|
||||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
|
||||||
|
|
||||||
print(
|
|
||||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Handle help request
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
|
||||||
print("DiskANN vs HNSW Performance Comparison")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
|
||||||
print()
|
|
||||||
print("Arguments:")
|
|
||||||
print(" n_docs Number of documents to index (default: 500)")
|
|
||||||
print(" n_queries Number of test queries to run (default: 10)")
|
|
||||||
print()
|
|
||||||
print("Examples:")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Parse command line arguments
|
|
||||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
|
||||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
|
||||||
|
|
||||||
print("DiskANN vs HNSW Performance Comparison")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
|
||||||
print()
|
|
||||||
|
|
||||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n⚠️ Benchmark interrupted by user")
|
|
||||||
sys.exit(130)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Benchmark failed: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
finally:
|
|
||||||
# Ensure clean exit
|
|
||||||
try:
|
|
||||||
gc.collect()
|
|
||||||
print("\n🧹 Cleanup completed")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
sys.exit(0)
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
# Thinking Budget Feature Implementation
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
|
||||||
|
|
||||||
## Feature Description
|
|
||||||
|
|
||||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
|
||||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
|
||||||
- **`medium`**: Balanced speed and reasoning depth
|
|
||||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
### 1. Command Line Interface
|
|
||||||
|
|
||||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# LEANN CLI
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# RAG Examples
|
|
||||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. LLM Backend Support
|
|
||||||
|
|
||||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
options = kwargs.copy()
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget:
|
|
||||||
options.pop("thinking_budget", None)
|
|
||||||
if thinking_budget in ["low", "medium", "high"]:
|
|
||||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
|
||||||
```
|
|
||||||
|
|
||||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
|
||||||
|
|
||||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
|
||||||
# Check if this is an o-series model
|
|
||||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
|
||||||
if any(model in self.model for model in o_series_models):
|
|
||||||
params["reasoning_effort"] = thinking_budget
|
|
||||||
```
|
|
||||||
|
|
||||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
|
||||||
|
|
||||||
### 3. Parameter Propagation
|
|
||||||
|
|
||||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
|
||||||
|
|
||||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
|
||||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
|
||||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
|
||||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
|
||||||
|
|
||||||
## Files Modified
|
|
||||||
|
|
||||||
### Core Implementation
|
|
||||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
|
||||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
|
||||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
- `README.md`: Added thinking budget parameter to usage examples
|
|
||||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
|
||||||
|
|
||||||
### Examples
|
|
||||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
|
||||||
|
|
||||||
## Usage Examples
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
```bash
|
|
||||||
# High reasoning effort for complex questions
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# Medium reasoning for balanced performance
|
|
||||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
|
||||||
|
|
||||||
# Low reasoning for fast responses
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
|
||||||
```
|
|
||||||
|
|
||||||
### RAG Examples
|
|
||||||
```bash
|
|
||||||
# Email RAG with high reasoning
|
|
||||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# Document RAG with medium reasoning
|
|
||||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
### Ollama Models
|
|
||||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
|
||||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
|
||||||
|
|
||||||
### OpenAI Models
|
|
||||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
|
||||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
The implementation includes comprehensive testing:
|
|
||||||
- Parameter handling verification
|
|
||||||
- Backend-specific API format validation
|
|
||||||
- CLI argument parsing tests
|
|
||||||
- Integration with existing LEANN architecture
|
|
||||||
@@ -1,294 +0,0 @@
|
|||||||
# LEANN Configuration Guide
|
|
||||||
|
|
||||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
|
||||||
|
|
||||||
## Getting Started: Simple is Better
|
|
||||||
|
|
||||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
|
||||||
|
|
||||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
|
||||||
```bash
|
|
||||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**For other data sources**: Limit the dataset size for quick testing
|
|
||||||
```bash
|
|
||||||
# WeChat: Test with recent messages only
|
|
||||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
|
||||||
|
|
||||||
# Browser history: Last few days
|
|
||||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
|
||||||
|
|
||||||
# Email: Recent inbox
|
|
||||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
|
||||||
```
|
|
||||||
|
|
||||||
Once validated, scale up gradually:
|
|
||||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
|
||||||
- This helps identify issues early before committing to long processing times
|
|
||||||
|
|
||||||
## Embedding Model Selection: Understanding the Trade-offs
|
|
||||||
|
|
||||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
|
||||||
|
|
||||||
### Small Models (< 100M parameters)
|
|
||||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
|
||||||
- **Pros**: Lightweight, fast for both indexing and inference
|
|
||||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
|
||||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
|
||||||
|
|
||||||
### Medium Models (100M-500M parameters)
|
|
||||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
|
||||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
|
||||||
- **Cons**: Requires more compute than small models
|
|
||||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
|
||||||
|
|
||||||
### Large Models (500M+ parameters)
|
|
||||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
|
||||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
|
||||||
- **Cons**: Slower inference, longer index build times
|
|
||||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
|
||||||
|
|
||||||
### Quick Start: Cloud and Local Embedding Options
|
|
||||||
|
|
||||||
**OpenAI Embeddings (Fastest Setup)**
|
|
||||||
For immediate testing without local model downloads:
|
|
||||||
```bash
|
|
||||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
|
||||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
|
||||||
```
|
|
||||||
|
|
||||||
**Ollama Embeddings (Privacy-Focused)**
|
|
||||||
For local embeddings with complete privacy:
|
|
||||||
```bash
|
|
||||||
# First, pull an embedding model
|
|
||||||
ollama pull nomic-embed-text
|
|
||||||
|
|
||||||
# Use Ollama embeddings
|
|
||||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
|
||||||
|
|
||||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
|
||||||
- **Pros**: No local compute needed, consistently fast, high quality
|
|
||||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
|
||||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
|
||||||
|
|
||||||
**Local Embeddings**
|
|
||||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
|
||||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
|
||||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
|
||||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
|
||||||
- Full recomputation required
|
|
||||||
- High memory usage during build phase
|
|
||||||
- Excellent recall (95%+)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Optimal for most use cases
|
|
||||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
|
||||||
```
|
|
||||||
|
|
||||||
### DiskANN
|
|
||||||
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
|
|
||||||
|
|
||||||
**How it works:**
|
|
||||||
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
|
|
||||||
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
|
|
||||||
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
|
|
||||||
|
|
||||||
**Trade-offs compared to HNSW:**
|
|
||||||
- ✅ **Faster search latency** (typically 2-8x speedup)
|
|
||||||
- ✅ **Better scaling** for large datasets
|
|
||||||
- ✅ **Smart storage management** with automatic partitioning
|
|
||||||
- ✅ **Better graph locality** with `--ldg-times` parameter for SSD optimization
|
|
||||||
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Recommended for most use cases
|
|
||||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
|
||||||
|
|
||||||
# For large-scale deployments
|
|
||||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
|
||||||
```
|
|
||||||
|
|
||||||
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
|
||||||
|
|
||||||
## LLM Selection: Engine and Model Comparison
|
|
||||||
|
|
||||||
### LLM Engines
|
|
||||||
|
|
||||||
**OpenAI** (`--llm openai`)
|
|
||||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
|
||||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
|
||||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
|
||||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
|
||||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
|
||||||
|
|
||||||
**Ollama** (`--llm ollama`)
|
|
||||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
|
||||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
|
||||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
|
||||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
|
||||||
|
|
||||||
**HuggingFace** (`--llm hf`)
|
|
||||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
|
||||||
- **Cons**: More complex initial setup
|
|
||||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
|
||||||
|
|
||||||
## Parameter Tuning Guide
|
|
||||||
|
|
||||||
### Search Complexity Parameters
|
|
||||||
|
|
||||||
**`--build-complexity`** (index building)
|
|
||||||
- Controls thoroughness during index construction
|
|
||||||
- Higher = better recall but slower build
|
|
||||||
- Recommendations:
|
|
||||||
- 32: Quick prototyping
|
|
||||||
- 64: Balanced (default)
|
|
||||||
- 128: Production systems
|
|
||||||
- 256: Maximum quality
|
|
||||||
|
|
||||||
**`--search-complexity`** (query time)
|
|
||||||
- Controls search thoroughness
|
|
||||||
- Higher = better results but slower
|
|
||||||
- Recommendations:
|
|
||||||
- 16: Fast/Interactive search
|
|
||||||
- 32: High quality with diversity
|
|
||||||
- 64+: Maximum accuracy
|
|
||||||
|
|
||||||
### Top-K Selection
|
|
||||||
|
|
||||||
**`--top-k`** (number of retrieved chunks)
|
|
||||||
- More chunks = better context but slower LLM processing
|
|
||||||
- Should be always smaller than `--search-complexity`
|
|
||||||
- Guidelines:
|
|
||||||
- 10-20: General questions (default: 20)
|
|
||||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
|
||||||
|
|
||||||
**Trade-off formula**:
|
|
||||||
- Retrieval time ∝ log(n) × search_complexity
|
|
||||||
- LLM processing time ∝ top_k × chunk_size
|
|
||||||
- Total context = top_k × chunk_size tokens
|
|
||||||
|
|
||||||
### Thinking Budget for Reasoning Models
|
|
||||||
|
|
||||||
**`--thinking-budget`** (reasoning effort level)
|
|
||||||
- Controls the computational effort for reasoning models
|
|
||||||
- Options: `low`, `medium`, `high`
|
|
||||||
- Guidelines:
|
|
||||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
|
||||||
- `medium`: Balanced speed and reasoning depth
|
|
||||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
|
||||||
- **Supported Models**:
|
|
||||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
|
||||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
|
||||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
|
||||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
|
||||||
|
|
||||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
|
||||||
|
|
||||||
**💡 Quick Examples:**
|
|
||||||
```bash
|
|
||||||
# OpenAI o-series reasoning model
|
|
||||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
|
||||||
--index-dir hnswbuild --backend hnsw \
|
|
||||||
--llm openai --llm-model o3 --thinking-budget medium
|
|
||||||
|
|
||||||
# Ollama reasoning model
|
|
||||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
|
||||||
--index-dir hnswbuild --backend hnsw \
|
|
||||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
```
|
|
||||||
|
|
||||||
### Graph Degree (HNSW/DiskANN)
|
|
||||||
|
|
||||||
**`--graph-degree`**
|
|
||||||
- Number of connections per node in the graph
|
|
||||||
- Higher = better recall but more memory
|
|
||||||
- HNSW: 16-32 (default: 32)
|
|
||||||
- DiskANN: 32-128 (default: 64)
|
|
||||||
|
|
||||||
|
|
||||||
## Performance Optimization Checklist
|
|
||||||
|
|
||||||
### If Embedding is Too Slow
|
|
||||||
|
|
||||||
1. **Switch to smaller model**:
|
|
||||||
```bash
|
|
||||||
# From large model
|
|
||||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
|
||||||
# To small model
|
|
||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Limit dataset size for testing**:
|
|
||||||
```bash
|
|
||||||
--max-items 1000 # Process first 1k items only
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
|
||||||
```bash
|
|
||||||
--embedding-mode mlx --embedding-model mlx-community/multilingual-e5-base-mlx
|
|
||||||
```
|
|
||||||
|
|
||||||
### If Search Quality is Poor
|
|
||||||
|
|
||||||
1. **Increase retrieval count**:
|
|
||||||
```bash
|
|
||||||
--top-k 30 # Retrieve more candidates
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Upgrade embedding model**:
|
|
||||||
```bash
|
|
||||||
# For English
|
|
||||||
--embedding-model BAAI/bge-base-en-v1.5
|
|
||||||
# For multilingual
|
|
||||||
--embedding-model intfloat/multilingual-e5-large
|
|
||||||
```
|
|
||||||
|
|
||||||
## Understanding the Trade-offs
|
|
||||||
|
|
||||||
Every configuration choice involves trade-offs:
|
|
||||||
|
|
||||||
| Factor | Small/Fast | Large/Quality |
|
|
||||||
|--------|------------|---------------|
|
|
||||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
|
||||||
| Chunk Size | 512 tokens | 128 tokens |
|
|
||||||
| Index Type | HNSW | DiskANN |
|
|
||||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
|
||||||
|
|
||||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
|
||||||
|
|
||||||
## Deep Dive: Critical Configuration Decisions
|
|
||||||
|
|
||||||
### When to Disable Recomputation
|
|
||||||
|
|
||||||
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--no-recompute # Disable selective recomputation
|
|
||||||
```
|
|
||||||
|
|
||||||
**Trade-offs**:
|
|
||||||
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
|
|
||||||
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
|
|
||||||
|
|
||||||
**Disable when**:
|
|
||||||
- You have abundant storage and memory
|
|
||||||
- Need extremely low latency (< 100ms)
|
|
||||||
- Running a read-heavy workload where storage cost is acceptable
|
|
||||||
|
|
||||||
## Further Reading
|
|
||||||
|
|
||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
|
||||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
|
||||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
|
||||||
@@ -5,7 +5,7 @@
|
|||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
|
|
||||||
## 🛠️ Technical Highlights
|
## 🛠️ Technical Highlights
|
||||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
## 🎯 Q2 2025
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
- [X] HNSW backend integration
|
|
||||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] HNSW backend integration
|
||||||
- [X] Real-time embedding pipeline
|
- [X] Real-time embedding pipeline
|
||||||
- [X] Memory-efficient graph pruning
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1 @@
|
|||||||
from . import diskann_backend as diskann_backend
|
from . import diskann_backend as diskann_backend
|
||||||
from . import graph_partition
|
|
||||||
|
|
||||||
# Export main classes and functions
|
|
||||||
from .graph_partition import GraphPartitioner, partition_graph
|
|
||||||
|
|
||||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
|
||||||
|
|||||||
@@ -4,10 +4,9 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendFactoryInterface,
|
LeannBackendFactoryInterface,
|
||||||
@@ -85,43 +84,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
|||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
|
|
||||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
|
||||||
"""
|
|
||||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: The embedding data array
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
|
||||||
"""
|
|
||||||
num_vectors, dim = data.shape
|
|
||||||
|
|
||||||
# Calculate embedding storage size
|
|
||||||
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
|
||||||
embedding_size_gb = embedding_size_bytes / (1024**3)
|
|
||||||
|
|
||||||
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
|
||||||
# This controls Product Quantization size - smaller means more compression
|
|
||||||
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
|
||||||
|
|
||||||
# build_memory_maximum: Based on available system RAM for sharding control
|
|
||||||
# This controls how much memory DiskANN uses during index construction
|
|
||||||
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
|
||||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
|
||||||
|
|
||||||
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
|
||||||
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
|
||||||
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
|
||||||
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return search_memory_gb, build_memory_gb
|
|
||||||
|
|
||||||
|
|
||||||
@register_backend("diskann")
|
@register_backend("diskann")
|
||||||
class DiskannBackend(LeannBackendFactoryInterface):
|
class DiskannBackend(LeannBackendFactoryInterface):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -137,71 +99,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
|
||||||
"""
|
|
||||||
Safely cleanup files after partition.
|
|
||||||
In partition mode, C++ doesn't read _disk.index content,
|
|
||||||
so we can delete it if all derived files exist.
|
|
||||||
"""
|
|
||||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
|
||||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
|
||||||
|
|
||||||
# Required files that C++ partition mode needs
|
|
||||||
# Note: C++ generates these with _disk.index suffix
|
|
||||||
disk_suffix = "_disk.index"
|
|
||||||
required_files = [
|
|
||||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
|
||||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
|
||||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
|
||||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if all required files exist
|
|
||||||
missing_files = []
|
|
||||||
for filename in required_files:
|
|
||||||
file_path = index_dir / filename
|
|
||||||
if not file_path.exists():
|
|
||||||
missing_files.append(filename)
|
|
||||||
|
|
||||||
if missing_files:
|
|
||||||
logger.warning(
|
|
||||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
|
||||||
)
|
|
||||||
logger.info("Keeping all original files for safety")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate space savings
|
|
||||||
space_saved = 0
|
|
||||||
files_to_delete = []
|
|
||||||
|
|
||||||
if disk_index_file.exists():
|
|
||||||
space_saved += disk_index_file.stat().st_size
|
|
||||||
files_to_delete.append(disk_index_file)
|
|
||||||
|
|
||||||
if beam_search_file.exists():
|
|
||||||
space_saved += beam_search_file.stat().st_size
|
|
||||||
files_to_delete.append(beam_search_file)
|
|
||||||
|
|
||||||
# Safe to delete!
|
|
||||||
for file_to_delete in files_to_delete:
|
|
||||||
try:
|
|
||||||
os.remove(file_to_delete)
|
|
||||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
|
||||||
|
|
||||||
if space_saved > 0:
|
|
||||||
space_saved_mb = space_saved / (1024 * 1024)
|
|
||||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
|
||||||
|
|
||||||
# Show what files are kept
|
|
||||||
logger.info("📁 Kept essential files for partition mode:")
|
|
||||||
for filename in required_files:
|
|
||||||
file_path = index_dir / filename
|
|
||||||
if file_path.exists():
|
|
||||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
||||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -216,17 +113,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
|
|
||||||
# Extract is_recompute from nested backend_kwargs if needed
|
|
||||||
is_recompute = build_kwargs.get("is_recompute", False)
|
|
||||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
|
||||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
|
||||||
|
|
||||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
|
||||||
if "backend_kwargs" in build_kwargs:
|
|
||||||
nested_params = build_kwargs.pop("backend_kwargs")
|
|
||||||
build_kwargs.update(nested_params)
|
|
||||||
|
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
@@ -235,16 +121,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate smart memory configuration if not explicitly provided
|
|
||||||
if (
|
|
||||||
"search_memory_maximum" not in build_kwargs
|
|
||||||
or "build_memory_maximum" not in build_kwargs
|
|
||||||
):
|
|
||||||
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
|
||||||
else:
|
|
||||||
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
|
||||||
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
@@ -255,36 +131,12 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
index_prefix,
|
index_prefix,
|
||||||
build_kwargs.get("complexity", 64),
|
build_kwargs.get("complexity", 64),
|
||||||
build_kwargs.get("graph_degree", 32),
|
build_kwargs.get("graph_degree", 32),
|
||||||
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
build_kwargs.get("search_memory_maximum", 4.0),
|
||||||
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
build_kwargs.get("build_memory_maximum", 8.0),
|
||||||
build_kwargs.get("num_threads", 8),
|
build_kwargs.get("num_threads", 8),
|
||||||
build_kwargs.get("pq_disk_bytes", 0),
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Auto-partition if is_recompute is enabled
|
|
||||||
if build_kwargs.get("is_recompute", False):
|
|
||||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
|
||||||
from .graph_partition import partition_graph
|
|
||||||
|
|
||||||
# Partition the index using absolute paths
|
|
||||||
# Convert to absolute paths to avoid issues with working directory changes
|
|
||||||
absolute_index_dir = Path(index_dir).resolve()
|
|
||||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph(
|
|
||||||
index_prefix_path=absolute_index_prefix_path,
|
|
||||||
output_dir=str(absolute_index_dir),
|
|
||||||
partition_prefix=index_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
|
||||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
|
||||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
|
||||||
|
|
||||||
logger.info("✅ Graph partitioning completed successfully!")
|
|
||||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
|
||||||
logger.info(f" - Partition file: {partition_bin_path}")
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
@@ -313,26 +165,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
# Store the initialization parameters for later use
|
# Store the initialization parameters for later use
|
||||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
# C++ internally constructs: index_prefix + "_disk.index"
|
|
||||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
|
||||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
||||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
|
||||||
|
|
||||||
# Auto-detect partition files and set partition_prefix
|
|
||||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
|
||||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
|
||||||
|
|
||||||
partition_prefix = ""
|
|
||||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
|
||||||
# C++ expects full path prefix, not just filename
|
|
||||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
||||||
logger.info(
|
|
||||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("No partition files detected, using standard index files")
|
|
||||||
|
|
||||||
self._init_params = {
|
self._init_params = {
|
||||||
"metric_enum": metric_enum,
|
"metric_enum": metric_enum,
|
||||||
"full_index_prefix": full_index_prefix,
|
"full_index_prefix": full_index_prefix,
|
||||||
@@ -340,14 +173,8 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
"cache_mechanism": 1,
|
"cache_mechanism": 1,
|
||||||
"pq_prefix": "",
|
"pq_prefix": "",
|
||||||
"partition_prefix": partition_prefix,
|
"partition_prefix": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Log partition configuration for debugging
|
|
||||||
if partition_prefix:
|
|
||||||
logger.info(
|
|
||||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
||||||
)
|
|
||||||
self._diskannpy = diskannpy
|
self._diskannpy = diskannpy
|
||||||
self._current_zmq_port = None
|
self._current_zmq_port = None
|
||||||
self._index = None
|
self._index = None
|
||||||
@@ -384,7 +211,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -437,8 +264,6 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
# Perform search with suppressed C++ output based on log level
|
# Perform search with suppressed C++ output based on log level
|
||||||
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
|
||||||
recompute_neighors = False
|
|
||||||
with suppress_cpp_output_if_needed():
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
@@ -447,9 +272,9 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
complexity,
|
complexity,
|
||||||
beam_width,
|
beam_width,
|
||||||
self.num_threads,
|
self.num_threads,
|
||||||
use_deferred_fetch,
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
kwargs.get("skip_search_reorder", False),
|
kwargs.get("skip_search_reorder", False),
|
||||||
recompute_neighors,
|
recompute_embeddings,
|
||||||
dedup_node_dis,
|
dedup_node_dis,
|
||||||
prune_ratio,
|
prune_ratio,
|
||||||
batch_recompute,
|
batch_recompute,
|
||||||
@@ -459,25 +284,3 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Cleanup DiskANN-specific resources including C++ index."""
|
|
||||||
# Call parent cleanup first
|
|
||||||
super().cleanup()
|
|
||||||
|
|
||||||
# Delete the C++ index to trigger destructors
|
|
||||||
try:
|
|
||||||
if hasattr(self, "_index") and self._index is not None:
|
|
||||||
del self._index
|
|
||||||
self._index = None
|
|
||||||
self._current_zmq_port = None
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Force garbage collection to ensure C++ objects are destroyed
|
|
||||||
try:
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -33,7 +32,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
@@ -81,8 +80,7 @@ def create_diskann_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
passages = PassageManager(meta["passage_sources"])
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
@@ -263,7 +261,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -1,299 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Graph Partition Module for LEANN DiskANN Backend
|
|
||||||
|
|
||||||
This module provides Python bindings for the graph partition functionality
|
|
||||||
of DiskANN, allowing users to partition disk-based indices for better
|
|
||||||
performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class GraphPartitioner:
|
|
||||||
"""
|
|
||||||
A Python interface for DiskANN's graph partition functionality.
|
|
||||||
|
|
||||||
This class provides methods to partition disk-based indices for improved
|
|
||||||
search performance and memory efficiency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, build_type: str = "release"):
|
|
||||||
"""
|
|
||||||
Initialize the GraphPartitioner.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
build_type: Build type for the executables ("debug" or "release")
|
|
||||||
"""
|
|
||||||
self.build_type = build_type
|
|
||||||
self._ensure_executables()
|
|
||||||
|
|
||||||
def _get_executable_path(self, name: str) -> str:
|
|
||||||
"""Get the path to a graph partition executable."""
|
|
||||||
# Get the directory where this Python module is located
|
|
||||||
module_dir = Path(__file__).parent
|
|
||||||
# Navigate to the graph_partition directory
|
|
||||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
|
||||||
|
|
||||||
if not executable_path.exists():
|
|
||||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
|
||||||
|
|
||||||
return str(executable_path)
|
|
||||||
|
|
||||||
def _ensure_executables(self):
|
|
||||||
"""Ensure that the required executables are built."""
|
|
||||||
try:
|
|
||||||
self._get_executable_path("partitioner")
|
|
||||||
self._get_executable_path("index_relayout")
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Try to build the executables automatically
|
|
||||||
print("Executables not found, attempting to build them...")
|
|
||||||
self._build_executables()
|
|
||||||
|
|
||||||
def _build_executables(self):
|
|
||||||
"""Build the required executables."""
|
|
||||||
graph_partition_dir = (
|
|
||||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
)
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.chdir(graph_partition_dir)
|
|
||||||
|
|
||||||
# Clean any existing build
|
|
||||||
if (graph_partition_dir / "build").exists():
|
|
||||||
shutil.rmtree(graph_partition_dir / "build")
|
|
||||||
|
|
||||||
# Run the build script
|
|
||||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
|
||||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
|
||||||
|
|
||||||
# Check if executables were created
|
|
||||||
partitioner_path = self._get_executable_path("partitioner")
|
|
||||||
relayout_path = self._get_executable_path("index_relayout")
|
|
||||||
|
|
||||||
print(f"✅ Built partitioner: {partitioner_path}")
|
|
||||||
print(f"✅ Built index_relayout: {relayout_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to build executables: {e}")
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
|
|
||||||
def partition_graph(
|
|
||||||
self,
|
|
||||||
index_prefix_path: str,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
partition_prefix: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Partition a disk-based index for improved performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
|
||||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
|
||||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
|
||||||
**kwargs: Additional parameters for graph partitioning:
|
|
||||||
- gp_times: Number of LDG partition iterations (default: 10)
|
|
||||||
- lock_nums: Number of lock nodes (default: 10)
|
|
||||||
- cut: Cut adjacency list degree (default: 100)
|
|
||||||
- scale_factor: Scale factor (default: 1)
|
|
||||||
- data_type: Data type (default: "float")
|
|
||||||
- thread_nums: Number of threads (default: 10)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the partitioning process fails
|
|
||||||
"""
|
|
||||||
# Set default parameters
|
|
||||||
params = {
|
|
||||||
"gp_times": 10,
|
|
||||||
"lock_nums": 10,
|
|
||||||
"cut": 100,
|
|
||||||
"scale_factor": 1,
|
|
||||||
"data_type": "float",
|
|
||||||
"thread_nums": 10,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine output directory
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = str(Path(index_prefix_path).parent)
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Determine partition prefix
|
|
||||||
if partition_prefix is None:
|
|
||||||
partition_prefix = Path(index_prefix_path).name
|
|
||||||
|
|
||||||
# Get executable paths
|
|
||||||
partitioner_path = self._get_executable_path("partitioner")
|
|
||||||
relayout_path = self._get_executable_path("index_relayout")
|
|
||||||
|
|
||||||
# Create temporary directory for processing
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
# Change to the graph_partition directory for temporary files
|
|
||||||
graph_partition_dir = (
|
|
||||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
)
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.chdir(graph_partition_dir)
|
|
||||||
|
|
||||||
# Create temporary data directory
|
|
||||||
temp_data_dir = Path(temp_dir) / "data"
|
|
||||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Set up paths for temporary files
|
|
||||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
|
||||||
graph_gp_path = (
|
|
||||||
graph_path
|
|
||||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
|
||||||
)
|
|
||||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Find input index file
|
|
||||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
old_index_file = f"{index_prefix_path}_disk.index"
|
|
||||||
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
|
||||||
|
|
||||||
# Run partitioner
|
|
||||||
gp_file_path = graph_gp_path / "_part.bin"
|
|
||||||
partitioner_cmd = [
|
|
||||||
partitioner_path,
|
|
||||||
"--index_file",
|
|
||||||
old_index_file,
|
|
||||||
"--data_type",
|
|
||||||
params["data_type"],
|
|
||||||
"--gp_file",
|
|
||||||
str(gp_file_path),
|
|
||||||
"-T",
|
|
||||||
str(params["thread_nums"]),
|
|
||||||
"--ldg_times",
|
|
||||||
str(params["gp_times"]),
|
|
||||||
"--scale",
|
|
||||||
str(params["scale_factor"]),
|
|
||||||
"--mode",
|
|
||||||
"1",
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
|
||||||
result = subprocess.run(
|
|
||||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Partitioner failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run relayout
|
|
||||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
|
||||||
relayout_cmd = [
|
|
||||||
relayout_path,
|
|
||||||
old_index_file,
|
|
||||||
str(gp_file_path),
|
|
||||||
params["data_type"],
|
|
||||||
"1",
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
|
||||||
result = subprocess.run(
|
|
||||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Relayout failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy results to output directory
|
|
||||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
|
||||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
|
||||||
|
|
||||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
|
||||||
shutil.copy2(gp_file_path, partition_bin_path)
|
|
||||||
|
|
||||||
print(f"Results copied to: {output_dir}")
|
|
||||||
return str(disk_graph_path), str(partition_bin_path)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
|
|
||||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
|
||||||
"""
|
|
||||||
Get information about a partition file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
partition_bin_path: Path to the partition binary file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing partition information
|
|
||||||
"""
|
|
||||||
if not os.path.exists(partition_bin_path):
|
|
||||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
|
||||||
|
|
||||||
# For now, return basic file information
|
|
||||||
# In the future, this could parse the binary file for detailed info
|
|
||||||
stat = os.stat(partition_bin_path)
|
|
||||||
return {
|
|
||||||
"file_size": stat.st_size,
|
|
||||||
"file_path": partition_bin_path,
|
|
||||||
"modified_time": stat.st_mtime,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def partition_graph(
|
|
||||||
index_prefix_path: str,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
partition_prefix: Optional[str] = None,
|
|
||||||
build_type: str = "release",
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Convenience function to partition a graph index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix
|
|
||||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
|
||||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
|
||||||
build_type: Build type for executables ("debug" or "release")
|
|
||||||
**kwargs: Additional parameters for graph partitioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
"""
|
|
||||||
partitioner = GraphPartitioner(build_type=build_type)
|
|
||||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Example: partition an index
|
|
||||||
try:
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph(
|
|
||||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
|
||||||
)
|
|
||||||
print("Partitioning completed successfully!")
|
|
||||||
print(f"Disk graph index: {disk_graph_path}")
|
|
||||||
print(f"Partition binary: {partition_bin_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Partitioning failed: {e}")
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Simplified Graph Partition Module for LEANN DiskANN Backend
|
|
||||||
|
|
||||||
This module provides a simple Python interface for graph partitioning
|
|
||||||
that directly calls the existing executables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
def partition_graph_simple(
|
|
||||||
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Simple function to partition a graph index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
|
||||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
|
||||||
**kwargs: Additional parameters for graph partitioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
"""
|
|
||||||
# Set default parameters
|
|
||||||
params = {
|
|
||||||
"gp_times": 10,
|
|
||||||
"lock_nums": 10,
|
|
||||||
"cut": 100,
|
|
||||||
"scale_factor": 1,
|
|
||||||
"data_type": "float",
|
|
||||||
"thread_nums": 10,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine output directory
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = str(Path(index_prefix_path).parent)
|
|
||||||
|
|
||||||
# Find the graph_partition directory
|
|
||||||
current_file = Path(__file__)
|
|
||||||
graph_partition_dir = current_file.parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
|
|
||||||
if not graph_partition_dir.exists():
|
|
||||||
raise RuntimeError(f"Graph partition directory not found: {graph_partition_dir}")
|
|
||||||
|
|
||||||
# Find input index file
|
|
||||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
old_index_file = f"{index_prefix_path}_disk.index"
|
|
||||||
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
|
||||||
|
|
||||||
# Create temporary directory for processing
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
temp_data_dir = Path(temp_dir) / "data"
|
|
||||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Set up paths for temporary files
|
|
||||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
|
||||||
graph_gp_path = (
|
|
||||||
graph_path
|
|
||||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
|
||||||
)
|
|
||||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Run the build script with our parameters
|
|
||||||
cmd = [str(graph_partition_dir / "build.sh"), "release", "split_graph", index_prefix_path]
|
|
||||||
|
|
||||||
# Set environment variables for parameters
|
|
||||||
env = os.environ.copy()
|
|
||||||
env.update(
|
|
||||||
{
|
|
||||||
"GP_TIMES": str(params["gp_times"]),
|
|
||||||
"GP_LOCK_NUMS": str(params["lock_nums"]),
|
|
||||||
"GP_CUT": str(params["cut"]),
|
|
||||||
"GP_SCALE_F": str(params["scale_factor"]),
|
|
||||||
"DATA_TYPE": params["data_type"],
|
|
||||||
"GP_T": str(params["thread_nums"]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Running graph partition with command: {' '.join(cmd)}")
|
|
||||||
print(f"Working directory: {graph_partition_dir}")
|
|
||||||
|
|
||||||
# Run the command
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd, env=env, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
print(f"Command failed with return code {result.returncode}")
|
|
||||||
print(f"stdout: {result.stdout}")
|
|
||||||
print(f"stderr: {result.stderr}")
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Graph partitioning failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if output files were created
|
|
||||||
disk_graph_path = Path(output_dir) / "_disk_graph.index"
|
|
||||||
partition_bin_path = Path(output_dir) / "_partition.bin"
|
|
||||||
|
|
||||||
if not disk_graph_path.exists():
|
|
||||||
raise RuntimeError(f"Expected output file not found: {disk_graph_path}")
|
|
||||||
|
|
||||||
if not partition_bin_path.exists():
|
|
||||||
raise RuntimeError(f"Expected output file not found: {partition_bin_path}")
|
|
||||||
|
|
||||||
print("✅ Partitioning completed successfully!")
|
|
||||||
print(f" Disk graph index: {disk_graph_path}")
|
|
||||||
print(f" Partition binary: {partition_bin_path}")
|
|
||||||
|
|
||||||
return str(disk_graph_path), str(partition_bin_path)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
if __name__ == "__main__":
|
|
||||||
try:
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph_simple(
|
|
||||||
"/Users/yichuan/Desktop/release2/leann/diskannbuild/test_doc_files",
|
|
||||||
gp_times=5,
|
|
||||||
lock_nums=5,
|
|
||||||
cut=50,
|
|
||||||
)
|
|
||||||
print("Success! Output files:")
|
|
||||||
print(f" - {disk_graph_path}")
|
|
||||||
print(f" - {partition_bin_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.2.5"
|
version = "0.1.16"
|
||||||
dependencies = ["leann-core==0.2.5", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.1.16", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: b2dc4ea2c7...af2a26481e
@@ -2,7 +2,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
@@ -152,7 +152,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,7 +33,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -82,8 +81,19 @@ def create_hnsw_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Let PassageManager handle path resolution uniformly
|
# Convert relative paths to absolute paths based on metadata file location
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||||
|
passage_sources = []
|
||||||
|
for source in meta["passage_sources"]:
|
||||||
|
source_copy = source.copy()
|
||||||
|
# Convert relative paths to absolute paths
|
||||||
|
if not Path(source_copy["path"]).is_absolute():
|
||||||
|
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||||
|
if not Path(source_copy["index_path"]).is_absolute():
|
||||||
|
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||||
|
passage_sources.append(source_copy)
|
||||||
|
|
||||||
|
passages = PassageManager(passage_sources)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
@@ -285,7 +295,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.2.5"
|
version = "0.1.16"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.2.5",
|
"leann-core==0.1.16",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.2.5"
|
version = "0.1.16"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
@@ -44,7 +44,6 @@ colab = [
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
leann = "leann.cli:main"
|
leann = "leann.cli:main"
|
||||||
leann_mcp = "leann.mcp:main"
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: Optional[int] = None,
|
port: int | None = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -87,26 +87,21 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
|
|||||||
# Connect to embedding server
|
# Connect to embedding server
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REQ)
|
socket = context.socket(zmq.REQ)
|
||||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
||||||
socket.setsockopt(zmq.IMMEDIATE, 1)
|
|
||||||
socket.connect(f"tcp://localhost:{port}")
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
try:
|
# Send chunks to server for embedding computation
|
||||||
# Send chunks to server for embedding computation
|
request = chunks
|
||||||
request = chunks
|
socket.send(msgpack.packb(request))
|
||||||
socket.send(msgpack.packb(request))
|
|
||||||
|
|
||||||
# Receive embeddings from server
|
# Receive embeddings from server
|
||||||
response = socket.recv()
|
response = socket.recv()
|
||||||
embeddings_list = msgpack.unpackb(response)
|
embeddings_list = msgpack.unpackb(response)
|
||||||
|
|
||||||
# Convert back to numpy array
|
# Convert back to numpy array
|
||||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
finally:
|
|
||||||
socket.close()
|
socket.close()
|
||||||
# Don't call context.term() - this was causing hangs
|
context.term()
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@@ -120,9 +115,7 @@ class SearchResult:
|
|||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(
|
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
|
||||||
):
|
|
||||||
self.offset_maps = {}
|
self.offset_maps = {}
|
||||||
self.passage_files = {}
|
self.passage_files = {}
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
@@ -132,26 +125,10 @@ class PassageManager:
|
|||||||
passage_file = source["path"]
|
passage_file = source["path"]
|
||||||
index_file = source["index_path"] # .idx file
|
index_file = source["index_path"] # .idx file
|
||||||
|
|
||||||
# Fix path resolution - relative paths should be relative to metadata file directory
|
# Fix path resolution for Colab and other environments
|
||||||
if not Path(index_file).is_absolute():
|
if not Path(index_file).is_absolute():
|
||||||
if metadata_file_path:
|
# If relative path, try to resolve it properly
|
||||||
# Resolve relative to metadata file directory
|
index_file = str(Path(index_file).resolve())
|
||||||
metadata_dir = Path(metadata_file_path).parent
|
|
||||||
logger.debug(
|
|
||||||
f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
|
|
||||||
)
|
|
||||||
index_file = str((metadata_dir / index_file).resolve())
|
|
||||||
passage_file = str((metadata_dir / passage_file).resolve())
|
|
||||||
logger.debug(f"PassageManager: Resolved index_file: {index_file}")
|
|
||||||
else:
|
|
||||||
# Fallback to current directory resolution (legacy behavior)
|
|
||||||
logger.warning(
|
|
||||||
"PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
|
|
||||||
)
|
|
||||||
logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
|
|
||||||
index_file = str(Path(index_file).resolve())
|
|
||||||
passage_file = str(Path(passage_file).resolve())
|
|
||||||
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
|
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
@@ -180,12 +157,12 @@ class LeannBuilder:
|
|||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: int | None = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
self.backend_factory = backend_factory
|
self.backend_factory = backend_factory
|
||||||
@@ -265,7 +242,7 @@ class LeannBuilder:
|
|||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
self.chunks: list[dict[str, Any]] = []
|
self.chunks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
@@ -337,8 +314,8 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": passages_file.name, # Use relative path (just filename)
|
"path": str(passages_file),
|
||||||
"index_path": offset_file.name, # Use relative path (just filename)
|
"index_path": str(offset_file),
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -453,8 +430,8 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": passages_file.name, # Use relative path (just filename)
|
"path": str(passages_file),
|
||||||
"index_path": offset_file.name, # Use relative path (just filename)
|
"index_path": str(offset_file),
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"built_from_precomputed_embeddings": True,
|
"built_from_precomputed_embeddings": True,
|
||||||
@@ -496,9 +473,7 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
self.passage_manager = PassageManager(
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
|
||||||
)
|
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
@@ -571,6 +546,7 @@ class LeannSearcher:
|
|||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
time.time() - start_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
@@ -578,7 +554,7 @@ class LeannSearcher:
|
|||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
@@ -611,17 +587,12 @@ class LeannSearcher:
|
|||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Cleanup embedding server and other resources."""
|
|
||||||
if hasattr(self.backend_impl, "cleanup"):
|
|
||||||
self.backend_impl.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[dict[str, Any]] = None,
|
llm_config: dict[str, Any] | None = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -637,7 +608,7 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
llm_kwargs: dict[str, Any] | None = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
@@ -665,10 +636,7 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
ask_time = time.time()
|
|
||||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||||
ask_time = time.time() - ask_time
|
|
||||||
logger.info(f" Ask time: {ask_time} seconds")
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import difflib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,12 +17,12 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_models(host: str) -> list[str]:
|
def check_ollama_models() -> list[str]:
|
||||||
"""Check available Ollama models and return a list"""
|
"""Check available Ollama models and return a list"""
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return [model["name"] for model in data.get("models", [])]
|
return [model["name"] for model in data.get("models", [])]
|
||||||
@@ -309,12 +309,10 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
return search_hf_models_fuzzy(query, limit)
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models(host)
|
available_models = check_ollama_models()
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -360,11 +358,7 @@ def validate_model_and_suggest(
|
|||||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||||
|
|
||||||
if suggestions:
|
if suggestions:
|
||||||
error_msg += (
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
"\n\nDid you mean one of these installed models?\n"
|
|
||||||
+ "\nTry to use ollama pull to install the model you need\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
else:
|
else:
|
||||||
@@ -471,7 +465,7 @@ class OllamaChat(LLMInterface):
|
|||||||
requests.get(host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
model_error = validate_model_and_suggest(model, "ollama")
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -491,35 +485,11 @@ class OllamaChat(LLMInterface):
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
full_url = f"{self.host}/api/generate"
|
full_url = f"{self.host}/api/generate"
|
||||||
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
options = kwargs.copy()
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget:
|
|
||||||
# Remove thinking_budget from options as it's not a standard Ollama option
|
|
||||||
options.pop("thinking_budget", None)
|
|
||||||
# Only apply reasoning parameters to models that support it
|
|
||||||
reasoning_supported_models = [
|
|
||||||
"gpt-oss:20b",
|
|
||||||
"gpt-oss:120b",
|
|
||||||
"deepseek-r1",
|
|
||||||
"deepseek-coder",
|
|
||||||
]
|
|
||||||
|
|
||||||
if thinking_budget in ["low", "medium", "high"]:
|
|
||||||
if any(model in self.model.lower() for model in reasoning_supported_models):
|
|
||||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
|
||||||
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False, # Keep it simple for now
|
"stream": False, # Keep it simple for now
|
||||||
"options": options,
|
"options": kwargs,
|
||||||
}
|
}
|
||||||
logger.debug(f"Sending request to Ollama: {payload}")
|
logger.debug(f"Sending request to Ollama: {payload}")
|
||||||
try:
|
try:
|
||||||
@@ -572,41 +542,14 @@ class HFChat(LLMInterface):
|
|||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
logger.info("No GPU detected. Using CPU.")
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
# Load tokenizer and model with timeout protection
|
# Load tokenizer and model
|
||||||
try:
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
import signal
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
def timeout_handler(signum, frame):
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||||
raise TimeoutError("Model download/loading timed out")
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
|
trust_remote_code=True,
|
||||||
# Set timeout for model loading (60 seconds)
|
)
|
||||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
|
||||||
signal.alarm(60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Loading tokenizer for {model_name}...")
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
|
|
||||||
logger.info(f"Loading model {model_name}...")
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
|
||||||
device_map="auto" if self.device != "cpu" else None,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
logger.info(f"Successfully loaded {model_name}")
|
|
||||||
finally:
|
|
||||||
signal.alarm(0) # Cancel the alarm
|
|
||||||
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
logger.error(f"Model loading timed out for {model_name}")
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model {model_name}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Move model to device if not using device_map
|
# Move model to device if not using device_map
|
||||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||||
@@ -685,7 +628,7 @@ class HFChat(LLMInterface):
|
|||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
@@ -710,38 +653,11 @@ class OpenAIChat(LLMInterface):
|
|||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
"temperature": kwargs.get("temperature", 0.7),
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle max_tokens vs max_completion_tokens based on model
|
|
||||||
max_tokens = kwargs.get("max_tokens", 1000)
|
|
||||||
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
|
|
||||||
# o-series models use max_completion_tokens
|
|
||||||
params["max_completion_tokens"] = max_tokens
|
|
||||||
params["temperature"] = 1.0
|
|
||||||
else:
|
|
||||||
# Other models use max_tokens
|
|
||||||
params["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
|
||||||
# Check if this is an o-series model (partial match for model names)
|
|
||||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
|
||||||
if any(model in self.model for model in o_series_models):
|
|
||||||
# Use the correct OpenAI reasoning parameter format
|
|
||||||
params["reasoning_effort"] = thinking_budget
|
|
||||||
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add other kwargs (excluding thinking_budget as it's handled above)
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if k not in ["max_tokens", "temperature", "thinking_budget"]:
|
|
||||||
params[k] = v
|
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -761,7 +677,7 @@ class SimulatedChat(LLMInterface):
|
|||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
@@ -42,23 +41,13 @@ def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
|||||||
|
|
||||||
class LeannCLI:
|
class LeannCLI:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Always use project-local .leann directory (like .git)
|
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||||
self.indexes_dir = Path.cwd() / ".leann" / "indexes"
|
|
||||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Default parser for documents
|
|
||||||
self.node_parser = SentenceSplitter(
|
self.node_parser = SentenceSplitter(
|
||||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Code-optimized parser
|
|
||||||
self.code_parser = SentenceSplitter(
|
|
||||||
chunk_size=512, # Larger chunks for code context
|
|
||||||
chunk_overlap=50, # Less overlap to preserve function boundaries
|
|
||||||
separator="\n", # Split by lines for code
|
|
||||||
paragraph_separator="\n\n", # Preserve logical code blocks
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_index_path(self, index_name: str) -> str:
|
def get_index_path(self, index_name: str) -> str:
|
||||||
index_dir = self.indexes_dir / index_name
|
index_dir = self.indexes_dir / index_name
|
||||||
return str(index_dir / "documents.leann")
|
return str(index_dir / "documents.leann")
|
||||||
@@ -75,11 +64,10 @@ class LeannCLI:
|
|||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
leann build my-docs --docs ./documents # Build index named my-docs
|
leann build my-docs --docs ./documents # Build index named my-docs
|
||||||
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
|
leann search my-docs "query" # Search in my-docs index
|
||||||
leann search my-docs "query" # Search in my-docs index
|
leann ask my-docs "question" # Ask my-docs index
|
||||||
leann ask my-docs "question" # Ask my-docs index
|
leann list # List all stored indexes
|
||||||
leann list # List all stored indexes
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,31 +76,17 @@ Examples:
|
|||||||
# Build command
|
# Build command
|
||||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
build_parser.add_argument("index_name", help="Index name")
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
build_parser.add_argument(
|
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||||
"--docs", type=str, default=".", help="Documents directory (default: current directory)"
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||||
)
|
)
|
||||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||||
build_parser.add_argument(
|
|
||||||
"--embedding-mode",
|
|
||||||
type=str,
|
|
||||||
default="sentence-transformers",
|
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
|
||||||
)
|
|
||||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
build_parser.add_argument("--complexity", type=int, default=64)
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||||
build_parser.add_argument(
|
|
||||||
"--file-types",
|
|
||||||
type=str,
|
|
||||||
help="Comma-separated list of file extensions to include (e.g., '.txt,.pdf,.pptx'). If not specified, uses default supported types.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search command
|
# Search command
|
||||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||||
@@ -122,12 +96,7 @@ Examples:
|
|||||||
search_parser.add_argument("--complexity", type=int, default=64)
|
search_parser.add_argument("--complexity", type=int, default=64)
|
||||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
search_parser.add_argument(
|
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
"--recompute-embeddings",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Recompute embeddings (default: True)",
|
|
||||||
)
|
|
||||||
search_parser.add_argument(
|
search_parser.add_argument(
|
||||||
"--pruning-strategy",
|
"--pruning-strategy",
|
||||||
choices=["global", "local", "proportional"],
|
choices=["global", "local", "proportional"],
|
||||||
@@ -150,138 +119,52 @@ Examples:
|
|||||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
"--recompute-embeddings",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Recompute embeddings (default: True)",
|
|
||||||
)
|
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--pruning-strategy",
|
"--pruning-strategy",
|
||||||
choices=["global", "local", "proportional"],
|
choices=["global", "local", "proportional"],
|
||||||
default="global",
|
default="global",
|
||||||
)
|
)
|
||||||
ask_parser.add_argument(
|
|
||||||
"--thinking-budget",
|
|
||||||
type=str,
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
default=None,
|
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
subparsers.add_parser("list", help="List all indexes")
|
subparsers.add_parser("list", help="List all indexes")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def register_project_dir(self):
|
|
||||||
"""Register current project directory in global registry"""
|
|
||||||
global_registry = Path.home() / ".leann" / "projects.json"
|
|
||||||
global_registry.parent.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
current_dir = str(Path.cwd())
|
|
||||||
|
|
||||||
# Load existing registry
|
|
||||||
projects = []
|
|
||||||
if global_registry.exists():
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(global_registry) as f:
|
|
||||||
projects = json.load(f)
|
|
||||||
except Exception:
|
|
||||||
projects = []
|
|
||||||
|
|
||||||
# Add current directory if not already present
|
|
||||||
if current_dir not in projects:
|
|
||||||
projects.append(current_dir)
|
|
||||||
|
|
||||||
# Save registry
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(global_registry, "w") as f:
|
|
||||||
json.dump(projects, f, indent=2)
|
|
||||||
|
|
||||||
def list_indexes(self):
|
def list_indexes(self):
|
||||||
print("Stored LEANN indexes:")
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
# Get all project directories with .leann
|
if not self.indexes_dir.exists():
|
||||||
global_registry = Path.home() / ".leann" / "projects.json"
|
|
||||||
all_projects = []
|
|
||||||
|
|
||||||
if global_registry.exists():
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(global_registry) as f:
|
|
||||||
all_projects = json.load(f)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Filter to only existing directories with .leann
|
|
||||||
valid_projects = []
|
|
||||||
for project_dir in all_projects:
|
|
||||||
project_path = Path(project_dir)
|
|
||||||
if project_path.exists() and (project_path / ".leann" / "indexes").exists():
|
|
||||||
valid_projects.append(project_path)
|
|
||||||
|
|
||||||
# Add current project if it has .leann but not in registry
|
|
||||||
current_path = Path.cwd()
|
|
||||||
if (current_path / ".leann" / "indexes").exists() and current_path not in valid_projects:
|
|
||||||
valid_projects.append(current_path)
|
|
||||||
|
|
||||||
if not valid_projects:
|
|
||||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
return
|
return
|
||||||
|
|
||||||
total_indexes = 0
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
current_dir = Path.cwd()
|
|
||||||
|
|
||||||
for project_path in valid_projects:
|
if not index_dirs:
|
||||||
indexes_dir = project_path / ".leann" / "indexes"
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
if not indexes_dir.exists():
|
return
|
||||||
continue
|
|
||||||
|
|
||||||
index_dirs = [d for d in indexes_dir.iterdir() if d.is_dir()]
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
if not index_dirs:
|
for i, index_dir in enumerate(index_dirs, 1):
|
||||||
continue
|
index_name = index_dir.name
|
||||||
|
status = "✓" if self.index_exists(index_name) else "✗"
|
||||||
|
|
||||||
# Show project header
|
print(f" {i}. {index_name} [{status}]")
|
||||||
if project_path == current_dir:
|
if self.index_exists(index_name):
|
||||||
print(f"\n📁 Current project ({project_path}):")
|
index_dir / "documents.leann.meta.json"
|
||||||
else:
|
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||||
print(f"\n📂 {project_path}:")
|
1024 * 1024
|
||||||
|
)
|
||||||
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
for index_dir in index_dirs:
|
if index_dirs:
|
||||||
total_indexes += 1
|
example_name = index_dirs[0].name
|
||||||
index_name = index_dir.name
|
print("\nUsage:")
|
||||||
meta_file = index_dir / "documents.leann.meta.json"
|
print(f' leann search {example_name} "your query"')
|
||||||
status = "✓" if meta_file.exists() else "✗"
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
print(f" {total_indexes}. {index_name} [{status}]")
|
def load_documents(self, docs_dir: str):
|
||||||
if status == "✓":
|
|
||||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
|
||||||
1024 * 1024
|
|
||||||
)
|
|
||||||
print(f" Size: {size_mb:.1f} MB")
|
|
||||||
|
|
||||||
if total_indexes > 0:
|
|
||||||
print(f"\nTotal: {total_indexes} indexes across {len(valid_projects)} projects")
|
|
||||||
print("\nUsage (current project only):")
|
|
||||||
|
|
||||||
# Show example from current project
|
|
||||||
current_indexes_dir = current_dir / ".leann" / "indexes"
|
|
||||||
if current_indexes_dir.exists():
|
|
||||||
current_index_dirs = [d for d in current_indexes_dir.iterdir() if d.is_dir()]
|
|
||||||
if current_index_dirs:
|
|
||||||
example_name = current_index_dirs[0].name
|
|
||||||
print(f' leann search {example_name} "your query"')
|
|
||||||
print(f" leann ask {example_name} --interactive")
|
|
||||||
|
|
||||||
def load_documents(self, docs_dir: str, custom_file_types: Optional[str] = None):
|
|
||||||
print(f"Loading documents from {docs_dir}...")
|
print(f"Loading documents from {docs_dir}...")
|
||||||
if custom_file_types:
|
|
||||||
print(f"Using custom file types: {custom_file_types}")
|
|
||||||
|
|
||||||
# Try to use better PDF parsers first
|
# Try to use better PDF parsers first
|
||||||
documents = []
|
documents = []
|
||||||
@@ -313,140 +196,17 @@ Examples:
|
|||||||
documents.extend(default_docs)
|
documents.extend(default_docs)
|
||||||
|
|
||||||
# Load other file types with default reader
|
# Load other file types with default reader
|
||||||
if custom_file_types:
|
other_docs = SimpleDirectoryReader(
|
||||||
# Parse custom file types from comma-separated string
|
docs_dir,
|
||||||
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
|
recursive=True,
|
||||||
# Ensure extensions start with a dot
|
encoding="utf-8",
|
||||||
code_extensions = [ext if ext.startswith(".") else f".{ext}" for ext in code_extensions]
|
required_exts=[".txt", ".md", ".docx"],
|
||||||
else:
|
).load_data(show_progress=True)
|
||||||
# Use default supported file types
|
documents.extend(other_docs)
|
||||||
code_extensions = [
|
|
||||||
# Original document types
|
|
||||||
".txt",
|
|
||||||
".md",
|
|
||||||
".docx",
|
|
||||||
".pptx",
|
|
||||||
# Code files for Claude Code integration
|
|
||||||
".py",
|
|
||||||
".js",
|
|
||||||
".ts",
|
|
||||||
".jsx",
|
|
||||||
".tsx",
|
|
||||||
".java",
|
|
||||||
".cpp",
|
|
||||||
".c",
|
|
||||||
".h",
|
|
||||||
".hpp",
|
|
||||||
".cs",
|
|
||||||
".go",
|
|
||||||
".rs",
|
|
||||||
".rb",
|
|
||||||
".php",
|
|
||||||
".swift",
|
|
||||||
".kt",
|
|
||||||
".scala",
|
|
||||||
".r",
|
|
||||||
".sql",
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".fish",
|
|
||||||
".ps1",
|
|
||||||
".bat",
|
|
||||||
# Config and markup files
|
|
||||||
".json",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".xml",
|
|
||||||
".toml",
|
|
||||||
".ini",
|
|
||||||
".cfg",
|
|
||||||
".conf",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".scss",
|
|
||||||
".less",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
# Data science
|
|
||||||
".ipynb",
|
|
||||||
".R",
|
|
||||||
".py",
|
|
||||||
".jl",
|
|
||||||
]
|
|
||||||
# Try to load other file types, but don't fail if none are found
|
|
||||||
try:
|
|
||||||
other_docs = SimpleDirectoryReader(
|
|
||||||
docs_dir,
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=code_extensions,
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
documents.extend(other_docs)
|
|
||||||
except ValueError as e:
|
|
||||||
if "No files found" in str(e):
|
|
||||||
print("No additional files found for other supported types.")
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
all_texts = []
|
all_texts = []
|
||||||
|
|
||||||
# Define code file extensions for intelligent chunking
|
|
||||||
code_file_exts = {
|
|
||||||
".py",
|
|
||||||
".js",
|
|
||||||
".ts",
|
|
||||||
".jsx",
|
|
||||||
".tsx",
|
|
||||||
".java",
|
|
||||||
".cpp",
|
|
||||||
".c",
|
|
||||||
".h",
|
|
||||||
".hpp",
|
|
||||||
".cs",
|
|
||||||
".go",
|
|
||||||
".rs",
|
|
||||||
".rb",
|
|
||||||
".php",
|
|
||||||
".swift",
|
|
||||||
".kt",
|
|
||||||
".scala",
|
|
||||||
".r",
|
|
||||||
".sql",
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".fish",
|
|
||||||
".ps1",
|
|
||||||
".bat",
|
|
||||||
".json",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".xml",
|
|
||||||
".toml",
|
|
||||||
".ini",
|
|
||||||
".cfg",
|
|
||||||
".conf",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".scss",
|
|
||||||
".less",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
".ipynb",
|
|
||||||
".R",
|
|
||||||
".jl",
|
|
||||||
}
|
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# Check if this is a code file based on source path
|
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||||
source_path = doc.metadata.get("source", "")
|
|
||||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
|
||||||
|
|
||||||
# Use appropriate parser based on file type
|
|
||||||
parser = self.code_parser if is_code_file else self.node_parser
|
|
||||||
nodes = parser.get_nodes_from_documents([doc])
|
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
@@ -459,13 +219,11 @@ Examples:
|
|||||||
index_dir = self.indexes_dir / index_name
|
index_dir = self.indexes_dir / index_name
|
||||||
index_path = self.get_index_path(index_name)
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
print(f"📂 Indexing: {Path(docs_dir).resolve()}")
|
|
||||||
|
|
||||||
if index_dir.exists() and not args.force:
|
if index_dir.exists() and not args.force:
|
||||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||||
return
|
return
|
||||||
|
|
||||||
all_texts = self.load_documents(docs_dir, args.file_types)
|
all_texts = self.load_documents(docs_dir)
|
||||||
if not all_texts:
|
if not all_texts:
|
||||||
print("No documents found")
|
print("No documents found")
|
||||||
return
|
return
|
||||||
@@ -477,7 +235,6 @@ Examples:
|
|||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.complexity,
|
complexity=args.complexity,
|
||||||
is_compact=args.compact,
|
is_compact=args.compact,
|
||||||
@@ -491,9 +248,6 @@ Examples:
|
|||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index built at {index_path}")
|
print(f"Index built at {index_path}")
|
||||||
|
|
||||||
# Register this project directory in global registry
|
|
||||||
self.register_project_dir()
|
|
||||||
|
|
||||||
async def search_documents(self, args):
|
async def search_documents(self, args):
|
||||||
index_name = args.index_name
|
index_name = args.index_name
|
||||||
query = args.query
|
query = args.query
|
||||||
@@ -554,11 +308,6 @@ Examples:
|
|||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
response = chat.ask(
|
||||||
user_input,
|
user_input,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
@@ -567,17 +316,11 @@ Examples:
|
|||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
)
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
else:
|
else:
|
||||||
query = input("Enter your question: ").strip()
|
query = input("Enter your question: ").strip()
|
||||||
if query:
|
if query:
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
response = chat.ask(
|
||||||
query,
|
query,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
@@ -586,7 +329,6 @@ Examples:
|
|||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
)
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -36,7 +35,7 @@ def compute_embeddings(
|
|||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Model name
|
model_name: Model name
|
||||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
batch_size: Batch size for processing
|
batch_size: Batch size for processing
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
@@ -56,8 +55,6 @@ def compute_embeddings(
|
|||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
@@ -368,262 +365,3 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
# Stack numpy arrays
|
# Stack numpy arrays
|
||||||
return np.stack(all_embeddings)
|
return np.stack(all_embeddings)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
|
||||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Compute embeddings using Ollama API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to compute embeddings for
|
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Ollama is running
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
error_msg = (
|
|
||||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
|
||||||
"Please ensure Ollama is running:\n"
|
|
||||||
" • macOS/Linux: ollama serve\n"
|
|
||||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
|
||||||
"Installation: https://ollama.com/download"
|
|
||||||
)
|
|
||||||
raise RuntimeError(error_msg)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
|
||||||
|
|
||||||
# Check if model exists and provide helpful suggestions
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
models = response.json()
|
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
|
||||||
|
|
||||||
# Filter for embedding models (models that support embeddings)
|
|
||||||
embedding_models = []
|
|
||||||
suggested_embedding_models = [
|
|
||||||
"nomic-embed-text",
|
|
||||||
"mxbai-embed-large",
|
|
||||||
"bge-m3",
|
|
||||||
"all-minilm",
|
|
||||||
"snowflake-arctic-embed",
|
|
||||||
]
|
|
||||||
|
|
||||||
for model in model_names:
|
|
||||||
# Check if it's an embedding model (by name patterns or known models)
|
|
||||||
base_name = model.split(":")[0]
|
|
||||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
|
||||||
embedding_models.append(model)
|
|
||||||
|
|
||||||
# Check if model exists (handle versioned names)
|
|
||||||
model_found = any(
|
|
||||||
model_name == name.split(":")[0] or model_name == name for name in model_names
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_found:
|
|
||||||
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
|
||||||
|
|
||||||
# Suggest pulling the model
|
|
||||||
error_msg += "📦 To install this embedding model:\n"
|
|
||||||
error_msg += f" ollama pull {model_name}\n\n"
|
|
||||||
|
|
||||||
# Show available embedding models
|
|
||||||
if embedding_models:
|
|
||||||
error_msg += "✅ Available embedding models:\n"
|
|
||||||
for model in embedding_models[:5]:
|
|
||||||
error_msg += f" • {model}\n"
|
|
||||||
if len(embedding_models) > 5:
|
|
||||||
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
|
||||||
else:
|
|
||||||
error_msg += "💡 Popular embedding models to install:\n"
|
|
||||||
for model in suggested_embedding_models[:3]:
|
|
||||||
error_msg += f" • ollama pull {model}\n"
|
|
||||||
|
|
||||||
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it
|
|
||||||
try:
|
|
||||||
test_response = requests.post(
|
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
|
||||||
)
|
|
||||||
if test_response.status_code != 200:
|
|
||||||
error_msg = (
|
|
||||||
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
|
||||||
f"Please use an embedding model like:\n"
|
|
||||||
)
|
|
||||||
for model in suggested_embedding_models[:3]:
|
|
||||||
error_msg += f" • {model}\n"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
# If test fails, continue anyway - model might still work
|
|
||||||
pass
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.warning(f"Could not verify model existence: {e}")
|
|
||||||
|
|
||||||
# Process embeddings with optimized concurrent processing
|
|
||||||
import requests
|
|
||||||
|
|
||||||
def get_single_embedding(text_idx_tuple):
|
|
||||||
"""Helper function to get embedding for a single text."""
|
|
||||||
text, idx = text_idx_tuple
|
|
||||||
max_retries = 3
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
# Truncate very long texts to avoid API issues
|
|
||||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
|
||||||
|
|
||||||
while retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{host}/api/embeddings",
|
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
embedding = result.get("embedding")
|
|
||||||
|
|
||||||
if embedding is None:
|
|
||||||
raise ValueError(f"No embedding returned for text {idx}")
|
|
||||||
|
|
||||||
return idx, embedding
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
|
|
||||||
return idx, None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry_count >= max_retries - 1:
|
|
||||||
logger.error(f"Failed to get embedding for text {idx}: {e}")
|
|
||||||
return idx, None
|
|
||||||
retry_count += 1
|
|
||||||
|
|
||||||
return idx, None
|
|
||||||
|
|
||||||
# Determine if we should use concurrent processing
|
|
||||||
use_concurrent = (
|
|
||||||
len(texts) > 5 and not is_build
|
|
||||||
) # Don't use concurrent in build mode to avoid overwhelming
|
|
||||||
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
|
|
||||||
|
|
||||||
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
|
|
||||||
failed_indices = []
|
|
||||||
|
|
||||||
if use_concurrent:
|
|
||||||
logger.info(
|
|
||||||
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
|
|
||||||
)
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
||||||
# Submit all tasks
|
|
||||||
future_to_idx = {
|
|
||||||
executor.submit(get_single_embedding, (text, idx)): idx
|
|
||||||
for idx, text in enumerate(texts)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add progress bar for concurrent processing
|
|
||||||
try:
|
|
||||||
if is_build or len(texts) > 10:
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
futures_iterator = tqdm(
|
|
||||||
as_completed(future_to_idx),
|
|
||||||
total=len(texts),
|
|
||||||
desc="Computing Ollama embeddings",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
futures_iterator = as_completed(future_to_idx)
|
|
||||||
except ImportError:
|
|
||||||
futures_iterator = as_completed(future_to_idx)
|
|
||||||
|
|
||||||
# Collect results as they complete
|
|
||||||
for future in futures_iterator:
|
|
||||||
try:
|
|
||||||
idx, embedding = future.result()
|
|
||||||
if embedding is not None:
|
|
||||||
all_embeddings[idx] = embedding
|
|
||||||
else:
|
|
||||||
failed_indices.append(idx)
|
|
||||||
except Exception as e:
|
|
||||||
idx = future_to_idx[future]
|
|
||||||
logger.error(f"Exception for text {idx}: {e}")
|
|
||||||
failed_indices.append(idx)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Sequential processing with progress bar
|
|
||||||
show_progress = is_build or len(texts) > 10
|
|
||||||
|
|
||||||
try:
|
|
||||||
if show_progress:
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
iterator = tqdm(
|
|
||||||
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
iterator = enumerate(texts)
|
|
||||||
except ImportError:
|
|
||||||
iterator = enumerate(texts)
|
|
||||||
|
|
||||||
for idx, text in iterator:
|
|
||||||
result_idx, embedding = get_single_embedding((text, idx))
|
|
||||||
if embedding is not None:
|
|
||||||
all_embeddings[idx] = embedding
|
|
||||||
else:
|
|
||||||
failed_indices.append(idx)
|
|
||||||
|
|
||||||
# Handle failed embeddings
|
|
||||||
if failed_indices:
|
|
||||||
if len(failed_indices) == len(texts):
|
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
|
||||||
|
|
||||||
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
|
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
|
||||||
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
|
||||||
if valid_embedding:
|
|
||||||
embedding_dim = len(valid_embedding)
|
|
||||||
for idx in failed_indices:
|
|
||||||
all_embeddings[idx] = [0.0] * embedding_dim
|
|
||||||
|
|
||||||
# Remove None values and convert to numpy array
|
|
||||||
all_embeddings = [e for e in all_embeddings if e is not None]
|
|
||||||
|
|
||||||
# Convert to numpy array and normalize
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
|
|
||||||
# Normalize embeddings (L2 normalization)
|
|
||||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
||||||
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
|
||||||
|
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
import atexit
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
@@ -184,8 +182,8 @@ class EmbeddingServerManager:
|
|||||||
e.g., "leann_backend_diskann.embedding_server"
|
e.g., "leann_backend_diskann.embedding_server"
|
||||||
"""
|
"""
|
||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: subprocess.Popen | None = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: int | None = None
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
@@ -312,7 +310,6 @@ class EmbeddingServerManager:
|
|||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=None, # Direct to console
|
stdout=None, # Direct to console
|
||||||
stderr=None, # Direct to console
|
stderr=None, # Direct to console
|
||||||
start_new_session=True, # Create new process group for better cleanup
|
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
@@ -354,46 +351,20 @@ class EmbeddingServerManager:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
|
self.server_process.terminate()
|
||||||
# Try terminating the whole process group first
|
|
||||||
try:
|
|
||||||
pgid = os.getpgid(self.server_process.pid)
|
|
||||||
os.killpg(pgid, signal.SIGTERM)
|
|
||||||
except Exception:
|
|
||||||
# Fallback to terminating just the process
|
|
||||||
self.server_process.terminate()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=3)
|
self.server_process.wait(timeout=5)
|
||||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
)
|
)
|
||||||
# Try killing the whole process group
|
self.server_process.kill()
|
||||||
try:
|
|
||||||
pgid = os.getpgid(self.server_process.pid)
|
|
||||||
os.killpg(pgid, signal.SIGKILL)
|
|
||||||
except Exception:
|
|
||||||
# Fallback to killing just the process
|
|
||||||
self.server_process.kill()
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=2)
|
|
||||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
|
||||||
)
|
|
||||||
# Don't hang indefinitely
|
|
||||||
|
|
||||||
# Clean up process resources to prevent resource tracker warnings
|
# Clean up process resources to prevent resource tracker warnings
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=1) # Give it one final chance with timeout
|
self.server_process.wait() # Ensure process is fully cleaned up
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning(
|
|
||||||
f"Process {self.server_process.pid} still hanging after all kill attempts"
|
|
||||||
)
|
|
||||||
# Don't wait indefinitely - just abandon it
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -34,9 +34,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
||||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
|
||||||
) -> int:
|
|
||||||
"""Ensure server is running"""
|
"""Ensure server is running"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -50,7 +48,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -76,7 +74,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
@@ -1,125 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def handle_request(request):
|
|
||||||
if request.get("method") == "initialize":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
elif request.get("method") == "tools/list":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "leann_search",
|
|
||||||
"description": "Search LEANN index",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_name": {"type": "string"},
|
|
||||||
"query": {"type": "string"},
|
|
||||||
"top_k": {"type": "integer", "default": 5},
|
|
||||||
},
|
|
||||||
"required": ["index_name", "query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "leann_ask",
|
|
||||||
"description": "Ask question using LEANN RAG",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_name": {"type": "string"},
|
|
||||||
"question": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["index_name", "question"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "leann_list",
|
|
||||||
"description": "List all LEANN indexes",
|
|
||||||
"inputSchema": {"type": "object", "properties": {}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
elif request.get("method") == "tools/call":
|
|
||||||
tool_name = request["params"]["name"]
|
|
||||||
args = request["params"].get("arguments", {})
|
|
||||||
|
|
||||||
try:
|
|
||||||
if tool_name == "leann_search":
|
|
||||||
cmd = [
|
|
||||||
"leann",
|
|
||||||
"search",
|
|
||||||
args["index_name"],
|
|
||||||
args["query"],
|
|
||||||
"--recompute-embeddings",
|
|
||||||
f"--top-k={args.get('top_k', 5)}",
|
|
||||||
]
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
||||||
|
|
||||||
elif tool_name == "leann_ask":
|
|
||||||
cmd = f'echo "{args["question"]}" | leann ask {args["index_name"]} --recompute-embeddings --llm ollama --model qwen3:8b'
|
|
||||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": result.stdout
|
|
||||||
if result.returncode == 0
|
|
||||||
else f"Error: {result.stderr}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"error": {"code": -1, "message": str(e)},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
for line in sys.stdin:
|
|
||||||
try:
|
|
||||||
request = json.loads(line.strip())
|
|
||||||
response = handle_request(request)
|
|
||||||
if response:
|
|
||||||
print(json.dumps(response))
|
|
||||||
sys.stdout.flush()
|
|
||||||
except Exception as e:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": None,
|
|
||||||
"error": {"code": -1, "message": str(e)},
|
|
||||||
}
|
|
||||||
print(json.dumps(error_response))
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -132,15 +132,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
import msgpack
|
import msgpack
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
context = None
|
|
||||||
socket = None
|
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REQ)
|
socket = context.socket(zmq.REQ)
|
||||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
||||||
socket.setsockopt(zmq.IMMEDIATE, 1)
|
|
||||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||||
|
|
||||||
# Send embedding request
|
# Send embedding request
|
||||||
@@ -152,6 +147,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
response_bytes = socket.recv()
|
response_bytes = socket.recv()
|
||||||
response = msgpack.unpackb(response_bytes)
|
response = msgpack.unpackb(response_bytes)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
# Convert response to numpy array
|
# Convert response to numpy array
|
||||||
if isinstance(response, list) and len(response) > 0:
|
if isinstance(response, list) and len(response) > 0:
|
||||||
return np.array(response, dtype=np.float32)
|
return np.array(response, dtype=np.float32)
|
||||||
@@ -160,10 +158,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||||
finally:
|
|
||||||
if socket:
|
|
||||||
socket.close()
|
|
||||||
# Don't call context.term() - this was causing hangs
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
@@ -175,7 +169,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -197,15 +191,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cleanup(self):
|
def __del__(self):
|
||||||
"""Cleanup resources including embedding server."""
|
"""Ensures the embedding server is stopped when the searcher is destroyed."""
|
||||||
if hasattr(self, "embedding_server_manager"):
|
if hasattr(self, "embedding_server_manager"):
|
||||||
self.embedding_server_manager.stop_server()
|
self.embedding_server_manager.stop_server()
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""Ensures resources are cleaned up when the searcher is destroyed."""
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
# Ignore errors during destruction
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
# 🔥 LEANN Claude Code Integration
|
|
||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
**Step 1:** First, complete the basic LEANN installation following the [📦 Installation guide](../../README.md#installation) in the root README:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install leann
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2:** Install LEANN globally for MCP integration:
|
|
||||||
```bash
|
|
||||||
uv tool install leann-core
|
|
||||||
```
|
|
||||||
|
|
||||||
This makes the `leann` command available system-wide, which `leann_mcp` requires.
|
|
||||||
|
|
||||||
## 🚀 Quick Setup
|
|
||||||
|
|
||||||
Add the LEANN MCP server to Claude Code:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
claude mcp add leann-server -- leann_mcp
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🛠️ Available Tools
|
|
||||||
|
|
||||||
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
|
|
||||||
|
|
||||||
- **`leann_list`** - List all available indexes across your projects
|
|
||||||
- **`leann_search`** - Perform semantic searches across code and documents
|
|
||||||
- **`leann_ask`** - Ask natural language questions and get AI-powered answers from your codebase
|
|
||||||
|
|
||||||
## 🎯 Quick Start Example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build an index for your project (change to your actual path)
|
|
||||||
leann build my-project --docs ./
|
|
||||||
|
|
||||||
# Start Claude Code
|
|
||||||
claude
|
|
||||||
```
|
|
||||||
|
|
||||||
**Try this in Claude Code:**
|
|
||||||
```
|
|
||||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
|
||||||
```
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
## 🧠 How It Works
|
|
||||||
|
|
||||||
The integration consists of three key components working seamlessly together:
|
|
||||||
|
|
||||||
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
|
|
||||||
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
|
|
||||||
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
|
|
||||||
|
|
||||||
## 📁 File Support
|
|
||||||
|
|
||||||
LEANN understands **30+ file types** including:
|
|
||||||
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
|
|
||||||
- **Data**: SQL, YAML, JSON, CSV, XML
|
|
||||||
- **Documentation**: Markdown, TXT, PDF
|
|
||||||
- **And many more!**
|
|
||||||
|
|
||||||
## 💾 Storage & Organization
|
|
||||||
|
|
||||||
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
|
|
||||||
- **Global registry**: Project tracking at `~/.leann/projects.json`
|
|
||||||
- **Multi-project support**: Switch between different codebases seamlessly
|
|
||||||
- **Portable**: Transfer indexes between machines with minimal overhead
|
|
||||||
|
|
||||||
## 🗑️ Uninstalling
|
|
||||||
|
|
||||||
To remove the LEANN MCP server from Claude Code:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
claude mcp remove leann-server
|
|
||||||
```
|
|
||||||
To remove LEANN
|
|
||||||
```
|
|
||||||
uv pip uninstall leann leann-backend-hnsw leann-core
|
|
||||||
```
|
|
||||||
@@ -5,8 +5,11 @@ LEANN is a revolutionary vector database that democratizes personal AI. Transfor
|
|||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Default installation (includes both HNSW and DiskANN backends)
|
# Default installation (HNSW backend, recommended)
|
||||||
uv pip install leann
|
uv pip install leann
|
||||||
|
|
||||||
|
# With DiskANN backend (for large-scale deployments)
|
||||||
|
uv pip install leann[diskann]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
@@ -16,8 +19,8 @@ from leann import LeannBuilder, LeannSearcher, LeannChat
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
# Build an index (choose backend: "hnsw" or "diskann")
|
# Build an index
|
||||||
builder = LeannBuilder(backend_name="hnsw") # or "diskann" for large-scale deployments
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.2.5"
|
version = "0.1.16"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
@@ -24,15 +24,16 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Default installation: core + hnsw + diskann
|
# Default installation: core + hnsw
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core>=0.1.0",
|
"leann-core>=0.1.0",
|
||||||
"leann-backend-hnsw>=0.1.0",
|
"leann-backend-hnsw>=0.1.0",
|
||||||
"leann-backend-diskann>=0.1.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
# All backends now included by default
|
diskann = [
|
||||||
|
"leann-backend-diskann>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Repository = "https://github.com/yichuan-w/LEANN"
|
Repository = "https://github.com/yichuan-w/LEANN"
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ dependencies = [
|
|||||||
"pypdfium2>=4.30.0",
|
"pypdfium2>=4.30.0",
|
||||||
# LlamaIndex core and readers - updated versions
|
# LlamaIndex core and readers - updated versions
|
||||||
"llama-index>=0.12.44",
|
"llama-index>=0.12.44",
|
||||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
@@ -43,7 +43,6 @@ dependencies = [
|
|||||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
"pybind11>=3.0.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -52,7 +51,7 @@ dev = [
|
|||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
"pytest-xdist>=3.0", # For parallel test execution
|
"pytest-xdist>=3.0", # For parallel test execution
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
"ruff>=0.1.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
"pre-commit>=3.5.0",
|
"pre-commit>=3.5.0",
|
||||||
@@ -60,7 +59,7 @@ dev = [
|
|||||||
|
|
||||||
test = [
|
test = [
|
||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
"pytest-timeout>=2.0", # Simple timeout protection for CI
|
"pytest-timeout>=2.0",
|
||||||
"llama-index-core>=0.12.0",
|
"llama-index-core>=0.12.0",
|
||||||
"llama-index-readers-file>=0.4.0",
|
"llama-index-readers-file>=0.4.0",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
@@ -89,7 +88,7 @@ leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = tr
|
|||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py310"
|
||||||
line-length = 100
|
line-length = 100
|
||||||
extend-exclude = [
|
extend-exclude = [
|
||||||
"third_party",
|
"third_party",
|
||||||
@@ -152,7 +151,7 @@ markers = [
|
|||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
]
|
]
|
||||||
timeout = 300 # Simple timeout for CI safety (5 minutes)
|
timeout = 600
|
||||||
addopts = [
|
addopts = [
|
||||||
"-v",
|
"-v",
|
||||||
"--tb=short",
|
"--tb=short",
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ This directory contains automated tests for the LEANN project using pytest.
|
|||||||
|
|
||||||
### `test_readme_examples.py`
|
### `test_readme_examples.py`
|
||||||
Tests the examples shown in README.md:
|
Tests the examples shown in README.md:
|
||||||
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
- The basic example code that users see first
|
||||||
- Import statements work correctly
|
- Import statements work correctly
|
||||||
- Different backend options (HNSW, DiskANN)
|
- Different backend options (HNSW, DiskANN)
|
||||||
- Different LLM configuration options (parametrized for both backends)
|
- Different LLM configuration options
|
||||||
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
|
||||||
|
|
||||||
### `test_basic.py`
|
### `test_basic.py`
|
||||||
Basic functionality tests that verify:
|
Basic functionality tests that verify:
|
||||||
@@ -26,16 +25,6 @@ Tests the document RAG example functionality:
|
|||||||
- Tests error handling with invalid parameters
|
- Tests error handling with invalid parameters
|
||||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||||
|
|
||||||
### `test_diskann_partition.py`
|
|
||||||
Tests DiskANN graph partitioning functionality:
|
|
||||||
- Tests DiskANN index building without partitioning (baseline)
|
|
||||||
- Tests automatic graph partitioning with `is_recompute=True`
|
|
||||||
- Verifies that partition files are created and large files are cleaned up for storage saving
|
|
||||||
- Tests search functionality with partitioned indices
|
|
||||||
- Validates medoid and max_base_norm file generation and usage
|
|
||||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
|
||||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -65,23 +54,15 @@ pytest tests/ -m "not openai"
|
|||||||
|
|
||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
# Run DiskANN partition tests (requires local machine, not CI)
|
|
||||||
pytest tests/test_diskann_partition.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run with specific backend:
|
### Run with specific backend:
|
||||||
```bash
|
```bash
|
||||||
# Test only HNSW backend
|
# Test only HNSW backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||||
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
|
||||||
|
|
||||||
# Test only DiskANN backend
|
# Test only DiskANN backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||||
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
|
||||||
|
|
||||||
# All DiskANN tests (parametrized + specialized partition tests)
|
|
||||||
pytest tests/ -k diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
"""Pytest configuration and fixtures for LEANN tests."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def test_environment():
|
|
||||||
"""Set up test environment variables."""
|
|
||||||
# Mark as test environment to skip memory-intensive operations
|
|
||||||
os.environ["CI"] = "true"
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def cleanup_session():
|
|
||||||
"""Session-level cleanup to ensure no hanging processes."""
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Basic cleanup after all tests
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
current_process = psutil.Process(os.getpid())
|
|
||||||
children = current_process.children(recursive=True)
|
|
||||||
|
|
||||||
for child in children:
|
|
||||||
try:
|
|
||||||
child.terminate()
|
|
||||||
except psutil.NoSuchProcess:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Give them time to terminate gracefully
|
|
||||||
psutil.wait_procs(children, timeout=3)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# Don't fail tests due to cleanup errors
|
|
||||||
pass
|
|
||||||
@@ -1,369 +0,0 @@
|
|||||||
"""
|
|
||||||
Test DiskANN graph partitioning functionality.
|
|
||||||
|
|
||||||
Tests the automatic graph partitioning feature that was implemented to save
|
|
||||||
storage space by partitioning large DiskANN indices and safely deleting
|
|
||||||
redundant files while maintaining search functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
|
||||||
)
|
|
||||||
def test_diskann_without_partition():
|
|
||||||
"""Test DiskANN index building without partition (baseline)."""
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / "test_no_partition.leann")
|
|
||||||
|
|
||||||
# Test data - enough to trigger index building
|
|
||||||
texts = [
|
|
||||||
f"Document {i} discusses topic {i % 10} with detailed analysis of subject {i // 10}."
|
|
||||||
for i in range(500)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build without partition (is_recompute=False)
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
num_neighbors=32,
|
|
||||||
search_list_size=50,
|
|
||||||
is_recompute=False, # No partition
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
# Verify index was created
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
assert index_dir.exists()
|
|
||||||
|
|
||||||
# Check that traditional DiskANN files exist
|
|
||||||
index_prefix = Path(index_path).stem
|
|
||||||
# Core DiskANN files (beam search index may not be created for small datasets)
|
|
||||||
required_files = [
|
|
||||||
f"{index_prefix}_disk.index",
|
|
||||||
f"{index_prefix}_pq_compressed.bin",
|
|
||||||
f"{index_prefix}_pq_pivots.bin",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check all generated files first for debugging
|
|
||||||
generated_files = [f.name for f in index_dir.glob(f"{index_prefix}*")]
|
|
||||||
print(f"Generated files: {generated_files}")
|
|
||||||
|
|
||||||
for required_file in required_files:
|
|
||||||
file_path = index_dir / required_file
|
|
||||||
assert file_path.exists(), f"Required file {required_file} not found"
|
|
||||||
|
|
||||||
# Ensure no partition files exist in non-partition mode
|
|
||||||
partition_files = [f"{index_prefix}_disk_graph.index", f"{index_prefix}_partition.bin"]
|
|
||||||
|
|
||||||
for partition_file in partition_files:
|
|
||||||
file_path = index_dir / partition_file
|
|
||||||
assert not file_path.exists(), (
|
|
||||||
f"Partition file {partition_file} should not exist in non-partition mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test search functionality
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
results = searcher.search("topic 3 analysis", top_k=3)
|
|
||||||
|
|
||||||
assert len(results) > 0
|
|
||||||
assert all(result.score is not None and result.score != float("-inf") for result in results)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
|
||||||
)
|
|
||||||
def test_diskann_with_partition():
|
|
||||||
"""Test DiskANN index building with automatic graph partitioning."""
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / "test_with_partition.leann")
|
|
||||||
|
|
||||||
# Test data - enough to trigger partitioning
|
|
||||||
texts = [
|
|
||||||
f"Document {i} explores subject {i % 15} with comprehensive coverage of area {i // 15}."
|
|
||||||
for i in range(500)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build with partition (is_recompute=True)
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
num_neighbors=32,
|
|
||||||
search_list_size=50,
|
|
||||||
is_recompute=True, # Enable automatic partitioning
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
# Verify index was created
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
assert index_dir.exists()
|
|
||||||
|
|
||||||
# Check that partition files exist
|
|
||||||
index_prefix = Path(index_path).stem
|
|
||||||
partition_files = [
|
|
||||||
f"{index_prefix}_disk_graph.index", # Partitioned graph
|
|
||||||
f"{index_prefix}_partition.bin", # Partition metadata
|
|
||||||
f"{index_prefix}_pq_compressed.bin",
|
|
||||||
f"{index_prefix}_pq_pivots.bin",
|
|
||||||
]
|
|
||||||
|
|
||||||
for partition_file in partition_files:
|
|
||||||
file_path = index_dir / partition_file
|
|
||||||
assert file_path.exists(), f"Expected partition file {partition_file} not found"
|
|
||||||
|
|
||||||
# Check that large files were cleaned up (storage saving goal)
|
|
||||||
large_files = [f"{index_prefix}_disk.index", f"{index_prefix}_disk_beam_search.index"]
|
|
||||||
|
|
||||||
for large_file in large_files:
|
|
||||||
file_path = index_dir / large_file
|
|
||||||
assert not file_path.exists(), (
|
|
||||||
f"Large file {large_file} should have been deleted for storage saving"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify required auxiliary files for partition mode exist
|
|
||||||
required_files = [
|
|
||||||
f"{index_prefix}_disk.index_medoids.bin",
|
|
||||||
f"{index_prefix}_disk.index_max_base_norm.bin",
|
|
||||||
]
|
|
||||||
|
|
||||||
for req_file in required_files:
|
|
||||||
file_path = index_dir / req_file
|
|
||||||
assert file_path.exists(), (
|
|
||||||
f"Required auxiliary file {req_file} missing for partition mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
|
||||||
)
|
|
||||||
def test_diskann_partition_search_functionality():
|
|
||||||
"""Test that search works correctly with partitioned indices."""
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / "test_partition_search.leann")
|
|
||||||
|
|
||||||
# Create diverse test data
|
|
||||||
texts = [
|
|
||||||
"LEANN is a storage-efficient approximate nearest neighbor search system.",
|
|
||||||
"Graph partitioning helps reduce memory usage in large scale vector search.",
|
|
||||||
"DiskANN provides high-performance disk-based approximate nearest neighbor search.",
|
|
||||||
"Vector embeddings enable semantic search over unstructured text data.",
|
|
||||||
"Approximate nearest neighbor algorithms trade accuracy for speed and storage.",
|
|
||||||
] * 100 # Repeat to get enough data
|
|
||||||
|
|
||||||
# Build with partitioning
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
is_recompute=True, # Enable partitioning
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
# Test search with partitioned index
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
|
|
||||||
# Test various queries
|
|
||||||
test_queries = [
|
|
||||||
("vector search algorithms", 5),
|
|
||||||
("LEANN storage efficiency", 3),
|
|
||||||
("graph partitioning memory", 4),
|
|
||||||
("approximate nearest neighbor", 7),
|
|
||||||
]
|
|
||||||
|
|
||||||
for query, top_k in test_queries:
|
|
||||||
results = searcher.search(query, top_k=top_k)
|
|
||||||
|
|
||||||
# Verify search results
|
|
||||||
assert len(results) == top_k, f"Expected {top_k} results for query '{query}'"
|
|
||||||
assert all(result.score is not None for result in results), (
|
|
||||||
"All results should have scores"
|
|
||||||
)
|
|
||||||
assert all(result.score != float("-inf") for result in results), (
|
|
||||||
"No result should have -inf score"
|
|
||||||
)
|
|
||||||
assert all(result.text is not None for result in results), (
|
|
||||||
"All results should have text"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Scores should be in descending order (higher similarity first)
|
|
||||||
scores = [result.score for result in results]
|
|
||||||
assert scores == sorted(scores, reverse=True), (
|
|
||||||
"Results should be sorted by score descending"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
|
||||||
)
|
|
||||||
def test_diskann_medoid_and_norm_files():
|
|
||||||
"""Test that medoid and max_base_norm files are correctly generated and used."""
|
|
||||||
import struct
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / "test_medoid_norm.leann")
|
|
||||||
|
|
||||||
# Small but sufficient dataset
|
|
||||||
texts = [f"Test document {i} with content about subject {i % 10}." for i in range(200)]
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
is_recompute=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
index_prefix = Path(index_path).stem
|
|
||||||
|
|
||||||
# Test medoids file
|
|
||||||
medoids_file = index_dir / f"{index_prefix}_disk.index_medoids.bin"
|
|
||||||
assert medoids_file.exists(), "Medoids file should be generated"
|
|
||||||
|
|
||||||
# Read and validate medoids file format
|
|
||||||
with open(medoids_file, "rb") as f:
|
|
||||||
nshards = struct.unpack("<I", f.read(4))[0]
|
|
||||||
one_val = struct.unpack("<I", f.read(4))[0]
|
|
||||||
medoid_id = struct.unpack("<I", f.read(4))[0]
|
|
||||||
|
|
||||||
assert nshards == 1, "Single-shot build should have 1 shard"
|
|
||||||
assert one_val == 1, "Expected value should be 1"
|
|
||||||
assert medoid_id >= 0, "Medoid ID should be valid (not hardcoded 0)"
|
|
||||||
|
|
||||||
# Test max_base_norm file
|
|
||||||
norm_file = index_dir / f"{index_prefix}_disk.index_max_base_norm.bin"
|
|
||||||
assert norm_file.exists(), "Max base norm file should be generated"
|
|
||||||
|
|
||||||
# Read and validate norm file
|
|
||||||
with open(norm_file, "rb") as f:
|
|
||||||
npts = struct.unpack("<I", f.read(4))[0]
|
|
||||||
ndims = struct.unpack("<I", f.read(4))[0]
|
|
||||||
norm_val = struct.unpack("<f", f.read(4))[0]
|
|
||||||
|
|
||||||
assert npts == 1, "Should have 1 norm point"
|
|
||||||
assert ndims == 1, "Should have 1 dimension"
|
|
||||||
assert norm_val > 0, "Norm value should be positive"
|
|
||||||
assert norm_val != float("inf"), "Norm value should be finite"
|
|
||||||
|
|
||||||
# Test that search works with these files
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
results = searcher.search("test subject", top_k=3)
|
|
||||||
|
|
||||||
# Verify that scores are not -inf (which indicates norm file was loaded correctly)
|
|
||||||
assert len(results) > 0
|
|
||||||
assert all(result.score != float("-inf") for result in results), (
|
|
||||||
"Scores should not be -inf when norm file is correct"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip performance comparison in CI - requires significant compute time",
|
|
||||||
)
|
|
||||||
def test_diskann_vs_hnsw_performance():
|
|
||||||
"""Compare DiskANN (with partition) vs HNSW performance."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
# Test data
|
|
||||||
texts = [
|
|
||||||
f"Performance test document {i} covering topic {i % 20} in detail." for i in range(1000)
|
|
||||||
]
|
|
||||||
query = "performance topic test"
|
|
||||||
|
|
||||||
# Test DiskANN with partitioning
|
|
||||||
diskann_path = str(Path(temp_dir) / "perf_diskann.leann")
|
|
||||||
diskann_builder = LeannBuilder(
|
|
||||||
backend_name="diskann",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
is_recompute=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
diskann_builder.add_text(text)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
diskann_builder.build_index(diskann_path)
|
|
||||||
|
|
||||||
# Test HNSW
|
|
||||||
hnsw_path = str(Path(temp_dir) / "perf_hnsw.leann")
|
|
||||||
hnsw_builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
is_recompute=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
hnsw_builder.add_text(text)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
hnsw_builder.build_index(hnsw_path)
|
|
||||||
|
|
||||||
# Compare search performance
|
|
||||||
diskann_searcher = LeannSearcher(diskann_path)
|
|
||||||
hnsw_searcher = LeannSearcher(hnsw_path)
|
|
||||||
|
|
||||||
# Warm up searches
|
|
||||||
diskann_searcher.search(query, top_k=5)
|
|
||||||
hnsw_searcher.search(query, top_k=5)
|
|
||||||
|
|
||||||
# Timed searches
|
|
||||||
start_time = time.time()
|
|
||||||
diskann_results = diskann_searcher.search(query, top_k=10)
|
|
||||||
diskann_search_time = time.time() - start_time
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
hnsw_results = hnsw_searcher.search(query, top_k=10)
|
|
||||||
hnsw_search_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Basic assertions
|
|
||||||
assert len(diskann_results) == 10
|
|
||||||
assert len(hnsw_results) == 10
|
|
||||||
assert all(r.score != float("-inf") for r in diskann_results)
|
|
||||||
assert all(r.score != float("-inf") for r in hnsw_results)
|
|
||||||
|
|
||||||
# Performance ratio (informational)
|
|
||||||
if hnsw_search_time > 0:
|
|
||||||
speed_ratio = hnsw_search_time / diskann_search_time
|
|
||||||
print(f"DiskANN search time: {diskann_search_time:.4f}s")
|
|
||||||
print(f"HNSW search time: {hnsw_search_time:.4f}s")
|
|
||||||
print(f"DiskANN is {speed_ratio:.2f}x faster than HNSW")
|
|
||||||
@@ -10,9 +10,8 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
def test_readme_basic_example():
|
||||||
def test_readme_basic_example(backend_name):
|
"""Test the basic example from README.md."""
|
||||||
"""Test the basic example from README.md with both backends."""
|
|
||||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||||
@@ -22,18 +21,18 @@ def test_readme_basic_example(backend_name):
|
|||||||
from leann.api import SearchResult
|
from leann.api import SearchResult
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
INDEX_PATH = str(Path(temp_dir) / f"demo_{backend_name}.leann")
|
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
|
||||||
|
|
||||||
# Build an index
|
# Build an index
|
||||||
# In CI, use a smaller model to avoid memory issues
|
# In CI, use a smaller model to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
if os.environ.get("CI") == "true":
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=backend_name,
|
backend_name="hnsw",
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
||||||
dimensions=384, # Smaller dimensions
|
dimensions=384, # Smaller dimensions
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder = LeannBuilder(backend_name=backend_name)
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
@@ -53,9 +52,6 @@ def test_readme_basic_example(backend_name):
|
|||||||
# Verify search results
|
# Verify search results
|
||||||
assert len(results) > 0
|
assert len(results) > 0
|
||||||
assert isinstance(results[0], SearchResult)
|
assert isinstance(results[0], SearchResult)
|
||||||
assert results[0].score != float("-inf"), (
|
|
||||||
f"should return valid scores, got {results[0].score}"
|
|
||||||
)
|
|
||||||
# The second text about banana-crocodile should be more relevant
|
# The second text about banana-crocodile should be more relevant
|
||||||
assert "banana" in results[0].text or "crocodile" in results[0].text
|
assert "banana" in results[0].text or "crocodile" in results[0].text
|
||||||
|
|
||||||
@@ -114,31 +110,26 @@ def test_backend_options():
|
|||||||
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
def test_llm_config_simulated():
|
||||||
def test_llm_config_simulated(backend_name):
|
"""Test simulated LLM configuration option."""
|
||||||
"""Test simulated LLM configuration option with both backends."""
|
|
||||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||||
|
|
||||||
# Skip DiskANN tests in CI due to hardware requirements
|
|
||||||
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
|
||||||
pytest.skip("Skip DiskANN tests in CI - requires specific hardware and large memory")
|
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannChat
|
from leann import LeannBuilder, LeannChat
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Build a simple index
|
# Build a simple index
|
||||||
index_path = str(Path(temp_dir) / f"test_{backend_name}.leann")
|
index_path = str(Path(temp_dir) / "test.leann")
|
||||||
# Use smaller model in CI to avoid memory issues
|
# Use smaller model in CI to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
if os.environ.get("CI") == "true":
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=backend_name,
|
backend_name="hnsw",
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
dimensions=384,
|
dimensions=384,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder = LeannBuilder(backend_name=backend_name)
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
builder.add_text("Test document for LLM testing")
|
builder.add_text("Test document for LLM testing")
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user