Compare commits
28 Commits
readme-pol
...
perf-build
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d084f678c | ||
|
|
54155e8b10 | ||
|
|
5259ace111 | ||
|
|
48ea5566e9 | ||
|
|
3f8b6c5bbd | ||
|
|
725b32e74f | ||
|
|
d0b71f393f | ||
|
|
8a92efdae3 | ||
|
|
019cdce2e8 | ||
|
|
b64aa54fac | ||
|
|
c0d040f9d4 | ||
|
|
32364320f8 | ||
|
|
f47f76d6d7 | ||
|
|
1dc3923b53 | ||
|
|
7e226a51c9 | ||
|
|
f4998bb316 | ||
|
|
7522de1d41 | ||
|
|
15f8bd1cc9 | ||
|
|
34c71c072d | ||
|
|
6d2149c503 | ||
|
|
043b0bf69d | ||
|
|
9b07e392c6 | ||
|
|
e60fad8c73 | ||
|
|
19c1b182c3 | ||
|
|
49edea780c | ||
|
|
12ef5a1900 | ||
|
|
d21a134b2a | ||
|
|
1cd809aa41 |
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,9 +1,9 @@
|
||||
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
||||
path = packages/leann-backend-diskann/third_party/DiskANN
|
||||
url = https://github.com/yichuan520030910320/DiskANN.git
|
||||
url = https://github.com/yichuan-w/DiskANN.git
|
||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||
path = packages/leann-backend-hnsw/third_party/faiss
|
||||
url = https://github.com/yichuan520030910320/faiss.git
|
||||
url = https://github.com/yichuan-w/faiss.git
|
||||
[submodule "packages/leann-backend-hnsw/third_party/msgpack-c"]
|
||||
path = packages/leann-backend-hnsw/third_party/msgpack-c
|
||||
url = https://github.com/msgpack/msgpack-c.git
|
||||
|
||||
336
README.md
336
README.md
@@ -1,101 +1,86 @@
|
||||
<h1 align="center">🚀 LEANN: A Low-Storage Vector Index</h1>
|
||||
<p align="center">
|
||||
<img src="assets/logo-text.png" alt="LEANN Logo" width="400">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
</p>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
|
||||
## Why LEANN?
|
||||
|
||||
<p align="center">
|
||||
<strong>💾 Extreme Storage Saving • 🔒 100% Private • 📚 RAG Everything • ⚡ Easy & Accurate</strong>
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-quick-start">Quick Start</a> •
|
||||
<a href="#-features">Features</a> •
|
||||
<a href="#-benchmarks">Benchmarks</a> •
|
||||
<a href="https://arxiv.org/abs/2506.08276" target="_blank">Paper</a>
|
||||
</p>
|
||||
**The numbers speak for themselves:** Index 60 million Wikipedia articles in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks below ↓](#storage-usage-comparison)
|
||||
|
||||
---
|
||||
## Why This Matters
|
||||
|
||||
## 🌟 What is LEANN-RAG?
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
**LEANN-RAG** is a lightweight, locally deployable **Retrieval-Augmented Generation (RAG)** engine designed for personal devices. It combines **compact storage**, **clean usability**, and **privacy-by-design**, making it easy to build personalized retrieval systems over your own data — emails, notes, documents, chats, or anything else.
|
||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||
|
||||
Unlike traditional vector databases that rely on massive embedding storage, LEANN reduces storage needs dramatically by using **graph-based recomputation** and **pruned HNSW search**, while maintaining responsive and reliable performance — all without sending any data to the cloud.
|
||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||
|
||||
---
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## 🔥 Key Highlights
|
||||
|
||||
### 💾 1. Extreme Storage Efficiency
|
||||
LEANN reduces storage usage by **up to 97%** compared to conventional vector DBs (e.g., FAISS), by storing only pruned graph structures and computing embeddings at query time.
|
||||
> For example: 60M chunks can be indexed in just **6GB**, compared to **200GB+** with dense storage.
|
||||
|
||||
### 🔒 2. Fully Private, Cloud-Free
|
||||
LEANN runs entirely locally. No cloud services, no API keys, and no risk of leaking sensitive data.
|
||||
> Converse with your own files **without compromising privacy**.
|
||||
|
||||
### 🧠 3. RAG Everything
|
||||
Build truly personalized assistants by querying over **your own** chat logs, email archives, browser history, or agent memory.
|
||||
> LEANN makes it easy to integrate personal context into RAG workflows.
|
||||
|
||||
### ⚡ 4. Easy, Accurate, and Fast
|
||||
LEANN is designed to be **easy to install**, with a **clean API** and minimal setup. It runs efficiently on consumer hardware without sacrificing retrieval accuracy.
|
||||
> One command to install, one click to run.
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Why Choose LEANN?
|
||||
|
||||
Traditional RAG systems often require trade-offs between storage, privacy, and usability. **LEANN-RAG aims to simplify the stack** with a more practical design:
|
||||
|
||||
- ✅ **No embedding storage** — compute on demand, save disk space
|
||||
- ✅ **Low memory footprint** — lightweight and hardware-friendly
|
||||
- ✅ **Privacy-first** — 100% local, no network dependency
|
||||
- ✅ **Simple to use** — developer-friendly API and seamless setup
|
||||
|
||||
> 📄 For more details, see our [academic paper](https://arxiv.org/abs/2506.08276)
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Installation
|
||||
## Quick Start in 1 minute
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
export CC=$(brew --prefix llvm)/bin/clang
|
||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (Ubuntu/Debian):**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
|
||||
# Or add DiskANN backend if you want to test more options
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Ollama Setup (Optional for Local LLM):**
|
||||
|
||||
*We support both hf-transformers and Ollama for local LLMs. Ollama is recommended for faster performance.*
|
||||
|
||||
*macOS:*
|
||||
|
||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||
```bash
|
||||
# Install Ollama
|
||||
brew install ollama
|
||||
|
||||
```bash
|
||||
# Pull a lightweight model (recommended for consumer hardware)
|
||||
ollama pull llama3.2:1b
|
||||
|
||||
# For better performance but higher memory usage
|
||||
ollama pull llama3.2:3b
|
||||
```
|
||||
|
||||
*Linux:*
|
||||
@@ -108,18 +93,17 @@ ollama serve &
|
||||
|
||||
# Pull a lightweight model (recommended for consumer hardware)
|
||||
ollama pull llama3.2:1b
|
||||
|
||||
# For better performance but higher memory usage
|
||||
ollama pull llama3.2:3b
|
||||
```
|
||||
|
||||
**Note:** For Hugging Face models >1B parameters, you may encounter OOM errors on consumer hardware. Consider using smaller models like Qwen3-0.6B or switch to Ollama for better memory management.
|
||||
You can also replace `llama3.2:1b` to `deepseek-r1:1.5b` or `qwen3:4b` for better performance but higher memory usage.
|
||||
|
||||
### 30-Second Example
|
||||
Try it out in [**demo.ipynb**](demo.ipynb)
|
||||
## Dead Simple API
|
||||
|
||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# 1. Build index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
@@ -128,63 +112,45 @@ builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your macbook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("C++ programming languages", top_k=2, recompute_beighbor_embeddings=True)
|
||||
print(results)
|
||||
```
|
||||
|
||||
### Run the Demo (support .pdf,.txt,.docx, .pptx, .csv, .md etc)
|
||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||
|
||||
[Try the interactive demo →](demo.ipynb)
|
||||
|
||||
## Wild Things You Can Do
|
||||
|
||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||
|
||||
### Process Any Documents (.pdf, .txt, .md)
|
||||
|
||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents.
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
or you want to use python
|
||||
|
||||
```bash
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
```
|
||||
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
||||
|
||||
This demo showcases how to build a RAG system for PDF/md documents using Leann.
|
||||
Uses Ollama `qwen3:8b` by default. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||
|
||||
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
||||
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||
|
||||
|
||||
|
||||
## ✨ Features
|
||||
|
||||
### 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
### 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
## Applications on your MacBook
|
||||
|
||||
### 📧 Lightweight RAG on your Apple Mail
|
||||
|
||||
LEANN can create a searchable index of your Apple Mail emails, allowing you to query your email history using natural language.
|
||||
|
||||
#### Quick Start
|
||||
### Search Your Entire Life
|
||||
```bash
|
||||
python examples/mail_reader_leann.py
|
||||
# "What did my boss say about the Christmas party last year?"
|
||||
# "Find all emails from my mom about birthday plans"
|
||||
```
|
||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
@@ -193,52 +159,45 @@ LEANN can create a searchable index of your Apple Mail emails, allowing you to q
|
||||
# Use default mail path (works for most macOS setups)
|
||||
python examples/mail_reader_leann.py
|
||||
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||
|
||||
# embedd and search all of your email(this may take a long preprocessing time but it will encode all your emails)
|
||||
# Process all emails (may take time but indexes everything)
|
||||
python examples/mail_reader_leann.py --max-emails -1
|
||||
|
||||
# Limit number of emails processed (useful for testing)
|
||||
python examples/mail_reader_leann.py --max-emails 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/mail_reader_leann.py --query "Whats the number of class recommend to take per semester for incoming EECS students"
|
||||
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
#### Example Queries
|
||||
|
||||
<details>
|
||||
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once the index is built, you can ask questions like:
|
||||
- "Find emails from my boss about deadlines"
|
||||
- "What did John say about the project timeline?"
|
||||
- "Show me emails about travel expenses"
|
||||
|
||||
</details>
|
||||
|
||||
### 🌐 Lightweight RAG on your Google Chrome History
|
||||
|
||||
LEANN can create a searchable index of your Chrome browser history, allowing you to query your browsing history using natural language.
|
||||
|
||||
#### Quick Start
|
||||
### Time Machine for the Web
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py
|
||||
# "What was that AI paper I read last month?"
|
||||
# "Show me all the cooking videos I watched"
|
||||
```
|
||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
|
||||
Note you need to quit google right now to successfully run this.
|
||||
|
||||
```bash
|
||||
# Use default Chrome profile (auto-finds all profiles) and recommand method to run this because usually default file is enough
|
||||
# Use default Chrome profile (auto-finds all profiles)
|
||||
python examples/google_history_reader_leann.py
|
||||
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||
|
||||
@@ -247,17 +206,12 @@ python examples/google_history_reader_leann.py --max-entries 500
|
||||
|
||||
# Run a single query
|
||||
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||
|
||||
# Use only a specific profile (disable auto-find)
|
||||
python examples/google_history_reader_leann.py --chrome-profile "~/Library/Application Support/Google/Chrome/Default" --no-auto-find-profiles
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Finding Your Chrome Profile
|
||||
|
||||
<details>
|
||||
<summary><strong>🔍 Click to expand: How to find your Chrome profile</strong></summary>
|
||||
<summary><strong>📋 Click to expand: How to find your Chrome profile</strong></summary>
|
||||
|
||||
The default Chrome profile path is configured for a typical macOS setup. If you need to find your specific Chrome profile:
|
||||
|
||||
@@ -272,12 +226,11 @@ The default Chrome profile path is configured for a typical macOS setup. If you
|
||||
|
||||
</details>
|
||||
|
||||
#### Example Queries
|
||||
|
||||
<details>
|
||||
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once the index is built, you can ask questions like:
|
||||
|
||||
- "What websites did I visit about machine learning?"
|
||||
- "Find my search history about programming"
|
||||
- "What YouTube videos did I watch recently?"
|
||||
@@ -285,12 +238,13 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### WeChat Detective
|
||||
|
||||
### 💬 Lightweight RAG on your WeChat History
|
||||
|
||||
LEANN can create a searchable index of your WeChat chat history, allowing you to query your conversations using natural language.
|
||||
|
||||
#### Prerequisites
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py
|
||||
# "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB.** Search years of chat history in any language.
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
@@ -302,11 +256,8 @@ sudo packages/wechat-exporter/wechattweak-cli install
|
||||
```
|
||||
|
||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||
|
||||
</details>
|
||||
|
||||
#### Quick Start
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
|
||||
@@ -325,84 +276,60 @@ python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Example Queries
|
||||
|
||||
<details>
|
||||
<summary><strong>💬 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once the index is built, you can ask questions like:
|
||||
|
||||
- "我想买魔术师约翰逊的球衣,给我一些对应聊天记录?" (Chinese: Show me chat records about buying Magic Johnson's jersey)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## ⚡ Performance Comparison
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
### LEANN vs Faiss HNSW
|
||||
<p align="center">
|
||||
<img src="assets/arch.png" alt="LEANN Architecture" width="800">
|
||||
</p>
|
||||
|
||||
We benchmarked LEANN against the popular Faiss HNSW implementation to demonstrate the significant memory and storage savings our approach provides:
|
||||
**The magic:** Most vector DBs store every single embedding (expensive). LEANN stores a pruned graph structure (cheap) and recomputes embeddings only when needed (fast).
|
||||
|
||||
**Core techniques:**
|
||||
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Run the comparison yourself:
|
||||
```bash
|
||||
# Run the comparison benchmark
|
||||
python examples/compare_faiss_vs_leann.py
|
||||
```
|
||||
|
||||
#### 🎯 Results Summary
|
||||
| System | Storage |
|
||||
|--------|---------|
|
||||
| FAISS HNSW | 5.5 MB |
|
||||
| LEANN | 0.5 MB |
|
||||
| **Savings** | **91%** |
|
||||
|
||||
| Metric | Faiss HNSW | LEANN HNSW | **Improvement** |
|
||||
|--------|------------|-------------|-----------------|
|
||||
| **Storage Size** | 5.5 MB | 0.5 MB | **11.4x smaller** (5.0 MB saved) |
|
||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||
|
||||
#### 📈 Key Takeaways
|
||||
|
||||
|
||||
- **💾 Storage Optimization**: LEANN requires **91% less storage** for the same dataset
|
||||
|
||||
- **⚖️ Fair Comparison**: Both systems tested on identical hardware with the same 2,573 document dataset and the same embedding model and chunk method
|
||||
|
||||
> **Note**: Results may vary based on dataset size, hardware configuration, and query patterns. The comparison excludes text storage to focus purely on index structures.
|
||||
|
||||
|
||||
|
||||
*Benchmark results obtained on Apple Silicon with consistent environmental conditions*
|
||||
|
||||
## 📊 Benchmarks
|
||||
|
||||
### How to Reproduce Evaluation Results
|
||||
|
||||
Reproducing our benchmarks is straightforward. The evaluation script is designed to be self-contained, automatically downloading all necessary data on its first run.
|
||||
|
||||
#### 1. Environment Setup
|
||||
|
||||
First, ensure you have followed the installation instructions in the [Quick Start](#-quick-start) section. This will install all core dependencies.
|
||||
|
||||
Next, install the optional development dependencies, which include the `huggingface-hub` library required for automatic data download:
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
# This command installs all development dependencies
|
||||
uv pip install -e ".[dev]"
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
```
|
||||
|
||||
#### 2. Run the Evaluation
|
||||
|
||||
Simply run the evaluation script. The first time you run it, it will detect that the data is missing, download it from Hugging Face Hub, and then proceed with the evaluation.
|
||||
|
||||
**To evaluate the DPR dataset:**
|
||||
```bash
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann
|
||||
```
|
||||
|
||||
**To evaluate the RPJ-Wiki dataset:**
|
||||
```bash
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index
|
||||
```
|
||||
|
||||
The script will print the recall and search time for each query, followed by the average results.
|
||||
The evaluation script downloads data automatically on first run.
|
||||
|
||||
### Storage Usage Comparison
|
||||
|
||||
@@ -429,13 +356,6 @@ The script will print the recall and search time for each query, followed by the
|
||||
|
||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||
|
||||
|
||||
## 🏗️ Architecture
|
||||
|
||||
<p align="center">
|
||||
<img src="asset/arch.png" alt="LEANN Architecture" width="800">
|
||||
</p>
|
||||
|
||||
## 🔬 Paper
|
||||
|
||||
If you find Leann useful, please cite:
|
||||
@@ -454,6 +374,28 @@ If you find Leann useful, please cite:
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ Features
|
||||
|
||||
### 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
|
||||
### 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 78 KiB After Width: | Height: | Size: 78 KiB |
BIN
assets/effects.png
Normal file
BIN
assets/effects.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 339 KiB |
BIN
assets/logo-text.png
Normal file
BIN
assets/logo-text.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 818 KiB |
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 276 KiB |
@@ -197,8 +197,8 @@ class WeChatHistoryReader(BaseReader):
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
max_length: Maximum length for concatenated message groups
|
||||
time_window_minutes: Time window in minutes to group messages together
|
||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||
overlap_messages: Number of messages to overlap between consecutive groups
|
||||
|
||||
Returns:
|
||||
@@ -230,8 +230,8 @@ class WeChatHistoryReader(BaseReader):
|
||||
if not readable_text.strip():
|
||||
continue
|
||||
|
||||
# Check time window constraint
|
||||
if last_timestamp is not None and create_time > 0:
|
||||
# Check time window constraint (only if time_window_minutes != -1)
|
||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||
if time_diff_minutes > time_window_minutes:
|
||||
# Time gap too large, start new group
|
||||
@@ -250,9 +250,9 @@ class WeChatHistoryReader(BaseReader):
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
# Check length constraint
|
||||
# Check length constraint (only if max_length != -1)
|
||||
message_length = len(readable_text)
|
||||
if current_length + message_length > max_length and current_group:
|
||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||
# Current group would exceed max length, save it and start new
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
@@ -431,9 +431,9 @@ Contact: {contact_name}
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=2 # Keep 2 messages overlap between groups
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||
)
|
||||
|
||||
# Create documents from concatenated groups
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
@@ -78,7 +78,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
text_splitter = SentenceSplitter(chunk_size=128, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
@@ -224,7 +224,7 @@ async def query_leann_index(index_path: str, query: str):
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=64,
|
||||
complexity=128,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||
# DiskANN will handle everything itself, including compiling Python bindings
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
|
||||
@@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
@@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
|
||||
@@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
|
||||
# Load label map
|
||||
passages_dir = Path(meta_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
if label_map_file.exists():
|
||||
import pickle
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
||||
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
||||
|
||||
class LazyPassageLoader(SimplePassageLoader):
|
||||
def __init__(self, passage_manager, label_map):
|
||||
def __init__(self, passage_manager):
|
||||
self.passage_manager = passage_manager
|
||||
self.label_map = label_map
|
||||
# Initialize parent with empty data
|
||||
super().__init__({})
|
||||
|
||||
@@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
||||
"""Get passage by ID with lazy loading"""
|
||||
try:
|
||||
int_id = int(passage_id)
|
||||
if int_id in self.label_map:
|
||||
string_id = self.label_map[int_id]
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
string_id = str(int_id)
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
if passage_data and passage_data.get("text"):
|
||||
return {"text": passage_data["text"]}
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
|
||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_map)
|
||||
return len(self.passage_manager.global_offset_map)
|
||||
|
||||
def keys(self):
|
||||
return self.label_map.keys()
|
||||
return self.passage_manager.global_offset_map.keys()
|
||||
|
||||
loader = LazyPassageLoader(passage_manager, label_map)
|
||||
loader = LazyPassageLoader(passage_manager)
|
||||
loader._meta_path = meta_file
|
||||
return loader
|
||||
|
||||
@@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
# Load passages directly by their sequential IDs
|
||||
passages_data = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
passages_data[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
|
||||
@@ -8,9 +8,12 @@ version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# 关键:简化的 CMake 路径
|
||||
# Key: simplified CMake path
|
||||
cmake.source-dir = "third_party/DiskANN"
|
||||
# 关键:Python 包在根目录,路径完全匹配
|
||||
# Key: Python package in root directory, paths match exactly
|
||||
wheel.packages = ["leann_backend_diskann"]
|
||||
# 使用默认的 redirect 模式
|
||||
editable.mode = "redirect"
|
||||
# Use default redirect mode
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
@@ -1,6 +1,7 @@
|
||||
# 最终简化版
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
project(leann_backend_hnsw_wrapper)
|
||||
set(CMAKE_C_COMPILER_WORKS 1)
|
||||
set(CMAKE_CXX_COMPILER_WORKS 1)
|
||||
|
||||
# Set OpenMP path for macOS
|
||||
if(APPLE)
|
||||
@@ -11,15 +12,9 @@ if(APPLE)
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
endif()
|
||||
|
||||
# Build ZeroMQ from source
|
||||
set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE)
|
||||
set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE)
|
||||
set(WITH_DOCS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_STATIC ON CACHE BOOL "" FORCE)
|
||||
add_subdirectory(third_party/libzmq)
|
||||
# Use system ZeroMQ instead of building from source
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(ZMQ REQUIRED libzmq)
|
||||
|
||||
# Add cppzmq headers
|
||||
include_directories(third_party/cppzmq)
|
||||
@@ -29,6 +24,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
|
||||
add_compile_definitions(MSGPACK_NO_BOOST)
|
||||
include_directories(third_party/msgpack-c/include)
|
||||
|
||||
# Faiss configuration - streamlined build
|
||||
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE)
|
||||
@@ -36,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||
|
||||
# Disable additional SIMD versions to speed up compilation
|
||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# Additional optimization options from INSTALL.md
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build
|
||||
|
||||
# Avoid building demos and benchmarks
|
||||
set(BUILD_DEMOS OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_BENCHS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# NEW: Tell Faiss to only build the generic version
|
||||
set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE)
|
||||
set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
@@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, "wb") as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
metric_enum = get_metric_map().get(self.distance_metric.lower())
|
||||
if metric_enum is None:
|
||||
@@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher):
|
||||
|
||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||
|
||||
# Load label mapping
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found at {label_map_file}")
|
||||
|
||||
with open(label_map_file, "rb") as f:
|
||||
self.label_map = pickle.load(f)
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
|
||||
string_labels = [
|
||||
[
|
||||
self.label_map.get(int_label, f"unknown_{int_label}")
|
||||
for int_label in batch_labels
|
||||
]
|
||||
[str(int_label) for int_label in batch_labels]
|
||||
for batch_labels in labels
|
||||
]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,5 +13,10 @@ dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
[tool.scikit-build]
|
||||
wheel.packages = ["leann_backend_hnsw"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Debug"
|
||||
build.verbose = true
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ff22e2c86b
@@ -15,5 +15,8 @@ dependencies = [
|
||||
"tqdm>=4.60.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -9,9 +9,6 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
from .registry import BACKEND_REGISTRY
|
||||
from .interface import LeannBackendFactoryInterface
|
||||
from .chat import get_llm
|
||||
@@ -22,7 +19,7 @@ def compute_embeddings(
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx',
|
||||
port: int = 5557,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes embeddings using different backends.
|
||||
@@ -39,251 +36,60 @@ def compute_embeddings(
|
||||
Returns:
|
||||
numpy array of embeddings
|
||||
"""
|
||||
# Override mode for backward compatibility
|
||||
if use_mlx:
|
||||
mode = "mlx"
|
||||
|
||||
# Auto-detect mode based on model name if not explicitly set
|
||||
if mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
||||
mode = "openai"
|
||||
|
||||
if mode == "mlx":
|
||||
return compute_embeddings_mlx(chunks, model_name, batch_size=16)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(chunks, model_name)
|
||||
elif mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(
|
||||
chunks, model_name, use_server=use_server
|
||||
)
|
||||
if use_server:
|
||||
# Use embedding server (for search/query)
|
||||
return compute_embeddings_via_server(chunks, model_name, port=port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai"
|
||||
# Use direct computation (for build_index)
|
||||
from .embedding_compute import (
|
||||
compute_embeddings as compute_embeddings_direct,
|
||||
)
|
||||
|
||||
return compute_embeddings_direct(
|
||||
chunks,
|
||||
model_name,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
chunks: List[str], model_name: str, use_server: bool = True
|
||||
def compute_embeddings_via_server(
|
||||
chunks: List[str], model_name: str, port: int
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using sentence-transformers.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks to embed
|
||||
model_name: Name of the sentence transformer model
|
||||
use_server: If True, use embedding server (good for search). If False, use direct computation (good for build).
|
||||
"""
|
||||
if not use_server:
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||
)
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
# Use embedding server for sentence-transformers too
|
||||
# This avoids loading the model twice (once in API, once in server)
|
||||
try:
|
||||
# Import ZMQ client functionality and server manager
|
||||
import zmq
|
||||
import msgpack
|
||||
import numpy as np
|
||||
from .embedding_server_manager import EmbeddingServerManager
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Ensure embedding server is running
|
||||
port = 5557
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
server_started = server_manager.start_server(
|
||||
port=port,
|
||||
model_name=model_name,
|
||||
embedding_mode="sentence-transformers",
|
||||
enable_warmup=False,
|
||||
)
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to direct sentence-transformers if server connection fails
|
||||
print(
|
||||
f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}"
|
||||
)
|
||||
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||
|
||||
|
||||
def _compute_embeddings_sentence_transformers_direct(
|
||||
chunks: List[str], model_name: str
|
||||
) -> np.ndarray:
|
||||
"""Direct sentence-transformers computation (fallback)."""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"sentence-transformers not available. Install with: uv pip install sentence-transformers"
|
||||
) from e
|
||||
|
||||
# Load model using sentence-transformers
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
model = model.half()
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||
)
|
||||
# use acclerater GPU or MAC GPU
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.to("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
model = model.to("mps")
|
||||
|
||||
# Generate embeddings
|
||||
# give use an warning if OOM here means we need to turn down the batch size
|
||||
embeddings = model.encode(
|
||||
chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16
|
||||
)
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray:
|
||||
"""Computes embeddings using OpenAI API."""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"openai not available. Install with: uv pip install openai"
|
||||
) from e
|
||||
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..."
|
||||
)
|
||||
|
||||
# OpenAI has a limit on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(chunks), max_batch_size)
|
||||
batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches)
|
||||
except ImportError:
|
||||
# Fallback without progress bar
|
||||
batch_iterator = range(0, len(chunks), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_chunks)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch")
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i:i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
id: str
|
||||
@@ -344,14 +150,12 @@ class LeannBuilder:
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
self.backend_kwargs = backend_kwargs
|
||||
if 'mlx' in self.embedding_model:
|
||||
self.embedding_mode = "mlx"
|
||||
self.chunks: List[Dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(uuid.uuid4()))
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
|
||||
self.chunks.append(chunk_data)
|
||||
|
||||
@@ -377,10 +181,13 @@ class LeannBuilder:
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||
|
||||
chunk_iterator = tqdm(
|
||||
self.chunks, desc="Writing passages", unit="chunk"
|
||||
)
|
||||
except ImportError:
|
||||
chunk_iterator = self.chunks
|
||||
|
||||
|
||||
for chunk in chunk_iterator:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
@@ -398,7 +205,11 @@ class LeannBuilder:
|
||||
pickle.dump(offset_map, f)
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False
|
||||
texts_to_embed,
|
||||
self.embedding_model,
|
||||
self.embedding_mode,
|
||||
use_server=False,
|
||||
port=5557,
|
||||
)
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||
|
||||
287
packages/leann-core/src/leann/cli.py
Normal file
287
packages/leann-core/src/leann/cli.py
Normal file
@@ -0,0 +1,287 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
|
||||
def index_exists(self, index_name: str) -> bool:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
return meta_file.exists()
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="leann",
|
||||
description="LEANN - Local Enhanced AI Navigation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
"""
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||
build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
|
||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
search_parser.add_argument("index_name", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5)
|
||||
search_parser.add_argument("--complexity", type=int, default=64)
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
|
||||
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
return parser
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
|
||||
if not index_dirs:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print(f"\nUsage:")
|
||||
print(f" leann search {example_name} \"your query\"")
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
|
||||
return all_texts
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
is_recompute=args.recompute,
|
||||
num_threads=args.num_threads,
|
||||
)
|
||||
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||
return
|
||||
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
results = searcher.search(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
)
|
||||
|
||||
print(f"Search results for '{query}' (top {len(results)}):")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"{i}. Score: {result.score:.3f}")
|
||||
print(f" {result.text[:200]}...")
|
||||
print()
|
||||
|
||||
async def ask_questions(self, args):
|
||||
index_name = args.index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
if not self.index_exists(index_name):
|
||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
||||
return
|
||||
|
||||
print(f"Starting chat with index '{index_name}'...")
|
||||
print(f"Using {args.model} ({args.llm})")
|
||||
|
||||
llm_config = {"type": args.llm, "model": args.model}
|
||||
if args.llm == "ollama":
|
||||
llm_config["host"] = args.host
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
if args.interactive:
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if args.command == "list":
|
||||
self.list_indexes()
|
||||
elif args.command == "build":
|
||||
await self.build_index(args)
|
||||
elif args.command == "search":
|
||||
await self.search_documents(args)
|
||||
elif args.command == "ask":
|
||||
await self.ask_questions(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
def main():
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
cli = LeannCLI()
|
||||
asyncio.run(cli.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
272
packages/leann-core/src/leann/embedding_compute.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Unified embedding computation module
|
||||
Consolidates all embedding computation logic using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure performance
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
texts: List[str], model_name: str, mode: str = "sentence-transformers"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unified embedding computation entry point
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
if mode == "sentence-transformers":
|
||||
return compute_embeddings_sentence_transformers(texts, model_name)
|
||||
elif mode == "openai":
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||
|
||||
|
||||
def compute_embeddings_sentence_transformers(
|
||||
texts: List[str],
|
||||
model_name: str,
|
||||
use_fp16: bool = True,
|
||||
device: str = "auto",
|
||||
batch_size: int = 32,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using SentenceTransformer
|
||||
Preserves all optimization parameters to ensure consistency with original embedding_server
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: SentenceTransformer model name
|
||||
use_fp16: Whether to use FP16 precision
|
||||
device: Device selection ('auto', 'cuda', 'mps', 'cpu')
|
||||
batch_size: Batch size for processing
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||
)
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Auto-detect device
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"INFO: Using device: {device}")
|
||||
|
||||
# Prepare model and tokenizer optimization parameters (consistent with original embedding_server)
|
||||
model_kwargs = {
|
||||
"torch_dtype": torch.float16 if use_fp16 else torch.float32,
|
||||
"low_cpu_mem_usage": True,
|
||||
"_fast_init": True, # Skip weight initialization checks for faster loading
|
||||
}
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"use_fast": True, # Use fast tokenizer for better runtime performance
|
||||
}
|
||||
|
||||
# Load SentenceTransformer (try local first, then network)
|
||||
print(f"INFO: Loading SentenceTransformer model: {model_name}")
|
||||
|
||||
try:
|
||||
# Try local loading (avoid network delays)
|
||||
model_kwargs["local_files_only"] = True
|
||||
tokenizer_kwargs["local_files_only"] = True
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=True,
|
||||
)
|
||||
print("✅ Model loaded successfully! (local + optimized)")
|
||||
except Exception as e:
|
||||
print(f"Local loading failed ({e}), trying network download...")
|
||||
# Fallback to network loading
|
||||
model_kwargs["local_files_only"] = False
|
||||
tokenizer_kwargs["local_files_only"] = False
|
||||
|
||||
model = SentenceTransformer(
|
||||
model_name,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
tokenizer_kwargs=tokenizer_kwargs,
|
||||
local_files_only=False,
|
||||
)
|
||||
print("✅ Model loaded successfully! (network + optimized)")
|
||||
|
||||
# Apply additional optimizations (if supported)
|
||||
if use_fp16 and device in ["cuda", "mps"]:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"✅ Using FP16 precision and compile optimization: {model_name}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"FP16 or compile optimization failed, continuing with default settings: {e}"
|
||||
)
|
||||
|
||||
# Compute embeddings (using SentenceTransformer's optimized implementation)
|
||||
print("INFO: Starting embedding computation...")
|
||||
|
||||
embeddings = model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
show_progress_bar=False, # Don't show progress bar in server environment
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False, # Keep consistent with original API behavior
|
||||
device=device,
|
||||
)
|
||||
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
|
||||
# Validate results
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
raise RuntimeError(
|
||||
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||
"""Compute embeddings using OpenAI API"""
|
||||
try:
|
||||
import openai
|
||||
import os
|
||||
except ImportError as e:
|
||||
raise ImportError(f"OpenAI package not installed: {e}")
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# OpenAI has limits on batch size and input length
|
||||
max_batch_size = 100 # Conservative batch size
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||
batch_range = range(0, len(texts), max_batch_size)
|
||||
batch_iterator = tqdm(
|
||||
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback when tqdm is not available
|
||||
batch_iterator = range(0, len(texts), max_batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_texts = texts[i : i + max_batch_size]
|
||||
|
||||
try:
|
||||
response = client.embeddings.create(model=model_name, input=batch_texts)
|
||||
batch_embeddings = [embedding.embedding for embedding in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Batch {i} failed: {e}")
|
||||
raise
|
||||
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
print(
|
||||
f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def compute_embeddings_mlx(
|
||||
chunks: List[str], model_name: str, batch_size: int = 16
|
||||
) -> np.ndarray:
|
||||
"""Computes embeddings using an MLX model."""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
from tqdm import tqdm
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"MLX or related libraries not available. Install with: uv pip install mlx mlx-lm"
|
||||
) from e
|
||||
|
||||
print(
|
||||
f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..."
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer = load(model_name)
|
||||
|
||||
# Process chunks in batches with progress bar
|
||||
all_embeddings = []
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
batch_iterator = tqdm(
|
||||
range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch"
|
||||
)
|
||||
except ImportError:
|
||||
batch_iterator = range(0, len(chunks), batch_size)
|
||||
|
||||
for i in batch_iterator:
|
||||
batch_chunks = chunks[i : i + batch_size]
|
||||
|
||||
# Tokenize all chunks in the batch
|
||||
batch_token_ids = []
|
||||
for chunk in batch_chunks:
|
||||
token_ids = tokenizer.encode(chunk) # type: ignore
|
||||
batch_token_ids.append(token_ids)
|
||||
|
||||
# Pad sequences to the same length for batch processing
|
||||
max_length = max(len(ids) for ids in batch_token_ids)
|
||||
padded_token_ids = []
|
||||
for token_ids in batch_token_ids:
|
||||
# Pad with tokenizer.pad_token_id or 0
|
||||
padded = token_ids + [0] * (max_length - len(token_ids))
|
||||
padded_token_ids.append(padded)
|
||||
|
||||
# Convert to MLX array with batch dimension
|
||||
input_ids = mx.array(padded_token_ids)
|
||||
|
||||
# Get embeddings for the batch
|
||||
embeddings = model(input_ids)
|
||||
|
||||
# Mean pooling for each sequence in the batch
|
||||
pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size)
|
||||
|
||||
# Convert batch embeddings to numpy
|
||||
for j in range(len(batch_chunks)):
|
||||
pooled_list = pooled[j].tolist() # Convert to list
|
||||
pooled_numpy = np.array(pooled_list, dtype=np.float32)
|
||||
all_embeddings.append(pooled_numpy)
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
@@ -4,11 +4,10 @@ import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import zmq
|
||||
import msgpack
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import select
|
||||
import psutil
|
||||
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
@@ -17,151 +16,135 @@ def _check_port(port: int) -> bool:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
def _check_server_meta_path(port: int, expected_meta_path: str) -> bool:
|
||||
def _check_process_matches_config(
|
||||
port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the existing server on the port is using the correct meta file.
|
||||
Returns True if the server has the right meta path, False otherwise.
|
||||
Check if the process using the port matches our expected model and passages file.
|
||||
Returns True if matches, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||
if not _is_process_listening_on_port(proc, port):
|
||||
continue
|
||||
|
||||
# Send a special control message to query the server's meta path
|
||||
control_request = ["__QUERY_META_PATH__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
cmdline = proc.info["cmdline"]
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the meta path and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_meta_path = response[0]
|
||||
# Normalize paths for comparison
|
||||
expected_path = Path(expected_meta_path).resolve()
|
||||
server_path = Path(server_meta_path).resolve() if server_meta_path else None
|
||||
return server_path == expected_path
|
||||
return _check_cmdline_matches_config(
|
||||
cmdline, port, expected_model, expected_passages_file
|
||||
)
|
||||
|
||||
print(f"DEBUG: No process found listening on port {port}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server meta path on port {port}: {e}")
|
||||
print(f"WARNING: Could not check process on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_meta_path(port: int, new_meta_path: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's meta path.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||
"""Check if a process is listening on the given port."""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the meta path
|
||||
control_request = ["__UPDATE_META_PATH__", new_meta_path]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
connections = proc.net_connections()
|
||||
for conn in connections:
|
||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server meta path on port {port}: {e}")
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return False
|
||||
|
||||
|
||||
def _check_server_model(port: int, expected_model: str) -> bool:
|
||||
def _check_cmdline_matches_config(
|
||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||
) -> bool:
|
||||
"""Check if command line matches our expected configuration."""
|
||||
cmdline_str = " ".join(cmdline)
|
||||
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
||||
|
||||
# Check if it's our embedding server
|
||||
is_embedding_server = any(
|
||||
server_type in cmdline_str
|
||||
for server_type in [
|
||||
"embedding_server",
|
||||
"leann_backend_diskann.embedding_server",
|
||||
"leann_backend_hnsw.hnsw_embedding_server",
|
||||
]
|
||||
)
|
||||
|
||||
if not is_embedding_server:
|
||||
print(f"DEBUG: Process on port {port} is not our embedding server")
|
||||
return False
|
||||
|
||||
# Check model name
|
||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||
|
||||
# Check passages file if provided
|
||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||
|
||||
result = model_matches and passages_matches
|
||||
print(
|
||||
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||
"""Check if the command line contains the expected model."""
|
||||
if "--model-name" not in cmdline:
|
||||
return False
|
||||
|
||||
model_idx = cmdline.index("--model-name")
|
||||
if model_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_model = cmdline[model_idx + 1]
|
||||
return actual_model == expected_model
|
||||
|
||||
|
||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||
"""Check if the command line contains the expected passages file."""
|
||||
if "--passages-file" not in cmdline:
|
||||
return False # Expected but not found
|
||||
|
||||
passages_idx = cmdline.index("--passages-file")
|
||||
if passages_idx + 1 >= len(cmdline):
|
||||
return False
|
||||
|
||||
actual_passages = cmdline[passages_idx + 1]
|
||||
expected_path = Path(expected_passages_file).resolve()
|
||||
actual_path = Path(actual_passages).resolve()
|
||||
return actual_path == expected_path
|
||||
|
||||
|
||||
def _find_compatible_port_or_next_available(
|
||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Check if the existing server on the port is using the correct embedding model.
|
||||
Returns True if the server has the right model, False otherwise.
|
||||
Find a port that either has a compatible server or is available.
|
||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
for port in range(start_port, start_port + max_attempts):
|
||||
if not _check_port(port):
|
||||
# Port is available
|
||||
return port, False
|
||||
|
||||
# Send a special control message to query the server's model
|
||||
control_request = ["__QUERY_MODEL__"]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
# Port is in use, check if it's compatible
|
||||
if _check_process_matches_config(port, model_name, passages_file):
|
||||
print(f"✅ Found compatible server on port {port}")
|
||||
return port, True
|
||||
else:
|
||||
print(f"⚠️ Port {port} has incompatible server, trying next port...")
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the response contains the model name and if it matches
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
server_model = response[0]
|
||||
return server_model == expected_model
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not query server model on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _update_server_model(port: int, new_model: str) -> bool:
|
||||
"""
|
||||
Send a control message to update the server's embedding model.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading
|
||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send a control message to update the model
|
||||
control_request = ["__UPDATE_MODEL__", new_model]
|
||||
request_bytes = msgpack.packb(control_request)
|
||||
socket.send(request_bytes)
|
||||
|
||||
# Wait for response
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Check if the update was successful
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return response[0] == "SUCCESS"
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Could not update server model on port {port}: {e}")
|
||||
return False
|
||||
raise RuntimeError(
|
||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A generic manager for handling the lifecycle of a backend-specific embedding server process.
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
"""
|
||||
|
||||
def __init__(self, backend_module_name: str):
|
||||
@@ -175,210 +158,162 @@ class EmbeddingServerManager:
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
atexit.register(self.stop_server)
|
||||
self._atexit_registered = False
|
||||
|
||||
def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool:
|
||||
def start_server(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Starts the embedding server process.
|
||||
|
||||
Args:
|
||||
port (int): The ZMQ port for the server.
|
||||
port (int): The preferred ZMQ port for the server.
|
||||
model_name (str): The name of the embedding model to use.
|
||||
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup).
|
||||
**kwargs: Additional arguments for the server.
|
||||
|
||||
Returns:
|
||||
bool: True if the server is started successfully or already running, False otherwise.
|
||||
tuple[bool, int]: (success, actual_port_used)
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
# Even if we have a running process, check if model/meta path match
|
||||
if self.server_port is not None:
|
||||
port_in_use = _check_port(self.server_port)
|
||||
if port_in_use:
|
||||
print(
|
||||
f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})"
|
||||
)
|
||||
passages_file = kwargs.get("passages_file")
|
||||
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||
|
||||
# Check model compatibility
|
||||
model_matches = _check_server_model(self.server_port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server already using correct model: {model_name}"
|
||||
)
|
||||
|
||||
# Still check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(self.server_port, model_name):
|
||||
print(
|
||||
"❌ Failed to update existing server model. Restarting server..."
|
||||
)
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
print(
|
||||
f"✅ Successfully updated existing server model to: {model_name}"
|
||||
)
|
||||
# Check if we have a compatible running server
|
||||
if self._has_compatible_running_server(model_name, passages_file):
|
||||
assert self.server_port is not None, (
|
||||
"a compatible running server should set server_port"
|
||||
)
|
||||
return True, self.server_port
|
||||
|
||||
# Also check meta path if provided
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if passages_file and str(passages_file).endswith(
|
||||
".meta.json"
|
||||
):
|
||||
meta_matches = _check_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
if not meta_matches:
|
||||
print("⚠️ Updating meta path to: {passages_file}")
|
||||
_update_server_meta_path(
|
||||
self.server_port, str(passages_file)
|
||||
)
|
||||
# Find available port (compatible or free)
|
||||
try:
|
||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||
port, model_name, passages_file
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ {e}")
|
||||
return False, port
|
||||
|
||||
return True
|
||||
else:
|
||||
# Server process exists but port not responding - restart
|
||||
print("⚠️ Server process exists but not responding. Restarting...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
else:
|
||||
# No port stored - restart
|
||||
print("⚠️ No port information stored. Restarting server...")
|
||||
self.stop_server()
|
||||
# Continue to start new server below
|
||||
if is_compatible:
|
||||
print(f"✅ Using existing compatible server on port {actual_port}")
|
||||
self.server_port = actual_port
|
||||
self.server_process = None # We don't own this process
|
||||
return True, actual_port
|
||||
|
||||
if _check_port(port):
|
||||
# Port is in use, check if it's using the correct meta file and model
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if actual_port != port:
|
||||
print(f"⚠️ Using port {actual_port} instead of {port}")
|
||||
|
||||
print(f"INFO: Port {port} is in use. Checking server compatibility...")
|
||||
# Start new server
|
||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
# Check model compatibility first
|
||||
model_matches = _check_server_model(port, model_name)
|
||||
if model_matches:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||
)
|
||||
if not _update_server_model(port, model_name):
|
||||
raise RuntimeError(
|
||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||
)
|
||||
print(f"✅ Successfully updated server model to: {model_name}")
|
||||
def _has_compatible_running_server(
|
||||
self, model_name: str, passages_file: str
|
||||
) -> bool:
|
||||
"""Check if we have a compatible running server."""
|
||||
if not (
|
||||
self.server_process
|
||||
and self.server_process.poll() is None
|
||||
and self.server_port
|
||||
):
|
||||
return False
|
||||
|
||||
# Check meta path compatibility if provided
|
||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||
meta_matches = _check_server_meta_path(port, str(passages_file))
|
||||
if not meta_matches:
|
||||
print(
|
||||
f"⚠️ Existing server on port {port} has different meta path. Attempting to update..."
|
||||
)
|
||||
if not _update_server_meta_path(port, str(passages_file)):
|
||||
raise RuntimeError(
|
||||
"❌ Failed to update server meta path. This may cause data synchronization issues."
|
||||
)
|
||||
print(
|
||||
f"✅ Successfully updated server meta path to: {passages_file}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"✅ Existing server on port {port} is using correct meta path: {passages_file}"
|
||||
)
|
||||
|
||||
print(f"✅ Server on port {port} is compatible and ready to use.")
|
||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||
print(
|
||||
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
|
||||
)
|
||||
return True
|
||||
|
||||
print(
|
||||
f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..."
|
||||
)
|
||||
print("⚠️ Existing server process is incompatible. Should start a new server.")
|
||||
return False
|
||||
|
||||
def _start_new_server(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> tuple[bool, int]:
|
||||
"""Start a new embedding server on the given port."""
|
||||
print(f"INFO: Starting embedding server on port {port}...")
|
||||
|
||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
# Add extra arguments for specific backends
|
||||
if "passages_file" in kwargs and kwargs["passages_file"]:
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
|
||||
# command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if "enable_warmup" in kwargs and not kwargs["enable_warmup"]:
|
||||
command.extend(["--disable-warmup"])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1, # Line buffered
|
||||
universal_newlines=True,
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print(
|
||||
"❌ ERROR: Server process terminated unexpectedly during startup."
|
||||
)
|
||||
self._print_recent_output()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(
|
||||
f"❌ ERROR: Server process failed to start listening within {max_wait} seconds."
|
||||
)
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
self._launch_server_process(command, port)
|
||||
return self._wait_for_server_ready(port)
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
||||
return False, port
|
||||
|
||||
def _build_server_command(
|
||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||
) -> list:
|
||||
"""Build the command to start the embedding server."""
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
self.backend_module_name,
|
||||
"--zmq-port",
|
||||
str(port),
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
if kwargs.get("passages_file"):
|
||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
|
||||
return command
|
||||
|
||||
def _launch_server_process(self, command: list, port: int) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
print(f"INFO: Command: {' '.join(command)}")
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
if not self._atexit_registered:
|
||||
# Use a lambda to avoid issues with bound methods
|
||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready."""
|
||||
max_wait, wait_interval = 120, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print("✅ Embedding server is ready!")
|
||||
threading.Thread(target=self._log_monitor, daemon=True).start()
|
||||
return True, port
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: Server terminated during startup.")
|
||||
self._print_recent_output()
|
||||
return False, port
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False, port
|
||||
|
||||
def _print_recent_output(self):
|
||||
"""Print any recent output from the server process."""
|
||||
if not self.server_process or not self.server_process.stdout:
|
||||
return
|
||||
try:
|
||||
# Read any available output
|
||||
|
||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
||||
output = self.server_process.stdout.read()
|
||||
if output:
|
||||
@@ -404,17 +339,26 @@ class EmbeddingServerManager:
|
||||
|
||||
def stop_server(self):
|
||||
"""Stops the embedding server process if it's running."""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
if not self.server_process:
|
||||
return
|
||||
|
||||
if self.server_process.poll() is not None:
|
||||
# Process already terminated
|
||||
self.server_process = None
|
||||
return
|
||||
|
||||
print(
|
||||
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
f"INFO: Terminating session server process (PID: {self.server_process.pid})..."
|
||||
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print(
|
||||
"WARNING: Server process did not terminate gracefully, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
self.server_process.kill()
|
||||
|
||||
self.server_process = None
|
||||
|
||||
@@ -7,30 +7,37 @@ import importlib.metadata
|
||||
if TYPE_CHECKING:
|
||||
from leann.interface import LeannBackendFactoryInterface
|
||||
|
||||
BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {}
|
||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||
|
||||
|
||||
def register_backend(name: str):
|
||||
"""A decorator to register a new backend class."""
|
||||
|
||||
def decorator(cls):
|
||||
print(f"INFO: Registering backend '{name}'")
|
||||
BACKEND_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def autodiscover_backends():
|
||||
"""Automatically discovers and imports all 'leann-backend-*' packages."""
|
||||
print("INFO: Starting backend auto-discovery...")
|
||||
# print("INFO: Starting backend auto-discovery...")
|
||||
discovered_backends = []
|
||||
for dist in importlib.metadata.distributions():
|
||||
dist_name = dist.metadata['name']
|
||||
if dist_name.startswith('leann-backend-'):
|
||||
backend_module_name = dist_name.replace('-', '_')
|
||||
dist_name = dist.metadata["name"]
|
||||
if dist_name.startswith("leann-backend-"):
|
||||
backend_module_name = dist_name.replace("-", "_")
|
||||
discovered_backends.append(backend_module_name)
|
||||
|
||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||
|
||||
for backend_module_name in sorted(
|
||||
discovered_backends
|
||||
): # sort for deterministic loading
|
||||
try:
|
||||
importlib.import_module(backend_module_name)
|
||||
# Registration message is printed by the decorator
|
||||
except ImportError as e:
|
||||
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
print("INFO: Backend auto-discovery finished.")
|
||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||
pass
|
||||
# print("INFO: Backend auto-discovery finished.")
|
||||
|
||||
@@ -43,8 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||
)
|
||||
|
||||
self.label_map = self._load_label_map()
|
||||
|
||||
self.embedding_server_manager = EmbeddingServerManager(
|
||||
backend_module_name=backend_module_name
|
||||
)
|
||||
@@ -58,17 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _load_label_map(self) -> Dict[int, str]:
|
||||
"""Loads the mapping from integer IDs to string IDs."""
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
with open(label_map_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: int, **kwargs
|
||||
) -> None:
|
||||
) -> int:
|
||||
"""
|
||||
Ensures the embedding server is running if recompute is needed.
|
||||
This is a helper for subclasses.
|
||||
@@ -79,8 +69,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
)
|
||||
|
||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
passages_file=passages_source_file,
|
||||
@@ -89,7 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||
raise RuntimeError(
|
||||
f"Failed to start embedding server on port {actual_port}"
|
||||
)
|
||||
|
||||
return actual_port
|
||||
|
||||
def compute_query_embedding(
|
||||
self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True
|
||||
@@ -106,12 +100,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
Query embedding as numpy array
|
||||
"""
|
||||
# Try to use embedding server if available and requested
|
||||
if (
|
||||
use_server_if_available
|
||||
and self.embedding_server_manager
|
||||
and self.embedding_server_manager.server_process
|
||||
):
|
||||
if use_server_if_available:
|
||||
try:
|
||||
# Ensure we have a server with passages_file for compatibility
|
||||
passages_source_file = (
|
||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||
)
|
||||
zmq_port = self._ensure_server_running(
|
||||
str(passages_source_file), zmq_port
|
||||
)
|
||||
|
||||
return self._compute_embedding_via_server([query], zmq_port)[
|
||||
0:1
|
||||
] # Return (1, D) shape
|
||||
|
||||
@@ -9,7 +9,6 @@ requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
"leann-backend-diskann",
|
||||
"leann-backend-hnsw",
|
||||
"numpy>=1.26.0",
|
||||
"torch",
|
||||
@@ -36,6 +35,7 @@ dependencies = [
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
"mlx>=0.26.3",
|
||||
"mlx-lm>=0.26.0",
|
||||
"psutil>=5.8.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -48,6 +48,10 @@ dev = [
|
||||
"huggingface-hub>=0.20.0",
|
||||
]
|
||||
|
||||
diskann = [
|
||||
"leann-backend-diskann",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = []
|
||||
|
||||
|
||||
16
uv.lock
generated
16
uv.lock
generated
@@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" }
|
||||
dependencies = [
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }]
|
||||
requires-dist = [
|
||||
{ name = "numpy", specifier = ">=1.20.0" },
|
||||
{ name = "tqdm", specifier = ">=4.60.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "leann-workspace"
|
||||
@@ -1851,7 +1855,6 @@ dependencies = [
|
||||
{ name = "flask" },
|
||||
{ name = "flask-compress" },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "leann-backend-diskann" },
|
||||
{ name = "leann-backend-hnsw" },
|
||||
{ name = "leann-core" },
|
||||
{ name = "llama-index" },
|
||||
@@ -1867,6 +1870,7 @@ dependencies = [
|
||||
{ name = "ollama" },
|
||||
{ name = "openai" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pypdf2" },
|
||||
{ name = "requests" },
|
||||
{ name = "sentence-transformers" },
|
||||
@@ -1884,6 +1888,9 @@ dev = [
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
diskann = [
|
||||
{ name = "leann-backend-diskann" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
@@ -1896,7 +1903,7 @@ requires-dist = [
|
||||
{ name = "flask-compress" },
|
||||
{ name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" },
|
||||
{ name = "ipykernel", specifier = "==6.29.5" },
|
||||
{ name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" },
|
||||
{ name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" },
|
||||
{ name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" },
|
||||
{ name = "leann-core", editable = "packages/leann-core" },
|
||||
{ name = "llama-index", specifier = ">=0.12.44" },
|
||||
@@ -1912,6 +1919,7 @@ requires-dist = [
|
||||
{ name = "ollama" },
|
||||
{ name = "openai", specifier = ">=1.0.0" },
|
||||
{ name = "protobuf", specifier = "==4.25.3" },
|
||||
{ name = "psutil", specifier = ">=5.8.0" },
|
||||
{ name = "pypdf2", specifier = ">=3.0.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" },
|
||||
@@ -1922,7 +1930,7 @@ requires-dist = [
|
||||
{ name = "torch" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
provides-extras = ["dev", "diskann"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-cloud"
|
||||
|
||||
Reference in New Issue
Block a user