Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 |
2
.github/workflows/build-and-publish.yml
vendored
2
.github/workflows/build-and-publish.yml
vendored
@@ -8,4 +8,4 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
|
||||
107
.github/workflows/build-reusable.yml
vendored
107
.github/workflows/build-reusable.yml
vendored
@@ -17,23 +17,23 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff
|
||||
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
@@ -65,40 +65,41 @@ jobs:
|
||||
- os: macos-latest
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
|
||||
# Don't install LLVM, use system clang for better compatibility
|
||||
brew install libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
@@ -107,7 +108,7 @@ jobs:
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
@@ -116,32 +117,41 @@ jobs:
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Build meta package (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
@@ -153,7 +163,7 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
@@ -162,7 +172,7 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
@@ -174,7 +184,7 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
@@ -183,14 +193,59 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment
|
||||
uv venv
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install the built wheels
|
||||
# Use --find-links to let uv choose the correct wheel for the platform
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
uv pip install leann-core --find-links packages/leann-core/dist
|
||||
uv pip install leann --find-links packages/leann/dist
|
||||
fi
|
||||
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
||||
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
HF_HUB_DISABLE_SYMLINKS: 1
|
||||
TOKENIZERS_PARALLELISM: false
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run all tests
|
||||
pytest tests/
|
||||
|
||||
- name: Run sanity checks (optional)
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run distance function tests if available
|
||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||
echo "Running distance function sanity checks..."
|
||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
path: packages/*/dist/
|
||||
|
||||
34
.github/workflows/release-manual.yml
vendored
34
.github/workflows/release-manual.yml
vendored
@@ -16,10 +16,10 @@ jobs:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
# Remove 'v' prefix if present for validation
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||
echo "Current version: $CURRENT_VERSION"
|
||||
echo "Target version: ${{ inputs.version }}"
|
||||
|
||||
|
||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
fi
|
||||
|
||||
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
build-packages:
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
needs: update-version
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: 'main'
|
||||
ref: 'main'
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
@@ -69,26 +69,26 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
ref: 'main'
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
@@ -98,12 +98,12 @@ jobs:
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
# Check if tag already exists
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
git push origin "v${{ inputs.version }}"
|
||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
|
||||
# Check if release already exists
|
||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||
@@ -126,4 +126,4 @@ jobs:
|
||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -9,7 +9,7 @@ demo/indices/
|
||||
outputs/
|
||||
*.pkl
|
||||
*.pdf
|
||||
*.idx
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
lm_eval.egg-info/
|
||||
@@ -85,4 +85,6 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
|
||||
@@ -9,15 +9,8 @@ repos:
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.1.1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.1
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
152
README.md
152
README.md
@@ -33,12 +33,46 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
|
||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||
|
||||
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||
|
||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Installation
|
||||
> `pip leann` coming soon!
|
||||
|
||||
<details>
|
||||
<summary><strong>📦 Prerequisites: Install uv (if you don't have it)</strong></summary>
|
||||
|
||||
Install uv first if you don't have it:
|
||||
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
📖 [Detailed uv installation methods →](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
LEANN provides two installation methods: **pip install** (quick and easy) and **build from source** (recommended for development).
|
||||
|
||||
|
||||
|
||||
### 🚀 Quick Install (Recommended for most users)
|
||||
|
||||
Clone the repository to access all examples and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
### 🔧 Build from Source (Recommended for development)
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
@@ -48,27 +82,65 @@ git submodule update --init --recursive
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
# Install uv first if you don't have it:
|
||||
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
```
|
||||
|
||||
|
||||
**Ollama Setup (Recommended for full privacy):**
|
||||
|
||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||
|
||||
## Quick Start
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
|
||||
[](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb) [Try in this ipynb file →](demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
|
||||
> **Generation Model Setup**
|
||||
> LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
|
||||
Set your OpenAI API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||
|
||||
**macOS:**
|
||||
|
||||
@@ -80,6 +152,7 @@ ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -91,43 +164,7 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
## Quick Start in 30s
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
[Try in this ipynb file →](demo.ipynb) [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
# 1. Build the index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
|
||||
# 3. Chat with LEANN using retrieved results
|
||||
llm_config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.2:1b"
|
||||
}
|
||||
|
||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||
response = chat.ask(
|
||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||
top_k=2,
|
||||
)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
|
||||
@@ -139,11 +176,6 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
||||
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
```
|
||||
|
||||
```
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
@@ -154,6 +186,9 @@ python ./examples/main_cli_example.py
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
@@ -324,7 +359,7 @@ LEANN includes a powerful CLI for document processing and search. Perfect for qu
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
|
||||
# Search your documents
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
@@ -392,7 +427,7 @@ Options:
|
||||
|
||||
**Core techniques:**
|
||||
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
@@ -429,22 +464,22 @@ If you find Leann useful, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025leannlowstoragevectorindex,
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||
year={2025},
|
||||
eprint={2506.08276},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.DB},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ [Detailed Features →](docs/features.md)
|
||||
|
||||
## 🤝 [Contributing →](docs/contributing.md)
|
||||
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||
|
||||
|
||||
## [FAQ →](docs/faq.md)
|
||||
## ❓ [FAQ →](docs/faq.md)
|
||||
|
||||
|
||||
## 📈 [Roadmap →](docs/roadmap.md)
|
||||
@@ -465,4 +500,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.e
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
|
||||
141
demo.ipynb
141
demo.ipynb
@@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start in 30s\n",
|
||||
"# Quick Start \n",
|
||||
"\n",
|
||||
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||
"\n",
|
||||
@@ -49,68 +49,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 17077.79chunk/s]\n",
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.43it/s]\n",
|
||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"M: 64 for level: 0\n",
|
||||
"Starting conversion: index.index -> index.csr.tmp\n",
|
||||
"[0.00s] Reading Index HNSW header...\n",
|
||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||
"[0.00s] Reading HNSW struct vectors...\n",
|
||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||
"[0.00s] Read assign_probas (6)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||
"[0.14s] Read cum_nneighbor_per_level (7)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||
"[0.24s] Read levels (5)\n",
|
||||
"[0.33s] Probing for compact storage flag...\n",
|
||||
"[0.33s] Found compact flag: False\n",
|
||||
"[0.33s] Compact flag is False, reading original format...\n",
|
||||
"[0.33s] Probing for potential extra byte before non-compact offsets...\n",
|
||||
"[0.33s] Found and consumed an unexpected 0x00 byte.\n",
|
||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||
"[0.33s] Read offsets (6)\n",
|
||||
"[0.41s] Attempting to read neighbors vector...\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||
"[0.41s] Read neighbors (320)\n",
|
||||
"[0.54s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||
"[0.54s] Checking for storage data...\n",
|
||||
"[0.54s] Found storage fourcc: 49467849.\n",
|
||||
"[0.54s] Converting to CSR format...\n",
|
||||
"[0.54s] Conversion loop finished. \n",
|
||||
"[0.54s] Running validation checks...\n",
|
||||
" Checking total valid neighbor count...\n",
|
||||
" OK: Total valid neighbors = 20\n",
|
||||
" Checking final pointer indices...\n",
|
||||
" OK: Final pointers match data size.\n",
|
||||
"[0.54s] Deleting original neighbors and offsets arrays...\n",
|
||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||
"[0.63s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||
"[0.71s] Conversion complete.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'index.index'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
@@ -136,81 +75,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'programming languages'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5560 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5561 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5562 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Starting embedding server on port 5563...\n",
|
||||
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5563 --model-name facebook/contriever --passages-file /Users/yichuan/Desktop/code/test_leann_pip/LEANN/content/index.meta.json\n",
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"INFO:leann.embedding_server_manager:Server process started with PID: 31699\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"<frozen runpy>\", line 198, in _run_module_as_main\n",
|
||||
" File \"<frozen runpy>\", line 88, in _run_code\n",
|
||||
" File \"/Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann_backend_hnsw/hnsw_embedding_server.py\", line 323, in <module>\n",
|
||||
" create_hnsw_embedding_server(\n",
|
||||
" File \"/Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann_backend_hnsw/hnsw_embedding_server.py\", line 98, in create_hnsw_embedding_server\n",
|
||||
" passages = PassageManager(passage_sources)\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann/api.py\", line 127, in __init__\n",
|
||||
" raise FileNotFoundError(f\"Passage index file not found: {index_file}\")\n",
|
||||
"FileNotFoundError: Passage index file not found: /Users/yichuan/Desktop/code/test_leann_pip/LEANN/index.passages.idx\n",
|
||||
"ERROR:leann.embedding_server_manager:Server terminated during startup.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "Failed to start embedding server on port 5563",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mleann\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mapi\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m LeannSearcher\n\u001b[32m 3\u001b[39m searcher = LeannSearcher(\u001b[33m\"\u001b[39m\u001b[33mindex\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m results = \u001b[43msearcher\u001b[49m\u001b[43m.\u001b[49m\u001b[43msearch\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mprogramming languages\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 5\u001b[39m results\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann/api.py:439\u001b[39m, in \u001b[36mLeannSearcher.search\u001b[39m\u001b[34m(self, query, top_k, complexity, beam_width, prune_ratio, recompute_embeddings, pruning_strategy, expected_zmq_port, **kwargs)\u001b[39m\n\u001b[32m 437\u001b[39m start_time = time.time()\n\u001b[32m 438\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m recompute_embeddings:\n\u001b[32m--> \u001b[39m\u001b[32m439\u001b[39m zmq_port = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbackend_impl\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_ensure_server_running\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 440\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmeta_path_str\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 441\u001b[39m \u001b[43m \u001b[49m\u001b[43mport\u001b[49m\u001b[43m=\u001b[49m\u001b[43mexpected_zmq_port\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 442\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 443\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 444\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m expected_zmq_port\n\u001b[32m 445\u001b[39m zmq_time = time.time() - start_time\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann/searcher_base.py:81\u001b[39m, in \u001b[36mBaseSearcher._ensure_server_running\u001b[39m\u001b[34m(self, passages_source_file, port, **kwargs)\u001b[39m\n\u001b[32m 72\u001b[39m server_started, actual_port = \u001b[38;5;28mself\u001b[39m.embedding_server_manager.start_server(\n\u001b[32m 73\u001b[39m port=port,\n\u001b[32m 74\u001b[39m model_name=\u001b[38;5;28mself\u001b[39m.embedding_model,\n\u001b[32m (...)\u001b[39m\u001b[32m 78\u001b[39m enable_warmup=kwargs.get(\u001b[33m\"\u001b[39m\u001b[33menable_warmup\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m),\n\u001b[32m 79\u001b[39m )\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m server_started:\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 82\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFailed to start embedding server on port \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mactual_port\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 83\u001b[39m )\n\u001b[32m 85\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m actual_port\n",
|
||||
"\u001b[31mRuntimeError\u001b[39m: Failed to start embedding server on port 5563"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
|
||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
|
||||
## 🚀 Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install uv** (fast Python package installer):
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
2. **Clone the repository**:
|
||||
```bash
|
||||
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||
cd LEANN-RAG
|
||||
```
|
||||
|
||||
3. **Install system dependencies**:
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
```
|
||||
|
||||
**Ubuntu/Debian:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
```
|
||||
|
||||
4. **Build from source**:
|
||||
```bash
|
||||
# macOS
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
|
||||
# Ubuntu/Debian
|
||||
uv sync
|
||||
```
|
||||
|
||||
## 🔨 Pre-commit Hooks
|
||||
|
||||
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
|
||||
Our pre-commit configuration includes:
|
||||
- **Trailing whitespace removal**
|
||||
- **End-of-file fixing**
|
||||
- **YAML validation**
|
||||
- **Large file prevention**
|
||||
- **Merge conflict detection**
|
||||
- **Debug statement detection**
|
||||
- **Code formatting with ruff**
|
||||
- **Code linting with ruff**
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest test/test_filename.py
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest --cov=leann
|
||||
```
|
||||
|
||||
### Writing Tests
|
||||
|
||||
- Place tests in the `test/` directory
|
||||
- Follow the naming convention `test_*.py`
|
||||
- Use descriptive test names that explain what's being tested
|
||||
- Include both positive and negative test cases
|
||||
|
||||
## 📝 Code Style
|
||||
|
||||
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||
|
||||
### Format Your Code
|
||||
|
||||
```bash
|
||||
# Format all files
|
||||
ruff format
|
||||
|
||||
# Check formatting without changing files
|
||||
ruff format --check
|
||||
```
|
||||
|
||||
### Lint Your Code
|
||||
|
||||
```bash
|
||||
# Run linter with auto-fix
|
||||
ruff check --fix
|
||||
|
||||
# Just check without fixing
|
||||
ruff check
|
||||
```
|
||||
|
||||
### Style Guidelines
|
||||
|
||||
- Follow PEP 8 conventions
|
||||
- Use descriptive variable names
|
||||
- Add type hints where appropriate
|
||||
- Write docstrings for all public functions and classes
|
||||
- Keep functions focused and single-purpose
|
||||
|
||||
## 🚦 CI/CD
|
||||
|
||||
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||
|
||||
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||
4. **Wheel building**: Ensures packages can be built and distributed
|
||||
|
||||
### CI Commands
|
||||
|
||||
The CI uses the same commands as pre-commit to ensure consistency:
|
||||
```bash
|
||||
# Linting
|
||||
ruff check .
|
||||
|
||||
# Format checking
|
||||
ruff format --check .
|
||||
```
|
||||
|
||||
Make sure your code passes these checks locally before pushing!
|
||||
|
||||
## 🔄 Pull Request Process
|
||||
|
||||
1. **Fork the repository** and create your branch from `main`:
|
||||
```bash
|
||||
git checkout -b feature/your-feature-name
|
||||
```
|
||||
|
||||
2. **Make your changes**:
|
||||
- Write clean, documented code
|
||||
- Add tests for new functionality
|
||||
- Update documentation as needed
|
||||
|
||||
3. **Run pre-commit checks**:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
4. **Test your changes**:
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
5. **Commit with descriptive messages**:
|
||||
```bash
|
||||
git commit -m "feat: add new search algorithm"
|
||||
```
|
||||
|
||||
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
- `feat:` for new features
|
||||
- `fix:` for bug fixes
|
||||
- `docs:` for documentation changes
|
||||
- `test:` for test additions/changes
|
||||
- `refactor:` for code refactoring
|
||||
- `perf:` for performance improvements
|
||||
|
||||
6. **Push and create a pull request**:
|
||||
- Provide a clear description of your changes
|
||||
- Reference any related issues
|
||||
- Include examples or screenshots if applicable
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
When adding new features or making significant changes:
|
||||
|
||||
1. Update relevant documentation in `/docs`
|
||||
2. Add docstrings to new functions/classes
|
||||
3. Update README.md if needed
|
||||
4. Include usage examples
|
||||
|
||||
## 🤔 Getting Help
|
||||
|
||||
- **Discord**: Join our community for discussions
|
||||
- **Issues**: Check existing issues or create a new one
|
||||
- **Discussions**: For general questions and ideas
|
||||
|
||||
## 📄 License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||
@@ -19,4 +19,4 @@ That's it! The workflow will automatically:
|
||||
- ✅ Publish to PyPI
|
||||
- ✅ Create GitHub tag and release
|
||||
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
|
||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Comparison between Sentence Transformers and OpenAI embeddings
|
||||
|
||||
This example shows how different embedding models handle complex queries
|
||||
and demonstrates the differences between local and API-based embeddings.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
# OpenAI API key should be set as environment variable
|
||||
# export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Test data
|
||||
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||
|
||||
# Two queries with same intent but different wording
|
||||
query1 = "Tell me my browser history about some conference i often visit"
|
||||
query2 = "browser history about conference I often visit"
|
||||
|
||||
texts = [query1, query2, conference_text, browser_text]
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) # Already normalized
|
||||
|
||||
|
||||
def analyze_embeddings(embeddings, model_name):
|
||||
print(f"\n=== {model_name} Results ===")
|
||||
|
||||
# Results for Query 1
|
||||
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||
|
||||
print(f"Query 1: '{query1}'")
|
||||
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||
|
||||
# Results for Query 2
|
||||
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||
|
||||
print(f"\nQuery 2: '{query2}'")
|
||||
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||
|
||||
# Show the impact
|
||||
print("\n=== Impact Analysis ===")
|
||||
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||
|
||||
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||
print("✅ STABLE: Conference remains winner in both queries")
|
||||
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||
print("✅ STABLE: Browser remains winner in both queries")
|
||||
else:
|
||||
print("🔄 MIXED: Results vary between queries")
|
||||
|
||||
return {
|
||||
"query1_conf": sim1_conf,
|
||||
"query1_browser": sim1_browser,
|
||||
"query2_conf": sim2_conf,
|
||||
"query2_browser": sim2_browser,
|
||||
}
|
||||
|
||||
|
||||
# Test Sentence Transformers
|
||||
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||
try:
|
||||
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||
except Exception as e:
|
||||
print(f"❌ Sentence Transformers failed: {e}")
|
||||
st_results = None
|
||||
|
||||
# Test OpenAI
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing OpenAI (text-embedding-3-small)...")
|
||||
try:
|
||||
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||
except Exception as e:
|
||||
print(f"❌ OpenAI failed: {e}")
|
||||
openai_results = None
|
||||
|
||||
# Compare results
|
||||
if st_results and openai_results:
|
||||
print("\n" + "=" * 60)
|
||||
print("=== COMPARISON SUMMARY ===")
|
||||
@@ -1,11 +0,0 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
@@ -7,4 +7,4 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
@@ -19,4 +19,4 @@
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Normalized Embeddings Support in LEANN
|
||||
|
||||
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||
|
||||
## What are Normalized Embeddings?
|
||||
|
||||
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||
|
||||
## Automatic Detection
|
||||
|
||||
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||
|
||||
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||
2. **Show a warning** if you manually specify a different distance metric
|
||||
3. **Provide optimal search performance** with the correct metric
|
||||
|
||||
## Supported Normalized Embedding Models
|
||||
|
||||
### OpenAI
|
||||
All OpenAI text embedding models are normalized:
|
||||
- `text-embedding-ada-002`
|
||||
- `text-embedding-3-small`
|
||||
- `text-embedding-3-large`
|
||||
|
||||
### Voyage AI
|
||||
All Voyage AI embedding models are normalized:
|
||||
- `voyage-2`
|
||||
- `voyage-3`
|
||||
- `voyage-large-2`
|
||||
- `voyage-multilingual-2`
|
||||
- `voyage-code-2`
|
||||
|
||||
### Cohere
|
||||
All Cohere embedding models are normalized:
|
||||
- `embed-english-v3.0`
|
||||
- `embed-multilingual-v3.0`
|
||||
- `embed-english-light-v3.0`
|
||||
- `embed-multilingual-light-v3.0`
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
# Automatic detection - will use cosine distance
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai"
|
||||
)
|
||||
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||
# Automatically setting distance_metric='cosine'
|
||||
|
||||
# Manual override (not recommended)
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
distance_metric="mips" # Will show warning
|
||||
)
|
||||
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||
```
|
||||
|
||||
## Non-Normalized Embeddings
|
||||
|
||||
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||
|
||||
## Why This Matters
|
||||
|
||||
Using the wrong distance metric with normalized embeddings can lead to:
|
||||
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||
- **Incorrect ranking** of search results
|
||||
- **Suboptimal performance** compared to using the correct metric
|
||||
|
||||
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).
|
||||
@@ -18,4 +18,4 @@
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
The Project Gutenberg eBook of Pride and Prejudice
|
||||
|
||||
|
||||
This ebook is for the use of anyone anywhere in the United States and
|
||||
most other parts of the world at no cost and with almost no restrictions
|
||||
whatsoever. You may copy it, give it away or re-use it under the terms
|
||||
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
|
||||
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Updated editions will replace the previous one—the old editions will
|
||||
be renamed.
|
||||
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
|
||||
at www.gutenberg.org. If you
|
||||
are not located in the United States, you will have to check the laws
|
||||
of the country where you are located before using this eBook.
|
||||
|
||||
|
||||
1.E.2. If an individual Project Gutenberg™ electronic work is
|
||||
derived from texts not protected by U.S. copyright law (does not
|
||||
contain a notice indicating that it is posted with permission of the
|
||||
@@ -14724,7 +14724,7 @@ provided that:
|
||||
Gutenberg Literary Archive Foundation at the address specified in
|
||||
Section 4, “Information about donations to the Project Gutenberg
|
||||
Literary Archive Foundation.”
|
||||
|
||||
|
||||
• You provide a full refund of any money paid by a user who notifies
|
||||
you in writing (or by e-mail) within 30 days of receipt that s/he
|
||||
does not agree to the terms of the full Project Gutenberg™
|
||||
@@ -14732,15 +14732,15 @@ provided that:
|
||||
copies of the works possessed in a physical medium and discontinue
|
||||
all use of and all access to other copies of Project Gutenberg™
|
||||
works.
|
||||
|
||||
|
||||
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
||||
any money paid for a work or a replacement copy, if a defect in the
|
||||
electronic work is discovered and reported to you within 90 days of
|
||||
receipt of the work.
|
||||
|
||||
|
||||
• You comply with all other terms of this agreement for free
|
||||
distribution of Project Gutenberg™ works.
|
||||
|
||||
|
||||
|
||||
1.E.9. If you wish to charge a fee or distribute a Project
|
||||
Gutenberg™ electronic work or group of works on different terms than
|
||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
||||
including how to make donations to the Project Gutenberg Literary
|
||||
Archive Foundation, how to help produce our new eBooks, and how to
|
||||
subscribe to our email newsletter to hear about new eBooks.
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,10 @@ def load_sample_documents():
|
||||
"title": "Intro to Python",
|
||||
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||
},
|
||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
||||
{
|
||||
"title": "ML Basics",
|
||||
"content": "Machine learning builds systems that learn from data.",
|
||||
},
|
||||
{
|
||||
"title": "Data Structures",
|
||||
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||
|
||||
@@ -21,7 +21,11 @@ DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Googl
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs: list[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1
|
||||
profile_dirs: list[Path],
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = -1,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
@@ -30,6 +34,8 @@ def create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
embedding_model: The embedding model to use
|
||||
embedding_mode: The embedding backend mode
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
@@ -104,9 +110,11 @@ def create_leann_index_from_multiple_chrome_profiles(
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode=embedding_mode,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
@@ -130,6 +138,8 @@ def create_leann_index(
|
||||
profile_path: str | None = None,
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from Chrome history data.
|
||||
@@ -138,6 +148,8 @@ def create_leann_index(
|
||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process
|
||||
embedding_model: The embedding model to use
|
||||
embedding_mode: The embedding backend mode
|
||||
"""
|
||||
print("Creating LEANN index from Chrome history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
@@ -185,9 +197,11 @@ def create_leann_index(
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode=embedding_mode,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
@@ -271,6 +285,24 @@ async def main():
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="The embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-existing-index",
|
||||
action="store_true",
|
||||
help="Use existing index without rebuilding",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -281,26 +313,34 @@ async def main():
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
if args.use_existing_index:
|
||||
# Use existing index without rebuilding
|
||||
if not Path(INDEX_PATH).exists():
|
||||
print(f"Error: Index file not found at {INDEX_PATH}")
|
||||
return
|
||||
print(f"Using existing index at {INDEX_PATH}")
|
||||
index_path = INDEX_PATH
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs, INDEX_PATH, args.max_entries
|
||||
)
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs, INDEX_PATH, args.max_entries, args.embedding_model, args.embedding_mode
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
|
||||
@@ -474,7 +474,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
message_group, contact_name
|
||||
)
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
text=doc_content,
|
||||
metadata={"contact_name": contact_name},
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
@@ -315,7 +315,11 @@ async def main():
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(
|
||||
messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model
|
||||
messages_dirs,
|
||||
INDEX_PATH,
|
||||
args.max_emails,
|
||||
args.include_html,
|
||||
args.embedding_model,
|
||||
)
|
||||
|
||||
if index_path:
|
||||
|
||||
@@ -92,7 +92,10 @@ def main():
|
||||
help="Directory to store the index (default: mail_index_embedded)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails", type=int, default=10000, help="Maximum number of emails to process"
|
||||
"--max-emails",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum number of emails to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
@@ -112,7 +115,10 @@ def main():
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(
|
||||
mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html
|
||||
mail_path,
|
||||
save_dir,
|
||||
max_count=args.max_emails,
|
||||
include_html=args.include_html,
|
||||
)
|
||||
if index:
|
||||
queries = [
|
||||
|
||||
@@ -30,17 +30,22 @@ async def main(args):
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# LeannBuilder now automatically detects normalized embeddings and sets appropriate distance metric
|
||||
print(f"Using {args.embedding_model} with {args.embedding_mode} mode")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
# distance_metric is automatically set based on embedding model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
@@ -59,9 +64,19 @@ async def main(args):
|
||||
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
# Build llm_config based on command line arguments
|
||||
if args.llm == "simulated":
|
||||
llm_config = {"type": "simulated"}
|
||||
elif args.llm == "ollama":
|
||||
llm_config = {"type": "ollama", "model": args.model, "host": args.host}
|
||||
elif args.llm == "hf":
|
||||
llm_config = {"type": "hf", "model": args.model}
|
||||
elif args.llm == "openai":
|
||||
llm_config = {"type": "openai", "model": args.model}
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM type: {args.llm}")
|
||||
|
||||
print(f"Using LLM: {args.llm} with model: {args.model if args.llm != 'simulated' else 'N/A'}")
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
# query = (
|
||||
@@ -89,6 +104,19 @@ if __name__ == "__main__":
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="The embedding backend mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
|
||||
@@ -347,7 +347,9 @@ def demo_aggregation():
|
||||
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method, spatial_clustering=True, cluster_distance_threshold=100.0
|
||||
aggregation_method=method,
|
||||
spatial_clustering=True,
|
||||
cluster_distance_threshold=100.0,
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -163,18 +163,44 @@ class DiskannSearcher(BaseSearcher):
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": "",
|
||||
}
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||
|
||||
def _ensure_index_loaded(self, zmq_port: int):
|
||||
"""Ensure the index is loaded with the correct zmq_port."""
|
||||
if self._index is None or self._current_zmq_port != zmq_port:
|
||||
# Need to (re)load the index with the correct zmq_port
|
||||
with suppress_cpp_output_if_needed():
|
||||
if self._index is not None:
|
||||
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||
else:
|
||||
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||
|
||||
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||
self._init_params["metric_enum"],
|
||||
self._init_params["full_index_prefix"],
|
||||
self._init_params["num_threads"],
|
||||
self._init_params["num_nodes_to_cache"],
|
||||
self._init_params["cache_mechanism"],
|
||||
zmq_port,
|
||||
self._init_params["pq_prefix"],
|
||||
self._init_params["partition_prefix"],
|
||||
)
|
||||
self._current_zmq_port = zmq_port
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -212,14 +238,15 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(f"Updating DiskANN zmq_port from {current_port} to {zmq_port}")
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
self._ensure_index_loaded(zmq_port)
|
||||
else:
|
||||
# If not recomputing, we still need an index, use a default port
|
||||
if self._index is None:
|
||||
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
|
||||
@@ -36,6 +36,7 @@ def create_diskann_embedding_server(
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
distance_metric: str = "l2",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
@@ -263,6 +264,13 @@ if __name__ == "__main__":
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "mips", "cosine"],
|
||||
help="Distance metric for similarity computation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -272,4 +280,5 @@ if __name__ == "__main__":
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.14"
|
||||
dependencies = ["leann-core==0.1.14", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.1.16"
|
||||
dependencies = ["leann-core==0.1.16", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
@@ -16,4 +16,4 @@ wheel.packages = ["leann_backend_diskann"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
@@ -2,12 +2,12 @@ syntax = "proto3";
|
||||
|
||||
package protoembedding;
|
||||
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
}
|
||||
|
||||
message NodeEmbeddingResponse {
|
||||
bytes embeddings_data = 1; // All embedded binary datas
|
||||
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
|
||||
repeated uint32 missing_ids = 3; // Missing node ids
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,14 @@ if(APPLE)
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
|
||||
# Force use of system libc++ to avoid version mismatch
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||
|
||||
# Set minimum macOS version for better compatibility
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Use system ZeroMQ instead of building from source
|
||||
@@ -52,4 +60,4 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
add_subdirectory(third_party/faiss)
|
||||
|
||||
@@ -72,7 +72,11 @@ def read_vector_raw(f, element_fmt_char):
|
||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||
"""Reads a vector into a NumPy array."""
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end="", flush=True)
|
||||
print(
|
||||
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||
@@ -647,7 +651,10 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||
return False
|
||||
except MemoryError as e:
|
||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
||||
print(
|
||||
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Clean up potentially partially written output file?
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
|
||||
@@ -28,6 +28,12 @@ def get_metric_map():
|
||||
}
|
||||
|
||||
|
||||
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1 # Avoid division by zero
|
||||
return data / norms
|
||||
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -76,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index.hnsw.efConstruction = self.efConstruction
|
||||
|
||||
if self.distance_metric.lower() == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
data = normalize_l2(data)
|
||||
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
@@ -118,7 +124,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
from . import faiss # type: ignore
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
self.distance_metric = (
|
||||
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||
)
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -186,7 +194,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if self.distance_metric == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
query = normalize_l2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
if zmq_port is not None:
|
||||
@@ -194,6 +202,16 @@ class HNSWSearcher(BaseSearcher):
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||
# This prevents early termination when all scores are in a narrow range
|
||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||
if self.distance_metric == "cosine" and any(
|
||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||
):
|
||||
params.check_relative_distance = False
|
||||
else:
|
||||
params.check_relative_distance = True
|
||||
|
||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||
params.pq_pruning_ratio = prune_ratio
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.14"
|
||||
version = "0.1.16"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.1.14",
|
||||
"leann-core==0.1.16",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
@@ -24,4 +24,4 @@ build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.14"
|
||||
version = "0.1.16"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -46,4 +46,4 @@ colab = [
|
||||
leann = "leann.cli:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
where = ["src"]
|
||||
|
||||
@@ -8,6 +8,10 @@ if platform.system() == "Darwin":
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
os.environ["KMP_BLOCKTIME"] = "0"
|
||||
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
|
||||
if os.environ.get("CI") == "true":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
@@ -22,6 +23,11 @@ from .registry import BACKEND_REGISTRY
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_registered_backends() -> list[str]:
|
||||
"""Get list of registered backend names."""
|
||||
return list(BACKEND_REGISTRY.keys())
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: list[str],
|
||||
model_name: str,
|
||||
@@ -163,6 +169,76 @@ class LeannBuilder:
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
|
||||
# Check if we need to use cosine distance for normalized embeddings
|
||||
normalized_embeddings_models = {
|
||||
# OpenAI models
|
||||
("openai", "text-embedding-ada-002"),
|
||||
("openai", "text-embedding-3-small"),
|
||||
("openai", "text-embedding-3-large"),
|
||||
# Voyage AI models
|
||||
("voyage", "voyage-2"),
|
||||
("voyage", "voyage-3"),
|
||||
("voyage", "voyage-large-2"),
|
||||
("voyage", "voyage-multilingual-2"),
|
||||
("voyage", "voyage-code-2"),
|
||||
# Cohere models
|
||||
("cohere", "embed-english-v3.0"),
|
||||
("cohere", "embed-multilingual-v3.0"),
|
||||
("cohere", "embed-english-light-v3.0"),
|
||||
("cohere", "embed-multilingual-light-v3.0"),
|
||||
}
|
||||
|
||||
# Also check for patterns in model names
|
||||
is_normalized = False
|
||||
current_model_lower = embedding_model.lower()
|
||||
current_mode_lower = embedding_mode.lower()
|
||||
|
||||
# Check exact matches
|
||||
for mode, model in normalized_embeddings_models:
|
||||
if (current_mode_lower == mode and current_model_lower == model) or (
|
||||
mode in current_mode_lower and model in current_model_lower
|
||||
):
|
||||
is_normalized = True
|
||||
break
|
||||
|
||||
# Check patterns
|
||||
if not is_normalized:
|
||||
# OpenAI patterns
|
||||
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
||||
if any(
|
||||
pattern in current_model_lower
|
||||
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
||||
):
|
||||
is_normalized = True
|
||||
# Voyage patterns
|
||||
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
||||
is_normalized = True
|
||||
# Cohere patterns
|
||||
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
||||
if "embed" in current_model_lower:
|
||||
is_normalized = True
|
||||
|
||||
# Handle distance metric
|
||||
if is_normalized and "distance_metric" not in backend_kwargs:
|
||||
backend_kwargs["distance_metric"] = "cosine"
|
||||
warnings.warn(
|
||||
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
||||
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
||||
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
||||
current_metric = backend_kwargs.get("distance_metric", "mips")
|
||||
warnings.warn(
|
||||
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
||||
f"'{embedding_model}' may lead to suboptimal search results. "
|
||||
f"Consider using 'cosine' distance metric for better performance.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@@ -245,7 +245,11 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||
|
||||
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||
models = list_models(
|
||||
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
|
||||
search=query,
|
||||
filter="text-generation",
|
||||
sort="downloads",
|
||||
direction=-1,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
||||
@@ -582,7 +586,11 @@ class HFChat(LLMInterface):
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
|
||||
@@ -293,6 +293,8 @@ class EmbeddingServerManager:
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if kwargs.get("distance_metric"):
|
||||
command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
|
||||
return command
|
||||
|
||||
|
||||
@@ -63,12 +63,19 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
if not self.embedding_model:
|
||||
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||
|
||||
# Get distance_metric from meta if not provided in kwargs
|
||||
distance_metric = (
|
||||
kwargs.get("distance_metric")
|
||||
or self.meta.get("backend_kwargs", {}).get("distance_metric")
|
||||
or "mips"
|
||||
)
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
|
||||
@@ -16,25 +16,24 @@ uv pip install leann[diskann]
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.build_index("my_index.leann")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher("my_index.leann")
|
||||
results = searcher.search("storage savings", top_k=3)
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||
response = chat.ask("How much storage does LEANN save?")
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
MIT License
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.1.14"
|
||||
version = "0.1.16"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -36,7 +36,5 @@ diskann = [
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/leann"
|
||||
Documentation = "https://leann.readthedocs.io"
|
||||
Repository = "https://github.com/yourusername/leann"
|
||||
Issues = "https://github.com/yourusername/leann/issues"
|
||||
Repository = "https://github.com/yichuan-w/LEANN"
|
||||
Issues = "https://github.com/yichuan-w/LEANN/issues"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import xml.etree.ElementTree as ET
|
||||
import xml.etree.ElementTree as ElementTree
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_safe_path(s: str) -> str:
|
||||
def process_history(history: str):
|
||||
if history.startswith("<?xml") or history.startswith("<msg>"):
|
||||
try:
|
||||
root = ET.fromstring(history)
|
||||
root = ElementTree.fromstring(history)
|
||||
title = root.find(".//title").text if root.find(".//title") is not None else None
|
||||
quoted = (
|
||||
root.find(".//refermsg/content").text
|
||||
@@ -52,7 +52,8 @@ def get_message(history: dict | str):
|
||||
|
||||
def export_chathistory(user_id: str):
|
||||
res = requests.get(
|
||||
"http://localhost:48065/wechat/chatlog", params={"userId": user_id, "count": 100000}
|
||||
"http://localhost:48065/wechat/chatlog",
|
||||
params={"userId": user_id, "count": 100000},
|
||||
).json()
|
||||
for i in range(len(res["chatLogs"])):
|
||||
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
||||
@@ -116,7 +117,8 @@ def export_sqlite(
|
||||
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||
for user in tqdm(all_users):
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user["arg"], user["title"])
|
||||
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
|
||||
(user["arg"], user["title"]),
|
||||
)
|
||||
usr_chatlog = export_chathistory(user["arg"])
|
||||
for msg in usr_chatlog:
|
||||
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "leann-workspace"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
@@ -33,8 +33,8 @@ dependencies = [
|
||||
# LlamaIndex core and readers - updated versions
|
||||
"llama-index>=0.12.44",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||
"llama-index-readers-docling",
|
||||
"llama-index-node-parser-docling",
|
||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||
"llama-index-vector-stores-faiss>=0.4.0",
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
# Other dependencies
|
||||
@@ -49,10 +49,21 @@ dependencies = [
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
"pytest-cov>=4.0",
|
||||
"pytest-xdist>=3.0", # For parallel test execution
|
||||
"black>=23.0",
|
||||
"ruff>=0.1.0",
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"pre-commit>=3.5.0",
|
||||
]
|
||||
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"pytest-timeout>=2.0",
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
]
|
||||
|
||||
diskann = [
|
||||
@@ -122,3 +133,24 @@ line-ending = "auto"
|
||||
dev = [
|
||||
"ruff>=0.12.4",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"openai: marks tests that require OpenAI API key",
|
||||
]
|
||||
timeout = 600
|
||||
addopts = [
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"--strict-markers",
|
||||
"--disable-warnings",
|
||||
]
|
||||
env = [
|
||||
"HF_HUB_DISABLE_SYMLINKS=1",
|
||||
"TOKENIZERS_PARALLELISM=false",
|
||||
]
|
||||
|
||||
@@ -19,16 +19,16 @@ uv pip install build twine delocate auditwheel scikit-build-core cmake pybind11
|
||||
build_package() {
|
||||
local package_dir=$1
|
||||
local package_name=$(basename $package_dir)
|
||||
|
||||
|
||||
echo "Building $package_name..."
|
||||
cd $package_dir
|
||||
|
||||
|
||||
# Clean previous builds
|
||||
rm -rf dist/ build/ _skbuild/
|
||||
|
||||
|
||||
# Build directly with pip wheel (avoids sdist issues)
|
||||
pip wheel . --no-deps -w dist
|
||||
|
||||
|
||||
# Repair wheel for binary packages
|
||||
if [[ "$package_name" != "leann-core" ]] && [[ "$package_name" != "leann" ]]; then
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
@@ -57,7 +57,7 @@ build_package() {
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
echo "Built wheels in $package_dir/dist/"
|
||||
ls -la dist/
|
||||
cd - > /dev/null
|
||||
@@ -84,4 +84,4 @@ else
|
||||
fi
|
||||
|
||||
echo -e "\nBuild complete! Test with:"
|
||||
echo "uv pip install packages/*/dist/*.whl"
|
||||
echo "uv pip install packages/*/dist/*.whl"
|
||||
|
||||
@@ -28,4 +28,4 @@ else
|
||||
fi
|
||||
|
||||
echo "✅ Version updated to $NEW_VERSION"
|
||||
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"
|
||||
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"
|
||||
|
||||
@@ -15,4 +15,4 @@ VERSION=$1
|
||||
git add . && git commit -m "chore: bump version to $VERSION" && git push
|
||||
|
||||
# Create release (triggers CI)
|
||||
gh release create v$VERSION --generate-notes
|
||||
gh release create v$VERSION --generate-notes
|
||||
|
||||
@@ -27,4 +27,4 @@ else
|
||||
else
|
||||
echo "Cancelled"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -58,7 +58,8 @@ class GraphWrapper:
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input, attention_mask=self.static_attention_mask
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
self.use_cuda_graph = True
|
||||
else:
|
||||
@@ -82,7 +83,10 @@ class GraphWrapper:
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask)
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_cuda_graph:
|
||||
@@ -261,7 +265,10 @@ class Benchmark:
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features, out_features, bias=bias, has_fp16_weights=False
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False,
|
||||
)
|
||||
|
||||
# Copy weights and bias
|
||||
@@ -350,8 +357,6 @@ class Benchmark:
|
||||
# Try xformers if available (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
@@ -427,7 +432,11 @@ class Benchmark:
|
||||
else "cpu"
|
||||
)
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
|
||||
@@ -7,7 +7,7 @@ This directory contains comprehensive sanity checks for the Leann system, ensuri
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
@@ -27,7 +27,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
@@ -64,7 +64,7 @@ When all tests pass, you should see:
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
@@ -98,7 +98,7 @@ pkill -f "embedding_server"
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
@@ -117,4 +117,4 @@ These tests are designed to be run in automated environments:
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
|
||||
@@ -115,7 +115,13 @@ def main():
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
|
||||
plt.plot(
|
||||
BATCH_SIZES,
|
||||
results_torch,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
label=f"PyTorch ({device})",
|
||||
)
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||
|
||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||
|
||||
@@ -170,7 +170,11 @@ class Benchmark:
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, self.config.seq_length), device=self.device, dtype=torch.long
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
@@ -256,7 +260,11 @@ def run_mlx_benchmark():
|
||||
"""Run MLX-specific benchmark"""
|
||||
if not MLX_AVAILABLE:
|
||||
print("MLX not available, skipping MLX benchmark")
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "MLX not available"}
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "MLX not available",
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||
|
||||
@@ -265,7 +273,11 @@ def run_mlx_benchmark():
|
||||
results = benchmark.run()
|
||||
|
||||
if not results:
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "No valid results"}
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "No valid results",
|
||||
}
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
87
tests/README.md
Normal file
87
tests/README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# LEANN Tests
|
||||
|
||||
This directory contains automated tests for the LEANN project using pytest.
|
||||
|
||||
## Test Files
|
||||
|
||||
### `test_readme_examples.py`
|
||||
Tests the examples shown in README.md:
|
||||
- The basic example code that users see first
|
||||
- Import statements work correctly
|
||||
- Different backend options (HNSW, DiskANN)
|
||||
- Different LLM configuration options
|
||||
|
||||
### `test_basic.py`
|
||||
Basic functionality tests that verify:
|
||||
- All packages can be imported correctly
|
||||
- C++ extensions (FAISS, DiskANN) load properly
|
||||
- Basic index building and searching works for both HNSW and DiskANN backends
|
||||
- Uses parametrized tests to test both backends
|
||||
|
||||
### `test_main_cli.py`
|
||||
Tests the main CLI example functionality:
|
||||
- Tests with facebook/contriever embeddings
|
||||
- Tests with OpenAI embeddings (if API key is available)
|
||||
- Tests error handling with invalid parameters
|
||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Install test dependencies:
|
||||
```bash
|
||||
# Using extras
|
||||
uv pip install -e ".[test]"
|
||||
```
|
||||
|
||||
### Run all tests:
|
||||
```bash
|
||||
pytest tests/
|
||||
|
||||
# Or with coverage
|
||||
pytest tests/ --cov=leann --cov-report=html
|
||||
|
||||
# Run in parallel (faster)
|
||||
pytest tests/ -n auto
|
||||
```
|
||||
|
||||
### Run specific tests:
|
||||
```bash
|
||||
# Only basic tests
|
||||
pytest tests/test_basic.py
|
||||
|
||||
# Only tests that don't require OpenAI
|
||||
pytest tests/ -m "not openai"
|
||||
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
```
|
||||
|
||||
### Run with specific backend:
|
||||
```bash
|
||||
# Test only HNSW backend
|
||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||
|
||||
# Test only DiskANN backend
|
||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Tests are automatically run in GitHub Actions:
|
||||
1. After building wheel packages
|
||||
2. On multiple Python versions (3.9 - 3.13)
|
||||
3. On both Ubuntu and macOS
|
||||
4. Using pytest with appropriate markers and flags
|
||||
|
||||
### pytest.ini Configuration
|
||||
|
||||
The `pytest.ini` file configures:
|
||||
- Test discovery paths
|
||||
- Default timeout (600 seconds)
|
||||
- Environment variables (HF_HUB_DISABLE_SYMLINKS, TOKENIZERS_PARALLELISM)
|
||||
- Custom markers for slow and OpenAI tests
|
||||
- Verbose output with short tracebacks
|
||||
|
||||
### Known Issues
|
||||
|
||||
- OpenAI tests are automatically skipped if no API key is provided
|
||||
92
tests/test_basic.py
Normal file
92
tests/test_basic.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Basic functionality tests for CI pipeline using pytest.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test that all packages can be imported."""
|
||||
|
||||
# Test C++ extensions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
def test_backend_basic(backend_name):
|
||||
"""Test basic functionality for each backend."""
|
||||
from leann.api import LeannBuilder, LeannSearcher, SearchResult
|
||||
|
||||
# Create temporary directory for index
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"test.{backend_name}")
|
||||
|
||||
# Test with small data
|
||||
texts = [f"This is document {i} about topic {i % 5}" for i in range(100)]
|
||||
|
||||
# Configure builder based on backend
|
||||
if backend_name == "hnsw":
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
M=16,
|
||||
efConstruction=200,
|
||||
)
|
||||
else: # diskann
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
)
|
||||
|
||||
# Add texts
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
# Build index
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test search
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("document about topic 2", top_k=5)
|
||||
|
||||
# Verify results
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], SearchResult)
|
||||
assert "topic 2" in results[0].text or "document" in results[0].text
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
def test_large_index():
|
||||
"""Test with larger dataset."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_large.hnsw")
|
||||
texts = [f"Document {i}: {' '.join([f'word{j}' for j in range(50)])}" for i in range(1000)]
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search(["word10 word20"], top_k=10)
|
||||
assert len(results[0]) == 10
|
||||
49
tests/test_ci_minimal.py
Normal file
49
tests/test_ci_minimal.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Minimal tests for CI that don't require model loading or significant memory.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_package_imports():
|
||||
"""Test that all core packages can be imported."""
|
||||
# Core package
|
||||
|
||||
# Backend packages
|
||||
|
||||
# Core modules
|
||||
|
||||
assert True # If we get here, imports worked
|
||||
|
||||
|
||||
def test_cli_help():
|
||||
"""Test that CLI example shows help."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/main_cli_example.py", "--help"], capture_output=True, text=True
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "usage:" in result.stdout.lower() or "usage:" in result.stderr.lower()
|
||||
assert "--llm" in result.stdout or "--llm" in result.stderr
|
||||
|
||||
|
||||
def test_backend_registration():
|
||||
"""Test that backends are properly registered."""
|
||||
from leann.api import get_registered_backends
|
||||
|
||||
backends = get_registered_backends()
|
||||
assert "hnsw" in backends
|
||||
assert "diskann" in backends
|
||||
|
||||
|
||||
def test_version_info():
|
||||
"""Test that packages have version information."""
|
||||
import leann
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
|
||||
# Check that packages have __version__ or can be imported
|
||||
assert hasattr(leann, "__version__") or True
|
||||
assert hasattr(leann_backend_hnsw, "__version__") or True
|
||||
assert hasattr(leann_backend_diskann, "__version__") or True
|
||||
120
tests/test_main_cli.py
Normal file
120
tests/test_main_cli.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Test main_cli_example functionality using pytest.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_data_dir():
|
||||
"""Return the path to test data directory."""
|
||||
return Path("examples/data")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
def test_main_cli_simulated(test_data_dir):
|
||||
"""Test main_cli with simulated LLM."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a subdirectory that doesn't exist yet to force index creation
|
||||
index_dir = Path(temp_dir) / "test_index"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"examples/main_cli_example.py",
|
||||
"--llm",
|
||||
"simulated",
|
||||
"--embedding-model",
|
||||
"facebook/contriever",
|
||||
"--embedding-mode",
|
||||
"sentence-transformers",
|
||||
"--index-dir",
|
||||
str(index_dir),
|
||||
"--data-dir",
|
||||
str(test_data_dir),
|
||||
"--query",
|
||||
"What is Pride and Prejudice about?",
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
|
||||
|
||||
# Check return code
|
||||
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||
|
||||
# Verify output
|
||||
output = result.stdout + result.stderr
|
||||
assert "Leann index built at" in output or "Using existing index" in output
|
||||
assert "This is a simulated answer" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
||||
def test_main_cli_openai(test_data_dir):
|
||||
"""Test main_cli with OpenAI embeddings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a subdirectory that doesn't exist yet to force index creation
|
||||
index_dir = Path(temp_dir) / "test_index_openai"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"examples/main_cli_example.py",
|
||||
"--llm",
|
||||
"simulated", # Use simulated LLM to avoid GPT-4 costs
|
||||
"--embedding-model",
|
||||
"text-embedding-3-small",
|
||||
"--embedding-mode",
|
||||
"openai",
|
||||
"--index-dir",
|
||||
str(index_dir),
|
||||
"--data-dir",
|
||||
str(test_data_dir),
|
||||
"--query",
|
||||
"What is Pride and Prejudice about?",
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
|
||||
|
||||
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||
|
||||
# Verify cosine distance was used
|
||||
output = result.stdout + result.stderr
|
||||
assert any(
|
||||
msg in output
|
||||
for msg in [
|
||||
"distance_metric='cosine'",
|
||||
"Automatically setting distance_metric='cosine'",
|
||||
"Using cosine distance",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_main_cli_error_handling(test_data_dir):
|
||||
"""Test main_cli with invalid parameters."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"examples/main_cli_example.py",
|
||||
"--llm",
|
||||
"invalid_llm_type",
|
||||
"--index-dir",
|
||||
temp_dir,
|
||||
"--data-dir",
|
||||
str(test_data_dir),
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
|
||||
# Should fail with invalid LLM type
|
||||
assert result.returncode != 0
|
||||
assert "Unknown LLM type" in result.stderr or "invalid_llm_type" in result.stderr
|
||||
165
tests/test_readme_examples.py
Normal file
165
tests/test_readme_examples.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Test examples from README.md to ensure documentation is accurate.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_readme_basic_example():
|
||||
"""Test the basic example from README.md."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
# This is the exact code from README (with smaller model for CI)
|
||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||
from leann.api import SearchResult
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
# In CI, use a smaller model to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
||||
dimensions=384, # Smaller dimensions
|
||||
)
|
||||
else:
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Verify index was created
|
||||
# The index path should be a directory containing index files
|
||||
index_dir = Path(INDEX_PATH).parent
|
||||
assert index_dir.exists()
|
||||
# Check that index files were created
|
||||
index_files = list(index_dir.glob(f"{Path(INDEX_PATH).stem}.*"))
|
||||
assert len(index_files) > 0
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Verify search results
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], SearchResult)
|
||||
# The second text about banana-crocodile should be more relevant
|
||||
assert "banana" in results[0].text or "crocodile" in results[0].text
|
||||
|
||||
# Chat with your data (using simulated LLM to avoid external dependencies)
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
# Verify chat works
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
def test_readme_imports():
|
||||
"""Test that the imports shown in README work correctly."""
|
||||
# These are the imports shown in README
|
||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
# Verify they are the correct types
|
||||
assert callable(LeannBuilder)
|
||||
assert callable(LeannSearcher)
|
||||
assert callable(LeannChat)
|
||||
|
||||
|
||||
def test_backend_options():
|
||||
"""Test different backend options mentioned in documentation."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use smaller model in CI to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
model_args = {
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"dimensions": 384,
|
||||
}
|
||||
else:
|
||||
model_args = {}
|
||||
|
||||
# Test HNSW backend (as shown in README)
|
||||
hnsw_path = str(Path(temp_dir) / "test_hnsw.leann")
|
||||
builder_hnsw = LeannBuilder(backend_name="hnsw", **model_args)
|
||||
builder_hnsw.add_text("Test document for HNSW backend")
|
||||
builder_hnsw.build_index(hnsw_path)
|
||||
assert Path(hnsw_path).parent.exists()
|
||||
assert len(list(Path(hnsw_path).parent.glob(f"{Path(hnsw_path).stem}.*"))) > 0
|
||||
|
||||
# Test DiskANN backend (mentioned as available option)
|
||||
diskann_path = str(Path(temp_dir) / "test_diskann.leann")
|
||||
builder_diskann = LeannBuilder(backend_name="diskann", **model_args)
|
||||
builder_diskann.add_text("Test document for DiskANN backend")
|
||||
builder_diskann.build_index(diskann_path)
|
||||
assert Path(diskann_path).parent.exists()
|
||||
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
||||
|
||||
|
||||
def test_llm_config_simulated():
|
||||
"""Test simulated LLM configuration option."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
from leann import LeannBuilder, LeannChat
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Build a simple index
|
||||
index_path = str(Path(temp_dir) / "test.leann")
|
||||
# Use smaller model in CI to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
dimensions=384,
|
||||
)
|
||||
else:
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("Test document for LLM testing")
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test simulated LLM config
|
||||
llm_config = {"type": "simulated"}
|
||||
chat = LeannChat(index_path, llm_config=llm_config)
|
||||
response = chat.ask("What is this document about?", top_k=1)
|
||||
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires HF model download and may timeout")
|
||||
def test_llm_config_hf():
|
||||
"""Test HuggingFace LLM configuration option."""
|
||||
from leann import LeannBuilder, LeannChat
|
||||
|
||||
pytest.importorskip("transformers") # Skip if transformers not installed
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Build a simple index
|
||||
index_path = str(Path(temp_dir) / "test.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("Test document for LLM testing")
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test HF LLM config
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-0.6B"}
|
||||
chat = LeannChat(index_path, llm_config=llm_config)
|
||||
response = chat.ask("What is this document about?", top_k=1)
|
||||
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
Reference in New Issue
Block a user