Compare commits

..

7 Commits

Author SHA1 Message Date
Andy Lee
3dc130760a fix: restore macOS 15 build matrix and correct test path
- Add back macOS 15 configurations for Python 3.9-3.13
- Fix pytest path from test/ to tests/ (correct directory name)

The macOS 15 support was accidentally missing from the matrix, and
pytest was looking for the wrong directory name.
2025-08-12 12:50:33 -07:00
Andy Lee
2761067b7b fix: correct macOS deployment targets based on Homebrew library requirements
The key insight is that Homebrew libraries on each macOS version are
compiled for that specific version:
- macOS 13: Libraries require macOS 13.0 minimum
- macOS 14: Libraries require macOS 14.0 minimum
- macOS 15: Libraries require macOS 15.0 minimum

We cannot build wheels for older macOS versions than what the bundled
Homebrew libraries require. This means:
- macOS 13 runners: Build for macOS 13.0+ (HNSW) and 13.3+ (DiskANN)
- macOS 14 runners: Build for macOS 14.0+
- macOS 15 runners: Build for macOS 15.0+

This ensures delocate-wheel succeeds by matching deployment targets
with the actual minimum versions required by system libraries.
2025-08-12 12:34:56 -07:00
Andy Lee
5f57f4763b fix: add macOS 15 support to deployment target configuration
The issue extends to macOS 15 runners where Homebrew libraries are built
for macOS 15. We must handle all runner versions explicitly:

- macOS 13 runners: Can build for macOS 11.0 (HNSW) and 13.3 (DiskANN)
- macOS 14 runners: Must build for macOS 14.0 (system libraries)
- macOS 15 runners: Must build for macOS 15.0 (system libraries)

This ensures wheels are properly tagged for their actual minimum
supported macOS version, matching the bundled libraries.
2025-08-12 11:48:06 -07:00
Andy Lee
9e01e69038 fix: match deployment target with runner OS for library compatibility
The issue is that Homebrew libraries on macOS 14 runners are built for
macOS 14 and cannot be downgraded. We must use different deployment
targets based on the runner OS:

- macOS 13 runners: Can build for macOS 11.0 (HNSW) and 13.3 (DiskANN)
- macOS 14 runners: Must build for macOS 14.0 (due to system libraries)

This ensures delocate-wheel succeeds by matching the deployment target
with the actual minimum version required by bundled libraries.
2025-08-12 11:30:23 -07:00
Andy Lee
d336f3dbf6 fix: use macOS 13.3 for DiskANN backend as required by LAPACK
DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function, so we must
use 13.3 as the deployment target, not 13.0.
2025-08-12 10:59:48 -07:00
Andy Lee
acf3034171 fix: ensure wheels are compatible with older macOS versions
- Set MACOSX_DEPLOYMENT_TARGET=11.0 for HNSW backend (broad compatibility)
- Set MACOSX_DEPLOYMENT_TARGET=13.0 for DiskANN backend (required for LAPACK)
- Add --require-target-macos-version to delocate-wheel commands
- This fixes CI failures on macos-13 runners while maintaining M4 Mac support

Fixes the issue where wheels built on macos-14 runners were incorrectly
tagged as macosx_14_0, preventing installation on macos-13 runners.
2025-08-12 10:58:35 -07:00
Andy Lee
04623b6be0 feat: add macOS 15 support for M4 Mac compatibility
- Add macos-15 CI builds for Python 3.9-3.13
- Update MACOSX_DEPLOYMENT_TARGET from 11.0/13.3 to 14.0 for broader compatibility
- Addresses issue #34 with Mac M4 wheel compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-12 00:06:19 -07:00
118 changed files with 3743 additions and 20641 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text

View File

@@ -1,50 +0,0 @@
name: Bug Report
description: Report a bug in LEANN
labels: ["bug"]
body:
- type: textarea
id: description
attributes:
label: What happened?
description: A clear description of the bug
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: How to reproduce
placeholder: |
1. Install with...
2. Run command...
3. See error
validations:
required: true
- type: textarea
id: error
attributes:
label: Error message
description: Paste any error messages
render: shell
- type: input
id: version
attributes:
label: LEANN Version
placeholder: "0.1.0"
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating System
options:
- macOS
- Linux
- Windows
- Docker
validations:
required: true

View File

@@ -1,8 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Documentation
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
about: Read the docs first
- name: Discussions
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
about: Ask questions and share ideas

View File

@@ -1,27 +0,0 @@
name: Feature Request
description: Suggest a new feature for LEANN
labels: ["enhancement"]
body:
- type: textarea
id: problem
attributes:
label: What problem does this solve?
description: Describe the problem or need
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: How would you like this to work?
validations:
required: true
- type: textarea
id: example
attributes:
label: Example usage
description: Show how the API might look
render: python

View File

@@ -1,13 +0,0 @@
## What does this PR do?
<!-- Brief description of your changes -->
## Related Issues
Fixes #
## Checklist
- [ ] Tests pass (`uv run pytest`)
- [ ] Code formatted (`ruff format` and `ruff check`)
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)

View File

@@ -5,7 +5,6 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
jobs:
build:

View File

@@ -17,17 +17,26 @@ jobs:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Install uv and Python
uses: astral-sh/setup-uv@v6
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Run pre-commit with only lint group (no project deps)
run: |
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: |
uv tool install ruff
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build:
needs: lint
@@ -45,17 +54,6 @@ jobs:
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
# ARM64 Linux builds
- os: ubuntu-24.04-arm
python: '3.9'
- os: ubuntu-24.04-arm
python: '3.10'
- os: ubuntu-24.04-arm
python: '3.11'
- os: ubuntu-24.04-arm
python: '3.12'
- os: ubuntu-24.04-arm
python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14
@@ -89,64 +87,32 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Install uv and Python
uses: astral-sh/setup-uv@v6
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
patchelf
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Debug: Show system information
echo "🔍 System Information:"
echo "Architecture: $(uname -m)"
echo "OS: $(uname -a)"
echo "CPU info: $(lscpu | head -5)"
# Install math library based on architecture
ARCH=$(uname -m)
echo "🔍 Setting up math library for architecture: $ARCH"
if [[ "$ARCH" == "x86_64" ]]; then
# Install Intel MKL for DiskANN on x86_64
echo "📦 Installing Intel MKL for x86_64..."
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
echo "✅ Intel MKL installed for x86_64"
# Debug: Check MKL installation
echo "🔍 MKL Installation Check:"
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
elif [[ "$ARCH" == "aarch64" ]]; then
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
echo "📦 Installing OpenBLAS for ARM64..."
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
echo "✅ OpenBLAS installed for ARM64"
# Debug: Check OpenBLAS installation
echo "🔍 OpenBLAS Installation Check:"
dpkg -l | grep openblas || echo "OpenBLAS package not found"
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
fi
# Debug: Show final library paths
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
@@ -156,24 +122,11 @@ jobs:
- name: Install build dependencies
run: |
uv python install ${{ matrix.python }}
uv venv --python ${{ matrix.python }} .uv-build
if [[ "$RUNNER_OS" == "Windows" ]]; then
BUILD_PY=".uv-build\\Scripts\\python.exe"
else
BUILD_PY=".uv-build/bin/python"
fi
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --python "$BUILD_PY" auditwheel
uv pip install --system auditwheel
else
uv pip install --python "$BUILD_PY" delocate
fi
if [[ "$RUNNER_OS" == "Windows" ]]; then
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
else
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
uv pip install --system delocate
fi
- name: Set macOS environment variables
@@ -309,79 +262,34 @@ jobs:
- name: Install built packages for testing
run: |
# Create uv-managed virtual environment with the requested interpreter
uv python install ${{ matrix.python }}
# Create a virtual environment with the correct Python version
uv venv --python ${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate
if [[ "$RUNNER_OS" == "Windows" ]]; then
UV_PY=".venv\\Scripts\\python.exe"
else
UV_PY=".venv/bin/python"
fi
# Install packages using --find-links to prioritize local builds
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
# Install test dependency group only (avoids reinstalling project package)
uv pip install --python "$UV_PY" --group test
# Install core wheel built in this job
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$CORE_WHL" ]]; then
uv pip install --python "$UV_PY" "$CORE_WHL"
else
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
fi
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
fi
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$HNSW_WHL" ]]; then
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$HNSW_WHL" ]]; then
uv pip install --python "$UV_PY" "$HNSW_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
fi
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$DISKANN_WHL" ]]; then
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$DISKANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$DISKANN_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
fi
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$LEANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$LEANN_WHL"
else
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
fi
# Install test dependencies using extras
uv pip install -e ".[test]"
- name: Run tests with pytest
env:
CI: true
CI: true # Mark as CI environment to skip memory-intensive tests
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
MKL_NUM_THREADS: 1 # Single thread for MKL operations
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
pytest tests/ -v --tb=short
# Run tests
pytest -v tests/
- name: Run sanity checks (optional)
run: |
@@ -399,53 +307,3 @@ jobs:
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/
arch-smoke:
name: Arch Linux smoke test (install & import)
needs: build
runs-on: ubuntu-latest
container:
image: archlinux:latest
steps:
- name: Prepare system
run: |
pacman -Syu --noconfirm
pacman -S --noconfirm python python-pip gcc git zlib openssl
- name: Download ALL wheel artifacts from this run
uses: actions/download-artifact@v5
with:
# Don't specify name, download all artifacts
path: ./wheels
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Create virtual environment and install wheels
run: |
uv venv
source .venv/bin/activate || source .venv/Scripts/activate
uv pip install --find-links wheels leann-core
uv pip install --find-links wheels leann-backend-hnsw
uv pip install --find-links wheels leann-backend-diskann
uv pip install --find-links wheels leann
- name: Import & tiny runtime check
env:
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
run: |
source .venv/bin/activate || source .venv/Scripts/activate
python - <<'PY'
import leann
import leann_backend_hnsw as h
import leann_backend_diskann as d
from leann import LeannBuilder, LeannSearcher
b = LeannBuilder(backend_name="hnsw")
b.add_text("hello arch")
b.build_index("arch_demo.leann")
s = LeannSearcher("arch_demo.leann")
print("search:", s.search("hello", top_k=1))
PY

View File

@@ -14,6 +14,6 @@ jobs:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

13
.gitignore vendored
View File

@@ -18,12 +18,9 @@ demo/experiment_results/**/*.json
*.eml
*.emlx
*.json
*.png
!.vscode/*.json
*.sh
*.txt
!CMakeLists.txt
!llms.txt
latency_breakdown*.json
experiment_results/eval_results/diskann/*.json
aws/
@@ -95,13 +92,3 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
benchmarks/data/
## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*

3
.gitmodules vendored
View File

@@ -14,6 +14,3 @@
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git
[submodule "packages/astchunk-leann"]
path = packages/astchunk-leann
url = https://github.com/yichuan-w/astchunk-leann.git

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@@ -10,8 +10,7 @@ repos:
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7 # Fixed version to match pyproject.toml
rev: v0.2.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

View File

@@ -1,5 +0,0 @@
{
"recommendations": [
"charliermarsh.ruff",
]
}

22
.vscode/settings.json vendored
View File

@@ -1,22 +0,0 @@
{
"python.defaultInterpreterPath": ".venv/bin/python",
"python.terminal.activateEnvironment": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",
"source.fixAll": "explicit"
},
"editor.insertSpaces": true,
"editor.tabSize": 4
},
"ruff.enable": true,
"files.watcherExclude": {
"**/.venv/**": true,
"**/__pycache__/**": true,
"**/*.egg-info/**": true,
"**/build/**": true,
"**/dist/**": true
}
}

549
README.md
View File

@@ -5,11 +5,9 @@
<p align="center">
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
@@ -20,7 +18,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
@@ -33,7 +31,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p>
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
@@ -72,8 +70,6 @@ uv venv
source .venv/bin/activate
uv pip install leann
```
<!--
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
<details>
<summary>
@@ -89,60 +85,15 @@ git submodule update --init --recursive
```
**macOS:**
Note: DiskANN requires MacOS 13.3 or later.
```bash
brew install libomp boost protobuf zeromq pkgconf
uv sync --extra diskann
brew install llvm libomp boost protobuf zeromq pkgconf
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
```
**Linux (Ubuntu/Debian):**
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
**Linux:**
```bash
sudo apt-get update && sudo apt-get install -y \
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
libmkl-full-dev
uv sync --extra diskann
```
**Linux (Arch Linux):**
```bash
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
boost boost-libs protobuf abseil-cpp libaio zeromq
# For MKL in DiskANN
sudo pacman -S --needed base-devel git
git clone https://aur.archlinux.org/paru-bin.git
cd paru-bin && makepkg -si
paru -S intel-oneapi-mkl intel-oneapi-compiler
source /opt/intel/oneapi/setvars.sh
uv sync --extra diskann
```
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
```bash
sudo dnf groupinstall -y "Development Tools"
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
# For MKL in DiskANN
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
source /opt/intel/oneapi/setvars.sh
uv sync --extra diskann
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
uv sync
```
</details>
@@ -176,16 +127,11 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
## RAG on Everything!
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and more.
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
### Generation Model Setup
#### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
<details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
@@ -196,68 +142,6 @@ Set your OpenAI API key as an environment variable:
export OPENAI_API_KEY="your-api-key-here"
```
Make sure to use `--llm openai` flag when using the CLI.
You can also specify the model name with `--llm-model <model-name>` flag.
</details>
<details>
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
```sh
export OPENAI_API_KEY="xxx"
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
```
To use OpenAI compatible endpoint with the CLI interface:
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
-----
Below is a list of base URLs for common providers to get you started.
### 🖥️ Local Inference Engines (Recommended for full privacy)
| Provider | Sample Base URL |
| ---------------- | --------------------------- |
| **Ollama** | `http://localhost:11434/v1` |
| **LM Studio** | `http://localhost:1234/v1` |
| **vLLM** | `http://localhost:8000/v1` |
| **llama.cpp** | `http://localhost:8080/v1` |
| **SGLang** | `http://localhost:30000/v1` |
| **LiteLLM** | `http://localhost:4000` |
-----
### ☁️ Cloud Providers
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
| Provider | Base URL |
| ---------------- | ---------------------------------------------------------- |
| **OpenAI** | `https://api.openai.com/v1` |
| **OpenRouter** | `https://openrouter.ai/api/v1` |
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
| **x.AI (Grok)** | `https://api.x.ai/v1` |
| **Groq AI** | `https://api.groq.com/openai/v1` |
| **DeepSeek** | `https://api.deepseek.com/v1` |
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` |
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
</details>
<details>
@@ -287,8 +171,7 @@ ollama pull llama3.2:1b
</details>
## ⭐ Flexible Configuration
### ⭐ Flexible Configuration
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
@@ -301,34 +184,34 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
```bash
# Core Parameters (General preprocessing for all examples)
--index-dir DIR # Directory to store the index (default: current directory)
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
--max-items N # Limit data preprocessing (default: -1, process all data)
--force-rebuild # Force rebuild index even if it exists
--index-dir DIR # Directory to store the index (default: current directory)
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
--max-items N # Limit data preprocessing (default: -1, process all data)
--force-rebuild # Force rebuild index even if it exists
# Embedding Parameters
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
# Search Parameters
--top-k N # Number of results to retrieve (default: 20)
--search-complexity N # Search complexity for graph traversal (default: 32)
--top-k N # Number of results to retrieve (default: 20)
--search-complexity N # Search complexity for graph traversal (default: 32)
# Chunking Parameters
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
# Index Building Parameters
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
--graph-degree N # Graph degree for index construction (default: 32)
--build-complexity N # Build complexity for index construction (default: 64)
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
--graph-degree N # Graph degree for index construction (default: 32)
--build-complexity N # Build complexity for index construction (default: 64)
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
```
</details>
@@ -364,12 +247,6 @@ python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
# Filter only markdown and Python files with smaller chunks
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
# Enable AST-aware chunking for code files
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
# Or use the specialized code RAG for better code understanding
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
```
</details>
@@ -542,268 +419,26 @@ Once the index is built, you can ask questions like:
</details>
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
```bash
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
```
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
<details>
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
**Step-by-step export process:**
1. **Sign in to ChatGPT**
2. **Click your profile icon** in the top right corner
3. **Navigate to Settings** → **Data Controls**
4. **Click "Export"** under Export Data
5. **Confirm the export** request
6. **Download the ZIP file** from the email link (expires in 24 hours)
7. **Extract or use directly** with LEANN
**Supported formats:**
- `.html` files from ChatGPT exports
- `.zip` archives from ChatGPT
- Directories with multiple export files
</details>
<details>
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
#### Parameters
```bash
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
--separate-messages # Process each message separately instead of concatenated conversations
--chunk-size N # Text chunk size (default: 512)
--chunk-overlap N # Overlap between chunks (default: 128)
```
#### Example Commands
```bash
# Basic usage with HTML export
python -m apps.chatgpt_rag --export-path conversations.html
# Process ZIP archive from ChatGPT
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
# Search with specific query
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
# Process individual messages for fine-grained search
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
# Process directory containing multiple exports
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
```
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
Once your ChatGPT conversations are indexed, you can search with queries like:
- "What did I ask ChatGPT about Python programming?"
- "Show me conversations about machine learning algorithms"
- "Find discussions about web development frameworks"
- "What coding advice did ChatGPT give me?"
- "Search for conversations about debugging techniques"
- "Find ChatGPT's recommendations for learning resources"
</details>
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
```bash
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
```
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
<details>
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
**Step-by-step export process:**
1. **Open Claude** in your browser
2. **Navigate to Settings** (look for gear icon or settings menu)
3. **Find Export/Download** options in your account settings
4. **Download conversation data** (usually in JSON format)
5. **Place the file** in your project directory
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
**Supported formats:**
- `.json` files (recommended)
- `.zip` archives containing JSON data
- Directories with multiple export files
</details>
<details>
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
#### Parameters
```bash
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
--separate-messages # Process each message separately instead of concatenated conversations
--chunk-size N # Text chunk size (default: 512)
--chunk-overlap N # Overlap between chunks (default: 128)
```
#### Example Commands
```bash
# Basic usage with JSON export
python -m apps.claude_rag --export-path my_claude_conversations.json
# Process ZIP archive from Claude
python -m apps.claude_rag --export-path claude_export.zip
# Search with specific query
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
# Process individual messages for fine-grained search
python -m apps.claude_rag --separate-messages --export-path claude_export.json
# Process directory containing multiple exports
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
```
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
Once your Claude conversations are indexed, you can search with queries like:
- "What did I ask Claude about Python programming?"
- "Show me conversations about machine learning algorithms"
- "Find discussions about software architecture patterns"
- "What debugging advice did Claude give me?"
- "Search for conversations about data structures"
- "Find Claude's recommendations for learning resources"
</details>
### 💬 iMessage History: Your Personal Conversation Archive!
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
```bash
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
```
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
<details>
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
**iMessage data location:**
iMessage conversations are stored in a SQLite database on your Mac at:
```
~/Library/Messages/chat.db
```
**Important setup requirements:**
1. **Grant Full Disk Access** to your terminal or IDE:
- Open **System Preferences** → **Security & Privacy** → **Privacy**
- Select **Full Disk Access** from the left sidebar
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
- Restart your terminal/IDE after granting access
2. **Alternative: Use a backup database**
- If you have Time Machine backups or manual copies of the database
- Use `--db-path` to specify a custom location
**Supported formats:**
- Direct access to `~/Library/Messages/chat.db` (default)
- Custom database path with `--db-path`
- Works with backup copies of the database
</details>
<details>
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
#### Parameters
```bash
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
--concatenate-conversations # Group messages by conversation (default: True)
--no-concatenate-conversations # Process each message individually
--chunk-size N # Text chunk size (default: 1000)
--chunk-overlap N # Overlap between chunks (default: 200)
```
#### Example Commands
```bash
# Basic usage (requires Full Disk Access)
python -m apps.imessage_rag
# Search with specific query
python -m apps.imessage_rag --query "family dinner plans"
# Use custom database path
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
# Process individual messages instead of conversations
python -m apps.imessage_rag --no-concatenate-conversations
# Limit processing for testing
python -m apps.imessage_rag --max-items 100 --query "weekend"
```
</details>
<details>
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
Once your iMessage conversations are indexed, you can search with queries like:
- "What did we discuss about vacation plans?"
- "Find messages about restaurant recommendations"
- "Show me conversations with John about the project"
- "Search for shared links about technology"
- "Find group chat discussions about weekend events"
- "What did mom say about the family gathering?"
</details>
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
</details>
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
**Key features:**
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
- 🔍 **Semantic code search** across your entire project
- 📚 **Context-aware assistance** for debugging and development
- 🚀 **Zero-config setup** with automatic language detection
```bash
# Install LEANN globally for MCP integration
uv tool install leann-core --with leann
claude mcp add --scope user leann-server -- leann_mcp
uv tool install leann-core
# Setup is automatic - just start using Claude Code!
```
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
![LEANN MCP Integration](assets/mcp_leann.png)
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
## 🖥️ Command Line Interface
@@ -820,8 +455,7 @@ leann --help
**To make it globally available:**
```bash
# Install the LEANN CLI globally using uv tool
uv tool install leann-core --with leann
uv tool install leann
# Now you can use leann from anywhere without activating venv
leann --help
@@ -834,7 +468,7 @@ leann --help
### Usage Examples
```bash
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
# build from a specific directory, and my_docs is the index name
leann build my-docs --docs ./your_documents
# Search your documents
@@ -843,41 +477,32 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents
leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes
leann list
# Remove an index
leann remove my-docs
```
**Key CLI features:**
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
- Smart text chunking with overlap for all other content
- Auto-detects document formats (PDF, TXT, MD, DOCX)
- Smart text chunking with overlap
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
- Organized index storage in `.leann/indexes/` (project-local)
- Organized index storage in `~/.leann/indexes/`
- Support for advanced search parameters
<details>
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
**Build Command:**
```bash
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
Options:
--backend {hnsw,diskann} Backend to use (default: hnsw)
--embedding-model MODEL Embedding model (default: facebook/contriever)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
--recompute / --no-recompute Enable recomputation (default: true)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact Use compact storage (default: true)
--recompute Enable recomputation (default: true)
```
**Search Command:**
@@ -885,9 +510,9 @@ Options:
leann search INDEX_NAME QUERY [OPTIONS]
Options:
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute-embeddings Use recomputation for highest accuracy
--pruning-strategy {global,local,proportional}
```
@@ -902,73 +527,8 @@ Options:
--top-k N Retrieval count (default: 20)
```
**List Command:**
```bash
leann list
# Lists all indexes across all projects with status indicators:
# ✅ - Index is complete and ready to use
# ❌ - Index is incomplete or corrupted
# 📁 - CLI-created index (in .leann/indexes/)
# 📄 - App-created index (*.leann.meta.json files)
```
**Remove Command:**
```bash
leann remove INDEX_NAME [OPTIONS]
Options:
--force, -f Force removal without confirmation
# Smart removal: automatically finds and safely removes indexes
# - Shows all matching indexes across projects
# - Requires confirmation for cross-project removal
# - Interactive selection when multiple matches found
# - Supports both CLI and app-created indexes
```
</details>
## 🚀 Advanced Features
### 🎯 Metadata Filtering
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
```python
# Add metadata during indexing
builder.add_text(
"def authenticate_user(token): ...",
metadata={"file_extension": ".py", "lines_of_code": 25}
)
# Search with filters
results = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
### 🔍 Grep Search
For exact text matching instead of semantic search, use the `use_grep` parameter:
```python
# Exact text search
results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
```
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
📖 **[Complete grep search guide →](docs/grep_search.md)**
## 🏗️ Architecture & How It Works
<p align="center">
@@ -983,16 +543,12 @@ results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:**
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
## Benchmarks
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
### 📊 Storage Comparison
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
@@ -1006,8 +562,8 @@ results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
## Reproduce Our Results
```bash
uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
uv pip install -e ".[dev]" # Install dev dependencies
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
@@ -1047,9 +603,6 @@ MIT License - see [LICENSE](LICENSE) for details.
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
We welcome more contributors! Feel free to open issues or submit PRs.
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).

View File

@@ -10,9 +10,7 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.interactive_utils import create_rag_session
from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
@@ -71,32 +69,14 @@ class BaseRAGExample(ABC):
"--embedding-model",
type=str,
default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
help=f"Embedding model to use (default: {embedding_model_default})",
)
embedding_group.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
help="Embedding backend mode (default: sentence-transformers)",
)
# LLM parameters
@@ -106,19 +86,19 @@ class BaseRAGExample(ABC):
type=str,
default="openai",
choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend: openai, ollama, or hf (default: openai)",
help="LLM backend to use (default: openai)",
)
llm_group.add_argument(
"--llm-model",
type=str,
default=None,
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
)
llm_group.add_argument(
"--llm-host",
type=str,
default=None,
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
)
llm_group.add_argument(
"--thinking-budget",
@@ -127,50 +107,6 @@ class BaseRAGExample(ABC):
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
ast_group.add_argument(
"--use-ast-chunking",
action="store_true",
help="Enable AST-aware chunking for code files (requires astchunk)",
)
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=512,
help="Maximum characters per AST chunk (default: 512)",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks (default: 64)",
)
ast_group.add_argument(
"--code-file-extensions",
nargs="+",
default=None,
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
)
ast_group.add_argument(
"--ast-fallback-traditional",
action="store_true",
default=True,
help="Fall back to traditional chunking if AST chunking fails (default: True)",
)
# Search parameters
search_group = parser.add_argument_group("Search Parameters")
@@ -237,18 +173,11 @@ class BaseRAGExample(ABC):
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = resolve_ollama_host(args.llm_host)
config["host"] = args.llm_host
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config
@@ -259,20 +188,10 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
@@ -292,11 +211,6 @@ class BaseRAGExample(ABC):
builder.build_index(index_path)
print(f"Index saved to: {index_path}")
# Register project directory so leann list can discover this index
# The index is saved as args.index_dir/index_name.leann
# We want to register the current working directory where the app is run
register_project_directory(Path.cwd())
return index_path
async def run_interactive_chat(self, args, index_path: str):
@@ -308,32 +222,44 @@ class BaseRAGExample(ABC):
complexity=args.search_complexity,
)
# Create interactive session
session = create_rag_session(
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
)
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
print("Type 'quit' or 'exit' to stop.\n")
def handle_query(query: str):
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
while True:
try:
query = input("You: ").strip()
if query.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
if not query:
continue
session.run_interactive_loop(handle_query)
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)
@@ -375,3 +301,21 @@ class BaseRAGExample(ABC):
await self.run_single_query(args, index_path, args.query)
else:
await self.run_interactive_chat(args, index_path)
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
"""Helper function to create text chunks from documents."""
node_parser = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
all_texts.extend(node.get_content() for node in nodes)
return all_texts

View File

@@ -10,8 +10,7 @@ from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from base_rag_example import BaseRAGExample, create_text_chunks
from .history_data.history import ChromeHistoryReader

View File

View File

@@ -1,413 +0,0 @@
"""
ChatGPT export data reader.
Reads and processes ChatGPT export data from chat.html files.
"""
import re
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChatGPTReader(BaseReader):
"""
ChatGPT export data reader.
Reads ChatGPT conversation data from exported chat.html files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
self.concatenate_conversations = concatenate_conversations
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
"""
Extract chat.html from ChatGPT export zip file.
Args:
zip_path: Path to the ChatGPT export zip file
Returns:
HTML content as string, or None if not found
"""
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for chat.html or conversations.html
html_files = [
f
for f in zip_file.namelist()
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
]
if not html_files:
print(f"No HTML chat file found in {zip_path}")
return None
# Use the first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
with zip_file.open(html_file) as f:
return f.read().decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error extracting HTML from zip {zip_path}: {e}")
return None
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
"""
Parse ChatGPT HTML export to extract conversations.
Args:
html_content: HTML content from ChatGPT export
Returns:
List of conversation dictionaries
"""
soup = BeautifulSoup(html_content, "html.parser")
conversations = []
# Try different possible structures for ChatGPT exports
# Structure 1: Look for conversation containers
conversation_containers = soup.find_all(
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
)
if not conversation_containers:
# Structure 2: Look for message containers directly
conversation_containers = [soup] # Use the entire document as one conversation
for container in conversation_containers:
conversation = self._extract_conversation_from_container(container)
if conversation and conversation.get("messages"):
conversations.append(conversation)
# If no structured conversations found, try to extract all text as one conversation
if not conversations:
all_text = soup.get_text(separator="\n", strip=True)
if all_text:
conversations.append(
{
"title": "ChatGPT Conversation",
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
"timestamp": None,
}
)
return conversations
def _extract_conversation_from_container(self, container) -> dict | None:
"""
Extract conversation data from a container element.
Args:
container: BeautifulSoup element containing conversation
Returns:
Dictionary with conversation data or None
"""
messages = []
# Look for message elements with various possible structures
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
for selector in message_selectors:
message_elements = container.select(selector)
if message_elements:
break
else:
message_elements = []
# If no structured messages found, treat the entire container as one message
if not message_elements:
text_content = container.get_text(separator="\n", strip=True)
if text_content:
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
else:
for element in message_elements:
message = self._extract_message_from_element(element)
if message:
messages.append(message)
if not messages:
return None
# Try to extract conversation title
title_element = container.find(["h1", "h2", "h3", "title"])
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
# Try to extract timestamp from various possible locations
timestamp = self._extract_timestamp_from_container(container)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_element(self, element) -> dict | None:
"""
Extract message data from an element.
Args:
element: BeautifulSoup element containing message
Returns:
Dictionary with message data or None
"""
text_content = element.get_text(separator=" ", strip=True)
# Skip empty or very short messages
if not text_content or len(text_content.strip()) < 3:
return None
# Try to determine role (user/assistant) from class names or content
role = "mixed" # Default role
class_names = " ".join(element.get("class", [])).lower()
if "user" in class_names or "human" in class_names:
role = "user"
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
role = "assistant"
elif text_content.lower().startswith(("you:", "user:", "me:")):
role = "user"
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
role = "assistant"
text_content = re.sub(
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
)
# Try to extract timestamp
timestamp = self._extract_timestamp_from_element(element)
return {"role": role, "content": text_content, "timestamp": timestamp}
def _extract_timestamp_from_element(self, element) -> str | None:
"""Extract timestamp from element."""
# Look for timestamp in various attributes and child elements
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
for attr in timestamp_attrs:
if element.get(attr):
return element.get(attr)
# Look for time elements
time_element = element.find("time")
if time_element:
return time_element.get("datetime") or time_element.get_text(strip=True)
# Look for date-like text patterns
text = element.get_text()
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
for pattern in date_patterns:
match = re.search(pattern, text)
if match:
return match.group()
return None
def _extract_timestamp_from_container(self, container) -> str | None:
"""Extract timestamp from conversation container."""
return self._extract_timestamp_from_element(container)
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "ChatGPT Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[ChatGPT]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load ChatGPT export data.
Args:
input_dir: Directory containing ChatGPT export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not chatgpt_export_path:
print("No ChatGPT export path provided")
return docs
export_path = Path(chatgpt_export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
return docs
html_content = None
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract HTML from zip file
html_content = self._extract_html_from_zip(export_path)
elif export_path.suffix.lower() == ".html":
# Read HTML file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for HTML files in directory
html_files = list(export_path.glob("*.html"))
zip_files = list(export_path.glob("*.zip"))
if html_files:
# Use first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
try:
with open(html_file, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {html_file}: {e}")
return docs
elif zip_files:
# Use first zip file found
zip_file = zip_files[0]
print(f"Found zip file: {zip_file}")
html_content = self._extract_html_from_zip(zip_file)
else:
print(f"No HTML or zip files found in {export_path}")
return docs
if not html_content:
print("No HTML content found to process")
return docs
# Parse conversations from HTML
print("Parsing ChatGPT conversations from HTML...")
conversations = self._parse_chatgpt_html(html_content)
if not conversations:
print("No conversations found in HTML content")
return docs
print(f"Found {len(conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "ChatGPT Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from ChatGPT export")
return docs

View File

@@ -1,186 +0,0 @@
"""
ChatGPT RAG example using the unified interface.
Supports ChatGPT export data from chat.html files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .chatgpt_data.chatgpt_reader import ChatGPTReader
class ChatGPTRAG(BaseRAGExample):
"""RAG example for ChatGPT conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="ChatGPT",
description="Process and query ChatGPT conversation exports with LEANN",
default_index_name="chatgpt_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add ChatGPT-specific arguments."""
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
chatgpt_group.add_argument(
"--export-path",
type=str,
default="./chatgpt_export",
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
)
chatgpt_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
chatgpt_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
chatgpt_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
chatgpt_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
"""
Find ChatGPT export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to ChatGPT export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".html"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and html files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.html"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load ChatGPT export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
print(
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
)
print("\nTo export your ChatGPT data:")
print("1. Sign in to ChatGPT")
print("2. Click on your profile icon → Settings → Data Controls")
print("3. Click 'Export' under Export Data")
print("4. Download the zip file from the email link")
print("5. Extract or place the file/directory at the specified path")
return []
# Find export files
export_files = self._find_chatgpt_exports(export_path)
if not export_files:
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
return []
print(f"Found {len(export_files)} ChatGPT export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ChatGPTReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
chatgpt_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid ChatGPT export")
print("- Check that the HTML file contains conversation data")
print("- Try extracting the zip file and pointing to the HTML file directly")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for ChatGPT RAG
print("\n🤖 ChatGPT RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about travel planning'")
print("- 'What advice did ChatGPT give me about career development?'")
print("- 'Search for conversations about cooking recipes'")
print("\nTo get started:")
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
print("3. Run this script to build your personal ChatGPT knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ChatGPTRAG()
asyncio.run(rag.run())

View File

@@ -1,44 +0,0 @@
"""Unified chunking utilities facade.
This module re-exports the packaged utilities from `leann.chunking_utils` so
that both repo apps (importing `chunking`) and installed wheels share one
single implementation. When running from the repo without installation, it
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
"""
import sys
from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
except Exception: # pragma: no cover - best-effort fallback for dev environment
repo_root = Path(__file__).resolve().parents[2]
leann_src = repo_root / "packages" / "leann-core" / "src"
if leann_src.exists():
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
else:
raise
__all__ = [
"CODE_EXTENSIONS",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",
"detect_code_files",
"get_language_from_extension",
]

View File

View File

@@ -1,420 +0,0 @@
"""
Claude export data reader.
Reads and processes Claude conversation data from exported JSON files.
"""
import json
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ClaudeReader(BaseReader):
"""
Claude export data reader.
Reads Claude conversation data from exported JSON files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
"""
Extract JSON files from Claude export zip file.
Args:
zip_path: Path to the Claude export zip file
Returns:
List of JSON content strings, or empty list if not found
"""
json_contents = []
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for JSON files
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
if not json_files:
print(f"No JSON files found in {zip_path}")
return []
print(f"Found {len(json_files)} JSON files in archive")
for json_file in json_files:
with zip_file.open(json_file) as f:
content = f.read().decode("utf-8", errors="ignore")
json_contents.append(content)
except Exception as e:
print(f"Error extracting JSON from zip {zip_path}: {e}")
return json_contents
def _parse_claude_json(self, json_content: str) -> list[dict]:
"""
Parse Claude JSON export to extract conversations.
Args:
json_content: JSON content from Claude export
Returns:
List of conversation dictionaries
"""
try:
data = json.loads(json_content)
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
return []
conversations = []
# Handle different possible JSON structures
if isinstance(data, list):
# If data is a list of conversations
for item in data:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif isinstance(data, dict):
# Check for common structures
if "conversations" in data:
# Structure: {"conversations": [...]}
for item in data["conversations"]:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif "messages" in data:
# Single conversation with messages
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
else:
# Try to treat the whole object as a conversation
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
return conversations
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
"""
Extract conversation data from a JSON object.
Args:
conv_data: Dictionary containing conversation data
Returns:
Dictionary with conversation data or None
"""
if not isinstance(conv_data, dict):
return None
messages = []
# Look for messages in various possible structures
message_sources = []
if "messages" in conv_data:
message_sources = conv_data["messages"]
elif "chat" in conv_data:
message_sources = conv_data["chat"]
elif "conversation" in conv_data:
message_sources = conv_data["conversation"]
else:
# If no clear message structure, try to extract from the object itself
if "content" in conv_data and "role" in conv_data:
message_sources = [conv_data]
for msg_data in message_sources:
message = self._extract_message_from_json(msg_data)
if message:
messages.append(message)
if not messages:
return None
# Extract conversation metadata
title = self._extract_title_from_conversation(conv_data, messages)
timestamp = self._extract_timestamp_from_conversation(conv_data)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
"""
Extract message data from a JSON message object.
Args:
msg_data: Dictionary containing message data
Returns:
Dictionary with message data or None
"""
if not isinstance(msg_data, dict):
return None
# Extract content from various possible fields
content = ""
content_fields = ["content", "text", "message", "body"]
for field in content_fields:
if msg_data.get(field):
content = str(msg_data[field])
break
if not content or len(content.strip()) < 3:
return None
# Extract role (user/assistant/human/ai/claude)
role = "mixed" # Default role
role_fields = ["role", "sender", "from", "author", "type"]
for field in role_fields:
if msg_data.get(field):
role_value = str(msg_data[field]).lower()
if role_value in ["user", "human", "person"]:
role = "user"
elif role_value in ["assistant", "ai", "claude", "bot"]:
role = "assistant"
break
# Extract timestamp
timestamp = self._extract_timestamp_from_message(msg_data)
return {"role": role, "content": content, "timestamp": timestamp}
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
"""Extract timestamp from message data."""
timestamp_fields = ["timestamp", "created_at", "date", "time"]
for field in timestamp_fields:
if msg_data.get(field):
return str(msg_data[field])
return None
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
"""Extract timestamp from conversation data."""
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
for field in timestamp_fields:
if conv_data.get(field):
return str(conv_data[field])
return None
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
"""Extract or generate title for conversation."""
# Try to find explicit title
title_fields = ["title", "name", "subject", "topic"]
for field in title_fields:
if conv_data.get(field):
return str(conv_data[field])
# Generate title from first user message
for message in messages:
if message.get("role") == "user":
content = message.get("content", "")
if content:
# Use first 50 characters as title
title = content[:50].strip()
if len(content) > 50:
title += "..."
return title
return "Claude Conversation"
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "Claude Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[Claude]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load Claude export data.
Args:
input_dir: Directory containing Claude export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
claude_export_path (str): Specific path to Claude export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not claude_export_path:
print("No Claude export path provided")
return docs
export_path = Path(claude_export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
return docs
json_contents = []
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract JSON from zip file
json_contents = self._extract_json_from_zip(export_path)
elif export_path.suffix.lower() == ".json":
# Read JSON file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for JSON files in directory
json_files = list(export_path.glob("*.json"))
zip_files = list(export_path.glob("*.zip"))
if json_files:
print(f"Found {len(json_files)} JSON files in directory")
for json_file in json_files:
try:
with open(json_file, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {json_file}: {e}")
continue
if zip_files:
print(f"Found {len(zip_files)} ZIP files in directory")
for zip_file in zip_files:
zip_contents = self._extract_json_from_zip(zip_file)
json_contents.extend(zip_contents)
if not json_files and not zip_files:
print(f"No JSON or ZIP files found in {export_path}")
return docs
if not json_contents:
print("No JSON content found to process")
return docs
# Parse conversations from JSON content
print("Parsing Claude conversations from JSON...")
all_conversations = []
for json_content in json_contents:
conversations = self._parse_claude_json(json_content)
all_conversations.extend(conversations)
if not all_conversations:
print("No conversations found in JSON content")
return docs
print(f"Found {len(all_conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in all_conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "Claude Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "Claude Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from Claude export")
return docs

View File

@@ -1,189 +0,0 @@
"""
Claude RAG example using the unified interface.
Supports Claude export data from JSON files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .claude_data.claude_reader import ClaudeReader
class ClaudeRAG(BaseRAGExample):
"""RAG example for Claude conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Claude",
description="Process and query Claude conversation exports with LEANN",
default_index_name="claude_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add Claude-specific arguments."""
claude_group = parser.add_argument_group("Claude Parameters")
claude_group.add_argument(
"--export-path",
type=str,
default="./claude_export",
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
)
claude_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
claude_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
claude_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
claude_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_claude_exports(self, export_path: Path) -> list[Path]:
"""
Find Claude export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to Claude export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".json"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and json files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.json"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load Claude export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
print(
"Please ensure you have exported your Claude data and placed it in the correct location."
)
print("\nTo export your Claude data:")
print("1. Open Claude in your browser")
print("2. Look for export/download options in settings or conversation menu")
print("3. Download the conversation data (usually in JSON format)")
print("4. Place the file/directory at the specified path")
print(
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
)
return []
# Find export files
export_files = self._find_claude_exports(export_path)
if not export_files:
print(f"No Claude export files (.json or .zip) found in: {export_path}")
return []
print(f"Found {len(export_files)} Claude export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ClaudeReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
claude_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid Claude export")
print("- Check that the JSON file contains conversation data")
print("- Try using a different export format or method")
print("- Check Claude's documentation for current export procedures")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for Claude RAG
print("\n🤖 Claude RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask Claude about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about code optimization'")
print("- 'What advice did Claude give me about software design?'")
print("- 'Search for conversations about debugging techniques'")
print("\nTo get started:")
print("1. Export your Claude conversation data")
print("2. Place the JSON/ZIP file in ./claude_export/")
print("3. Run this script to build your personal Claude knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ClaudeRAG()
asyncio.run(rag.run())

View File

@@ -1,211 +0,0 @@
"""
Code RAG example using AST-aware chunking for optimal code understanding.
Specialized for code repositories with automatic language detection and
optimized chunking parameters.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import CODE_EXTENSIONS, create_text_chunks
from llama_index.core import SimpleDirectoryReader
class CodeRAG(BaseRAGExample):
"""Specialized RAG example for code repositories with AST-aware chunking."""
def __init__(self):
super().__init__(
name="Code",
description="Process and query code repositories with AST-aware chunking",
default_index_name="code_index",
)
# Override defaults for code-specific usage
self.embedding_model_default = "facebook/contriever" # Good for code
self.max_items_default = -1 # Process all code files by default
def _add_specific_arguments(self, parser):
"""Add code-specific arguments."""
code_group = parser.add_argument_group("Code Repository Parameters")
code_group.add_argument(
"--repo-dir",
type=str,
default=".",
help="Code repository directory to index (default: current directory)",
)
code_group.add_argument(
"--include-extensions",
nargs="+",
default=list(CODE_EXTENSIONS.keys()),
help="File extensions to include (default: supported code extensions)",
)
code_group.add_argument(
"--exclude-dirs",
nargs="+",
default=[
".git",
"__pycache__",
"node_modules",
"venv",
".venv",
"build",
"dist",
"target",
],
help="Directories to exclude from indexing",
)
code_group.add_argument(
"--max-file-size",
type=int,
default=1000000, # 1MB
help="Maximum file size in bytes to process (default: 1MB)",
)
code_group.add_argument(
"--include-comments",
action="store_true",
help="Include comments in chunking (useful for documentation)",
)
code_group.add_argument(
"--preserve-imports",
action="store_true",
default=True,
help="Try to preserve import statements in chunks (default: True)",
)
async def load_data(self, args) -> list[str]:
"""Load code files and convert to AST-aware chunks."""
print(f"🔍 Scanning code repository: {args.repo_dir}")
print(f"📁 Including extensions: {args.include_extensions}")
print(f"🚫 Excluding directories: {args.exclude_dirs}")
# Check if repository directory exists
repo_path = Path(args.repo_dir)
if not repo_path.exists():
raise ValueError(f"Repository directory not found: {args.repo_dir}")
# Load code files with filtering
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
"required_exts": args.include_extensions,
"exclude_hidden": True,
}
# Create exclusion filter
def file_filter(file_path: str) -> bool:
"""Filter out unwanted files and directories."""
path = Path(file_path)
# Check file size
try:
if path.stat().st_size > args.max_file_size:
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
return False
except Exception:
return False
# Check if in excluded directory
for exclude_dir in args.exclude_dirs:
if exclude_dir in path.parts:
return False
return True
try:
# Load documents with file filtering
documents = SimpleDirectoryReader(
args.repo_dir,
file_extractor=None, # Use default extractors
**reader_kwargs,
).load_data(show_progress=True)
# Apply custom filtering
filtered_docs = []
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents = filtered_docs
except Exception as e:
print(f"❌ Error loading code files: {e}")
return []
if not documents:
print(
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
)
return []
print(f"✅ Loaded {len(documents)} code files")
# Show breakdown by language/extension
ext_counts = {}
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_path:
ext = Path(file_path).suffix.lower()
ext_counts[ext] = ext_counts.get(ext, 0) + 1
print("📊 Files by extension:")
for ext, count in sorted(ext_counts.items()):
print(f" {ext}: {count} files")
# Use AST-aware chunking by default for code
print(
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
)
all_texts = create_text_chunks(
documents,
chunk_size=256, # Fallback for non-code files
chunk_overlap=64,
use_ast_chunking=True, # Always use AST for code RAG
ast_chunk_size=args.ast_chunk_size,
ast_chunk_overlap=args.ast_chunk_overlap,
code_file_extensions=args.include_extensions,
ast_fallback_traditional=True,
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
print(f"✅ Generated {len(all_texts)} code chunks")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for code RAG
print("\n💻 Code RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'How does the embedding computation work?'")
print("- 'What are the main classes in this codebase?'")
print("- 'Show me the search implementation'")
print("- 'How is error handling implemented?'")
print("- 'What design patterns are used?'")
print("- 'Explain the chunking logic'")
print("\n🚀 Features:")
print("- ✅ AST-aware chunking preserves code structure")
print("- ✅ Automatic language detection")
print("- ✅ Smart filtering of large files and common excludes")
print("- ✅ Optimized for code understanding")
print("\nUsage examples:")
print(" python -m apps.code_rag --repo-dir ./my_project")
print(
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
)
print("\nOr run without --query for interactive mode\n")
rag = CodeRAG()
asyncio.run(rag.run())

View File

@@ -9,8 +9,7 @@ from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from base_rag_example import BaseRAGExample, create_text_chunks
from llama_index.core import SimpleDirectoryReader
@@ -45,11 +44,6 @@ class DocumentRAG(BaseRAGExample):
doc_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
doc_group.add_argument(
"--enable-code-chunking",
action="store_true",
help="Enable AST-aware chunking for code files in the data directory",
)
async def load_data(self, args) -> list[str]:
"""Load documents and convert to text chunks."""
@@ -82,22 +76,9 @@ class DocumentRAG(BaseRAGExample):
print(f"Loaded {len(documents)} documents")
# Determine chunking strategy
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
if use_ast:
print("Using AST-aware chunking for code files")
# Convert to text chunks with optional AST support
# Convert to text chunks
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
use_ast_chunking=use_ast,
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
code_file_extensions=getattr(args, "code_file_extensions", None),
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
# Apply max_items limit if specified
@@ -121,10 +102,6 @@ if __name__ == "__main__":
print(
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
)
print("\n🚀 NEW: Code-aware chunking available!")
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
print("- Supports Python, Java, C#, TypeScript files")
print("- Better semantic understanding of code structure")
print("\nOr run without --query for interactive mode\n")
rag = DocumentRAG()

View File

@@ -9,8 +9,7 @@ from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from base_rag_example import BaseRAGExample, create_text_chunks
from .email_data.LEANN_email_reader import EmlxReader

View File

@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, _hidden = row
last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text
doc_content = f"""

View File

@@ -1 +0,0 @@
"""iMessage data processing module."""

View File

@@ -1,342 +0,0 @@
"""
iMessage data reader.
Reads and processes iMessage conversation data from the macOS Messages database.
"""
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class IMessageReader(BaseReader):
"""
iMessage data reader.
Reads iMessage conversation data from the macOS Messages database (chat.db).
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _get_default_chat_db_path(self) -> Path:
"""
Get the default path to the iMessage chat database.
Returns:
Path to the chat.db file
"""
home = Path.home()
return home / "Library" / "Messages" / "chat.db"
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
"""
Convert Cocoa timestamp to readable format.
Args:
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
Returns:
Formatted timestamp string
"""
if cocoa_timestamp == 0:
return "Unknown"
try:
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
# Convert to seconds and add to Unix epoch
cocoa_epoch = datetime(2001, 1, 1)
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
message_time = cocoa_epoch.timestamp() + unix_timestamp
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "Unknown"
def _get_contact_name(self, handle_id: str) -> str:
"""
Get a readable contact name from handle ID.
Args:
handle_id: The handle ID (phone number or email)
Returns:
Formatted contact name
"""
if not handle_id:
return "Unknown"
# Clean up phone numbers and emails for display
if "@" in handle_id:
return handle_id # Email address
elif handle_id.startswith("+"):
return handle_id # International phone number
else:
# Try to format as phone number
digits = "".join(filter(str.isdigit, handle_id))
if len(digits) == 10:
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
elif len(digits) == 11 and digits[0] == "1":
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
else:
return handle_id
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
"""
Read messages from the iMessage database.
Args:
db_path: Path to the chat.db file
Returns:
List of message dictionaries
"""
if not db_path.exists():
print(f"iMessage database not found at: {db_path}")
return []
try:
# Connect to the database
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Query to get messages with chat and handle information
query = """
SELECT
m.ROWID as message_id,
m.text,
m.date,
m.is_from_me,
m.service,
c.chat_identifier,
c.display_name as chat_display_name,
h.id as handle_id,
c.ROWID as chat_id
FROM message m
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
LEFT JOIN handle h ON m.handle_id = h.ROWID
WHERE m.text IS NOT NULL AND m.text != ''
ORDER BY c.ROWID, m.date
"""
cursor.execute(query)
rows = cursor.fetchall()
messages = []
for row in rows:
(
message_id,
text,
date,
is_from_me,
service,
chat_identifier,
chat_display_name,
handle_id,
chat_id,
) = row
message = {
"message_id": message_id,
"text": text,
"timestamp": self._convert_cocoa_timestamp(date),
"is_from_me": bool(is_from_me),
"service": service or "iMessage",
"chat_identifier": chat_identifier or "Unknown",
"chat_display_name": chat_display_name or "Unknown Chat",
"handle_id": handle_id or "Unknown",
"contact_name": self._get_contact_name(handle_id or ""),
"chat_id": chat_id,
}
messages.append(message)
conn.close()
print(f"Found {len(messages)} messages in database")
return messages
except sqlite3.Error as e:
print(f"Error reading iMessage database: {e}")
return []
except Exception as e:
print(f"Unexpected error reading iMessage database: {e}")
return []
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
"""
Group messages by chat ID.
Args:
messages: List of message dictionaries
Returns:
Dictionary mapping chat_id to list of messages
"""
chats = {}
for message in messages:
chat_id = message["chat_id"]
if chat_id not in chats:
chats[chat_id] = []
chats[chat_id].append(message)
return chats
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
"""
Create concatenated content from chat messages.
Args:
chat_id: The chat ID
messages: List of messages in the chat
Returns:
Concatenated text content
"""
if not messages:
return ""
# Get chat info from first message
first_msg = messages[0]
chat_name = first_msg["chat_display_name"]
chat_identifier = first_msg["chat_identifier"]
# Build message content
message_parts = []
for message in messages:
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
if is_from_me:
prefix = "[You]"
else:
prefix = f"[{contact_name}]"
if timestamp != "Unknown":
prefix += f" ({timestamp})"
message_parts.append(f"{prefix}: {text}")
concatenated_text = "\n\n".join(message_parts)
doc_content = f"""Chat: {chat_name}
Identifier: {chat_identifier}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def _create_individual_content(self, message: dict) -> str:
"""
Create content for individual message.
Args:
message: Message dictionary
Returns:
Formatted message content
"""
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
chat_name = message["chat_display_name"]
sender = "You" if is_from_me else contact_name
return f"""Message from {sender} in chat "{chat_name}"
Time: {timestamp}
Content: {text}
"""
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load iMessage data and return as documents.
Args:
input_dir: Optional path to directory containing chat.db file.
If not provided, uses default macOS location.
**load_kwargs: Additional arguments (unused)
Returns:
List of Document objects containing iMessage data
"""
docs = []
# Determine database path
if input_dir:
db_path = Path(input_dir) / "chat.db"
else:
db_path = self._get_default_chat_db_path()
print(f"Reading iMessage database from: {db_path}")
# Read messages from database
messages = self._read_messages_from_db(db_path)
if not messages:
return docs
if self.concatenate_conversations:
# Group messages by chat and create concatenated documents
chats = self._group_messages_by_chat(messages)
for chat_id, chat_messages in chats.items():
if not chat_messages:
continue
content = self._create_concatenated_content(chat_id, chat_messages)
# Create metadata
first_msg = chat_messages[0]
last_msg = chat_messages[-1]
metadata = {
"source": "iMessage",
"chat_id": chat_id,
"chat_name": first_msg["chat_display_name"],
"chat_identifier": first_msg["chat_identifier"],
"message_count": len(chat_messages),
"first_message_date": first_msg["timestamp"],
"last_message_date": last_msg["timestamp"],
"participants": list(
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
),
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
else:
# Create individual documents for each message
for message in messages:
content = self._create_individual_content(message)
metadata = {
"source": "iMessage",
"message_id": message["message_id"],
"chat_id": message["chat_id"],
"chat_name": message["chat_display_name"],
"chat_identifier": message["chat_identifier"],
"timestamp": message["timestamp"],
"is_from_me": message["is_from_me"],
"contact_name": message["contact_name"],
"service": message["service"],
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
print(f"Created {len(docs)} documents from iMessage data")
return docs

View File

@@ -1,125 +0,0 @@
"""
iMessage RAG Example.
This example demonstrates how to build a RAG system on your iMessage conversation history.
"""
import asyncio
from pathlib import Path
from leann.chunking_utils import create_text_chunks
from apps.base_rag_example import BaseRAGExample
from apps.imessage_data.imessage_reader import IMessageReader
class IMessageRAG(BaseRAGExample):
"""RAG example for iMessage conversation history."""
def __init__(self):
super().__init__(
name="iMessage",
description="RAG on your iMessage conversation history",
default_index_name="imessage_index",
)
def _add_specific_arguments(self, parser):
"""Add iMessage-specific arguments."""
imessage_group = parser.add_argument_group("iMessage Parameters")
imessage_group.add_argument(
"--db-path",
type=str,
default=None,
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
)
imessage_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
imessage_group.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process each message individually instead of concatenating by conversation",
)
imessage_group.add_argument(
"--chunk-size",
type=int,
default=1000,
help="Maximum characters per text chunk (default: 1000)",
)
imessage_group.add_argument(
"--chunk-overlap",
type=int,
default=200,
help="Overlap between text chunks (default: 200)",
)
async def load_data(self, args) -> list[str]:
"""Load iMessage history and convert to text chunks."""
print("Loading iMessage conversation history...")
# Determine concatenation setting
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
# Initialize iMessage reader
reader = IMessageReader(concatenate_conversations=concatenate)
# Load documents
try:
if args.db_path:
# Use custom database path
db_dir = str(Path(args.db_path).parent)
documents = reader.load_data(input_dir=db_dir)
else:
# Use default macOS location
documents = reader.load_data()
except Exception as e:
print(f"Error loading iMessage data: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
print("3. Try specifying a custom path with --db-path if you have a backup")
return []
if not documents:
print("No iMessage conversations found!")
return []
print(f"Loaded {len(documents)} iMessage documents")
# Show some statistics
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
print(f"Total messages: {total_messages}")
if concatenate:
# Show chat statistics
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
unique_chats = len(set(chat_names))
print(f"Unique conversations: {unique_chats}")
# Convert to text chunks
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
)
# Apply max_items limit if specified
if args.max_items > 0:
all_texts = all_texts[: args.max_items]
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
return all_texts
async def main():
"""Main entry point."""
app = IMessageRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,113 +0,0 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -1,182 +0,0 @@
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
def _ensure_repo_paths_importable(current_file: str) -> None:
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
class LeannMultiVector:
def __init__(
self,
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
) -> None:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
def _meta_dict(self) -> dict:
return {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": self.embedding_model_name,
"embedding_mode": "custom",
"dimensions": self.dim,
"backend_kwargs": self._backend_kwargs,
"is_compact": self._backend_kwargs.get("is_compact", True),
"is_pruned": self._backend_kwargs.get("is_compact", True)
and self._backend_kwargs.get("is_recompute", True),
}
def create_collection(self) -> None:
path = Path(self.index_path)
path.parent.mkdir(parents=True, exist_ok=True)
def insert(self, data: dict) -> None:
self._pending_items.append(
{
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
}
)
def _labels_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
def _meta_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def create_index(self) -> None:
if not self._pending_items:
return
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
labels_meta.append(
{
"id": f"{doc_id}:{seq_id}",
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
}
)
if not embeddings:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
if self._labels_meta:
return
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
distances = raw.get("distances")
if labels is None or distances is None:
return []
doc_scores: dict[int, float] = {}
B = len(labels)
for b in range(B):
per_doc_best: dict[int, float] = {}
for k, sid in enumerate(labels[b]):
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
else:
continue
score = float(distances[b][k])
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
per_doc_best[doc_id] = score
for doc_id, best_score in per_doc_best.items():
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores

View File

@@ -1,112 +0,0 @@
# pip install pdf2image
# pip install pymilvus
# pip install colpali_engine
# pip install tqdm
# pip install pillow
import os
import re
import sys
from pathlib import Path
from typing import cast
from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
import torch
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
# Auto-select device: CUDA > MPS (mac) > CPU
_device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(_device_str)
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=_dtype,
device_map=device,
).eval()
print(f"Using device={_device_str}, dtype={_dtype}")
queries = [
"How to end-to-end retrieval with ColBert",
"Where is ColBERT performance Table, including text representation results?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: list[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: list[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
# %%
# Build HNSW index via LeannRetriever primitives and run search
index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
for query in qs:
query_np = query.float().numpy()
result = retriever.search(query_np, topk=1)
print(filepaths[result[0][1]])

View File

@@ -1,477 +0,0 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train"
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N ):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
# %%
# Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
# %%
# Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
base = _Path(SAVE_TOP_IMAGE)
base.parent.mkdir(parents=True, exist_ok=True)
for rank, img in enumerate(top_images[:TOPK], start=1):
if base.suffix:
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
for rank, img in enumerate(top_images[:TOPK], start=1):
if output_base:
if output_base.suffix:
out_dir = output_base.parent
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
out_path = str(out_dir / out_name)
else:
out_dir = output_base
out_dir.mkdir(parents=True, exist_ok=True)
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
else:
out_path = None
chosen_idx, max_sim = _generate_similarity_map(
model=model,
processor=processor,
image=img,
query=QUERY,
token_idx=token_idx,
output_path=out_path,
)
if out_path:
print(
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
)
else:
print(
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
)
# %%
# Step 6: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
print("\nAnswer:")
print(response)

View File

@@ -1,183 +0,0 @@
#!/usr/bin/env python3
import re
import sys
from datetime import datetime, timedelta
from pathlib import Path
from leann import LeannSearcher
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
class TimeParser:
def __init__(self):
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
# Compile for performance
self.regex = re.compile(self.pattern, re.IGNORECASE)
# Stop words to remove before regex parsing
self.stop_words = {
"in",
"at",
"of",
"by",
"as",
"me",
"the",
"a",
"an",
"and",
"any",
"find",
"search",
"list",
"ago",
"back",
"past",
"earlier",
}
def clean_text(self, text):
"""Remove stop words from text"""
words = text.split()
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
return cleaned
def parse(self, text):
"""Extract all time expressions from text"""
# Clean text first
cleaned_text = self.clean_text(text)
matches = []
for match in self.regex.finditer(cleaned_text):
fuzzy = match.group(1) # "around", "about", etc.
number = int(match.group(2))
unit = match.group(3).lower()
matches.append(
{
"full_match": match.group(0),
"fuzzy": bool(fuzzy),
"number": number,
"unit": unit,
"range": self.calculate_range(number, unit, bool(fuzzy)),
}
)
return matches
def calculate_range(self, number, unit, is_fuzzy):
"""Convert to actual datetime range and return ISO format strings"""
units = {
"hour": timedelta(hours=number),
"day": timedelta(days=number),
"week": timedelta(weeks=number),
"month": timedelta(days=number * 30),
"year": timedelta(days=number * 365),
}
delta = units[unit]
now = datetime.now()
target = now - delta
if is_fuzzy:
buffer = delta * 0.2 # 20% buffer for fuzzy
start = (target - buffer).isoformat()
end = (target + buffer).isoformat()
else:
start = target.isoformat()
end = now.isoformat()
return (start, end)
def search_files(query, top_k=15):
"""Search the index and return results"""
# Parse time expressions
parser = TimeParser()
time_matches = parser.parse(query)
# Remove time expressions from query for semantic search
clean_query = query
if time_matches:
for match in time_matches:
clean_query = clean_query.replace(match["full_match"], "").strip()
# Check if clean_query is less than 4 characters
if len(clean_query) < 4:
print("Error: add more input for accurate results.")
return
# Single query to vector DB
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
)
# Filter by time if time expression found
if time_matches:
time_range = time_matches[0]["range"] # Use first time expression
start_time, end_time = time_range
filtered_results = []
for result in results:
# Access metadata attribute directly (not .get())
metadata = result.metadata if hasattr(result, "metadata") else {}
if metadata:
# Check modification date first, fall back to creation date
date_str = metadata.get("modification_date") or metadata.get("creation_date")
if date_str:
# Convert strings to datetime objects for proper comparison
try:
file_date = datetime.fromisoformat(date_str)
start_dt = datetime.fromisoformat(start_time)
end_dt = datetime.fromisoformat(end_time)
# Compare dates properly
if start_dt <= file_date <= end_dt:
filtered_results.append(result)
except (ValueError, TypeError):
# Handle invalid date formats
print(f"Warning: Invalid date format in metadata: {date_str}")
continue
results = filtered_results
# Print results
print(f"\nSearch results for: '{query}'")
if time_matches:
print(
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
)
print(
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
)
print("-" * 80)
for i, result in enumerate(results, 1):
print(f"\n[{i}] Score: {result.score:.4f}")
print(f"Content: {result.text}")
# Show metadata if present
metadata = result.metadata if hasattr(result, "metadata") else None
if metadata:
if "creation_date" in metadata:
print(f"Created: {metadata['creation_date']}")
if "modification_date" in metadata:
print(f"Modified: {metadata['modification_date']}")
print("-" * 80)
if __name__ == "__main__":
if len(sys.argv) < 2:
print('Usage: python search_index.py "<search query>" [top_k]')
sys.exit(1)
query = sys.argv[1]
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
search_files(query, top_k)

View File

@@ -1,82 +0,0 @@
#!/usr/bin/env python3
import json
import sys
from pathlib import Path
from leann import LeannBuilder
def process_json_items(json_file_path):
"""Load and process JSON file with metadata items"""
with open(json_file_path, encoding="utf-8") as f:
items = json.load(f)
# Guard against empty JSON
if not items:
print("⚠️ No items found in the JSON file. Exiting gracefully.")
return
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
total_items = len(items)
items_added = 0
print(f"Processing {total_items} items...")
for idx, item in enumerate(items):
try:
# Create embedding text sentence
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
# Prepare metadata with dates
metadata = {}
if "CreationDate" in item:
metadata["creation_date"] = item["CreationDate"]
if "ContentChangeDate" in item:
metadata["modification_date"] = item["ContentChangeDate"]
# Add to builder
builder.add_text(embedding_text, metadata=metadata)
items_added += 1
except Exception as e:
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
continue
# Show progress
progress = (idx + 1) / total_items * 100
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
sys.stdout.flush()
print() # New line after progress
# Guard against no successfully added items
if items_added == 0:
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
return
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
print("Building index...")
try:
builder.build_index(INDEX_PATH)
print(f"✓ Index saved to {INDEX_PATH}")
except ValueError as e:
if "No chunks added" in str(e):
print("⚠️ No chunks were added to the builder. Index not created.")
else:
raise
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python build_index.py <json_file>")
sys.exit(1)
json_file = sys.argv[1]
if not Path(json_file).exists():
print(f"Error: File {json_file} not found")
sys.exit(1)
process_json_items(json_file)

View File

@@ -1,265 +0,0 @@
#!/usr/bin/env python3
"""
Spotlight Metadata Dumper for Vector DB
Extracts only essential metadata for semantic search embeddings
Output is optimized for vector database storage with minimal fields
"""
import json
import sys
from datetime import datetime
# Check platform before importing macOS-specific modules
if sys.platform != "darwin":
print("This script requires macOS (uses Spotlight)")
sys.exit(1)
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
# EDIT THIS LIST: Add or remove folders to search
# Can be either:
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
# - Absolute paths (e.g., "/Applications", "/System/Library")
SEARCH_FOLDERS = [
"Desktop",
"Downloads",
"Documents",
"Music",
"Pictures",
"Movies",
# "Library", # Uncomment to include
# "/Applications", # Absolute path example
# "Code/Projects", # Subfolder example
# Add any other folders here
]
def convert_to_serializable(obj):
"""Convert NS objects to Python serializable types"""
if obj is None:
return None
# Handle NSDate
if hasattr(obj, "timeIntervalSince1970"):
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
# Handle NSArray
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
# Convert to string
try:
return str(obj)
except Exception:
return repr(obj)
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
"""
Dump Spotlight data using public.item predicate
"""
# Build full paths from SEARCH_FOLDERS
import os
home_dir = os.path.expanduser("~")
search_paths = []
print("Search locations:")
for folder in SEARCH_FOLDERS:
# Check if it's an absolute path or relative
if folder.startswith("/"):
full_path = folder
else:
full_path = os.path.join(home_dir, folder)
if os.path.exists(full_path):
search_paths.append(full_path)
print(f"{full_path}")
else:
print(f"{full_path} (not found)")
if not search_paths:
print("No valid search paths found!")
return []
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
# Create query with public.item predicate
query = NSMetadataQuery.alloc().init()
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
query.setPredicate_(predicate)
# Set search scopes to our specific folders
query.setSearchScopes_(search_paths)
print("Starting query...")
query.startQuery()
# Wait for gathering to complete
run_loop = NSRunLoop.currentRunLoop()
print("Gathering results...")
# Let it gather for a few seconds
for i in range(50): # 5 seconds max
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
# Check gathering status periodically
if i % 10 == 0:
current_count = query.resultCount()
if current_count > 0:
print(f" Found {current_count} items so far...")
# Continue while still gathering (up to 2 more seconds)
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
query.stopQuery()
total_results = query.resultCount()
print(f"Found {total_results} total items")
if total_results == 0:
print("No results found")
return []
# Process items
items_to_process = min(total_results, max_items)
results = []
# ONLY relevant attributes for vector embeddings
# These provide essential context for semantic search without bloat
attributes = [
"kMDItemPath", # Full path for file retrieval
"kMDItemFSName", # Filename for display & embedding
"kMDItemFSSize", # Size for filtering/ranking
"kMDItemContentType", # File type for categorization
"kMDItemKind", # Human-readable type for embedding
"kMDItemFSCreationDate", # Temporal context
"kMDItemFSContentChangeDate", # Recency for ranking
]
print(f"Processing {items_to_process} items...")
for i in range(items_to_process):
try:
item = query.resultAtIndex_(i)
metadata = {}
# Extract ONLY the relevant attributes
for attr in attributes:
try:
value = item.valueForAttribute_(attr)
if value is not None:
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
clean_key = attr.replace("kMDItem", "").replace("FS", "")
metadata[clean_key] = convert_to_serializable(value)
except (AttributeError, ValueError, TypeError):
continue
# Only add if we have at least a path
if metadata.get("Path"):
results.append(metadata)
except Exception as e:
print(f"Error processing item {i}: {e}")
continue
# Save to JSON
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\n✓ Saved {len(results)} items to {output_file}")
# Show summary
print("\nSample items:")
import os
home_dir = os.path.expanduser("~")
for i, item in enumerate(results[:3]):
print(f"\n[Item {i + 1}]")
print(f" Path: {item.get('Path', 'N/A')}")
print(f" Name: {item.get('Name', 'N/A')}")
print(f" Type: {item.get('ContentType', 'N/A')}")
print(f" Kind: {item.get('Kind', 'N/A')}")
# Handle size properly
size = item.get("Size")
if size:
try:
size_int = int(size)
if size_int > 1024 * 1024:
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
elif size_int > 1024:
print(f" Size: {size_int / 1024:.2f} KB")
else:
print(f" Size: {size_int} bytes")
except (ValueError, TypeError):
print(f" Size: {size}")
# Show dates
if "CreationDate" in item:
print(f" Created: {item['CreationDate']}")
if "ContentChangeDate" in item:
print(f" Modified: {item['ContentChangeDate']}")
# Count by type
type_counts = {}
for item in results:
content_type = item.get("ContentType", "unknown")
type_counts[content_type] = type_counts.get(content_type, 0) + 1
print(f"\nTotal items saved: {len(results)}")
if type_counts:
print("\nTop content types:")
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {ct}: {count} items")
# Count by folder
folder_counts = {}
for item in results:
path = item.get("Path", "")
for folder in SEARCH_FOLDERS:
# Build the full folder path
if folder.startswith("/"):
folder_path = folder
else:
folder_path = os.path.join(home_dir, folder)
if path.startswith(folder_path):
folder_counts[folder] = folder_counts.get(folder, 0) + 1
break
if folder_counts:
print("\nItems by location:")
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
print(f" {folder}: {count} items")
return results
def main():
# Parse arguments
if len(sys.argv) > 1:
try:
max_items = int(sys.argv[1])
except ValueError:
print("Usage: python spot.py [number_of_items]")
print("Default: 10 items")
sys.exit(1)
else:
max_items = 10
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
# Run dump
dump_spotlight_data(max_items=max_items, output_file=output_file)
if __name__ == "__main__":
main()

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 152 KiB

View File

@@ -1,24 +1,9 @@
# 🧪 LEANN Benchmarks & Testing
# 🧪 Leann Sanity Checks
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
## 📁 Test Files
### `diskann_vs_hnsw_speed_comparison.py`
Performance comparison between DiskANN and HNSW backends:
-**Search latency** comparison with both backends using recompute
-**Index size** and **build time** measurements
-**Score validity** testing (ensures no -inf scores)
-**Configurable dataset sizes** for different scales
```bash
# Quick comparison with 500 docs, 10 queries
python benchmarks/diskann_vs_hnsw_speed_comparison.py
# Large-scale comparison with 2000 docs, 20 queries
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
```
### `test_distance_functions.py`
Tests all supported distance functions across DiskANN backend:
-**MIPS** (Maximum Inner Product Search)

View File

View File

@@ -1,148 +0,0 @@
import argparse
import os
import time
from pathlib import Path
from leann import LeannBuilder, LeannSearcher
def _meta_exists(index_path: str) -> bool:
p = Path(index_path)
return (p.parent / f"{p.stem}.meta.json").exists()
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
# if _meta_exists(index_path):
# return
kwargs = {}
if backend_name == "hnsw":
kwargs["is_compact"] = is_recompute
builder = LeannBuilder(
backend_name=backend_name,
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
graph_degree=32,
complexity=64,
is_recompute=is_recompute,
num_threads=4,
**kwargs,
)
for i in range(num_docs):
builder.add_text(
f"This is a test document number {i}. It contains some repeated text for benchmarking."
)
builder.build_index(index_path)
def _bench_group(
index_path: str,
recompute: bool,
query: str,
repeats: int,
complexity: int = 32,
top_k: int = 10,
) -> float:
# Independent searcher per group; fixed port when recompute
searcher = LeannSearcher(index_path=index_path)
# Warm-up once
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
def _once() -> float:
t0 = time.time()
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
return time.time() - t0
if repeats <= 1:
t = _once()
else:
vals = [_once() for _ in range(repeats)]
vals.sort()
t = vals[len(vals) // 2]
searcher.cleanup()
return t
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-docs", type=int, default=5000)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--complexity", type=int, default=32)
args = parser.parse_args()
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
base.parent.mkdir(parents=True, exist_ok=True)
# ---------- Build HNSW variants ----------
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
# ---------- Build DiskANN variants ----------
diskann_r = str(base / "diskann_r.leann")
diskann_nr = str(base / "diskann_nr.leann")
ensure_index(diskann_r, "diskann", args.num_docs, True)
ensure_index(diskann_nr, "diskann", args.num_docs, False)
# ---------- Helpers ----------
def _size_for(prefix: str) -> int:
p = Path(prefix)
base_dir = p.parent
stem = p.stem
total = 0
for f in base_dir.iterdir():
if f.is_file() and f.name.startswith(stem):
total += f.stat().st_size
return total
# ---------- HNSW benchmark ----------
t_hnsw_r = _bench_group(
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
t_hnsw_nr = _bench_group(
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
size_hnsw_r = _size_for(hnsw_r)
size_hnsw_nr = _size_for(hnsw_nr)
print("Benchmark results (HNSW):")
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
print(
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
)
print(" Expectation: no-recompute should be faster but larger on disk.")
# ---------- DiskANN benchmark ----------
t_diskann_r = _bench_group(
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
)
t_diskann_nr = _bench_group(
diskann_nr,
False,
"DiskANN NR test doc 123",
repeats=args.repeats,
complexity=args.complexity,
)
size_diskann_r = _size_for(diskann_r)
size_diskann_nr = _size_for(diskann_nr)
print("\nBenchmark results (DiskANN):")
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
if __name__ == "__main__":
main()

View File

@@ -1,23 +0,0 @@
BM25 vs DiskANN Baselines
```bash
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
```
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
- Machine-specific; results measured locally with the current repo.
DiskANN (NQ queries, search-only)
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
BM25
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
Notes
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.

Before

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -1,183 +0,0 @@
# /// script
# dependencies = [
# "pyserini"
# ]
# ///
# sudo pacman -S jdk21-openjdk
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
# sudo archlinux-java status
# sudo archlinux-java set java-21-openjdk
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
# fish_add_path --global $JAVA_HOME/bin
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
import argparse
import json
import os
import sys
import time
from statistics import mean
def load_queries(path: str, limit: int | None) -> list[str]:
queries: list[str] = []
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
_, ext = os.path.splitext(path)
if ext.lower() in {".jsonl", ".json"}:
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# Not strict JSONL? treat the whole line as the query
queries.append(line)
continue
q = obj.get("query") or obj.get("text") or obj.get("question")
if q:
queries.append(str(q))
else:
with open(path, encoding="utf-8") as f:
for line in f:
s = line.strip()
if s:
queries.append(s)
if limit is not None and limit > 0:
queries = queries[:limit]
return queries
def percentile(values: list[float], p: float) -> float:
if not values:
return 0.0
s = sorted(values)
k = (len(s) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(s) - 1)
if f == c:
return s[f]
return s[f] + (s[c] - s[f]) * (k - f)
def main():
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
ap.add_argument(
"--bm25-index",
default="benchmarks/data/indices/bm25_index",
help="Path to Pyserini Lucene index directory",
)
ap.add_argument(
"--queries",
default="benchmarks/data/queries/nq_open.jsonl",
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
)
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
ap.add_argument(
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
)
ap.add_argument(
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
)
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
args = ap.parse_args()
try:
from pyserini.search.lucene import LuceneSearcher
except Exception:
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
raise
if not os.path.isdir(args.bm25_index):
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
sys.exit(1)
queries = load_queries(args.queries, args.limit)
if not queries:
print("No queries loaded.", file=sys.stderr)
sys.exit(1)
print(f"Loaded {len(queries)} queries from {args.queries}")
print(f"Opening BM25 index: {args.bm25_index}")
searcher = LuceneSearcher(args.bm25_index)
# Some builds of pyserini require explicit set_bm25; others ignore
try:
searcher.set_bm25(k1=args.k1, b=args.b)
except Exception:
pass
latencies: list[float] = []
total_searches = 0
# Warmup
for i in range(min(args.warmup, len(queries))):
_ = searcher.search(queries[i], k=args.k)
t0 = time.time()
for i, q in enumerate(queries):
t1 = time.time()
hits = searcher.search(q, k=args.k)
t2 = time.time()
latencies.append(t2 - t1)
total_searches += 1
if args.fetch_docs:
# Optional doc fetch to include I/O time
for h in hits:
try:
_ = searcher.doc(h.docid)
except Exception:
pass
if (i + 1) % 50 == 0:
print(f"Processed {i + 1}/{len(queries)} queries")
t1 = time.time()
total_time = t1 - t0
if latencies:
avg = mean(latencies)
p50 = percentile(latencies, 50)
p90 = percentile(latencies, 90)
p95 = percentile(latencies, 95)
p99 = percentile(latencies, 99)
qps = total_searches / total_time if total_time > 0 else 0.0
else:
avg = p50 = p90 = p95 = p99 = qps = 0.0
print("BM25 Latency Report")
print(f" queries: {total_searches}")
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
print(f" avg per query: {avg:.6f} s")
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
if args.report:
payload = {
"queries": total_searches,
"k": args.k,
"k1": args.k1,
"b": args.b,
"avg_s": avg,
"p50_s": p50,
"p90_s": p90,
"p95_s": p95,
"p99_s": p99,
"total_time_s": total_time,
"qps": qps,
"index_dir": os.path.abspath(args.bm25_index),
"fetch_docs": bool(args.fetch_docs),
}
with open(args.report, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
print(f"Saved report to {args.report}")
if __name__ == "__main__":
main()

View File

@@ -1,124 +0,0 @@
# /// script
# dependencies = [
# "leann-backend-diskann"
# ]
# ///
import argparse
import json
import time
from pathlib import Path
import numpy as np
def load_queries(path: Path, limit: int | None) -> list[str]:
out: list[str] = []
with open(path, encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
out.append(obj["query"])
if limit and len(out) >= limit:
break
return out
def main() -> None:
ap = argparse.ArgumentParser(
description="DiskANN baseline on real NQ queries (search-only timing)"
)
ap.add_argument(
"--index-dir",
default="benchmarks/data/indices/diskann_rpj_wiki",
help="Directory containing DiskANN files",
)
ap.add_argument("--index-prefix", default="ann")
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
ap.add_argument("--num-queries", type=int, default=200)
ap.add_argument("--top-k", type=int, default=10)
ap.add_argument("--complexity", type=int, default=62)
ap.add_argument("--threads", type=int, default=1)
ap.add_argument("--beam-width", type=int, default=1)
ap.add_argument("--cache-mechanism", type=int, default=2)
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
args = ap.parse_args()
index_dir = Path(args.index_dir).resolve()
if not index_dir.is_dir():
raise SystemExit(f"Index dir not found: {index_dir}")
qpath = Path(args.queries_file).resolve()
if not qpath.exists():
raise SystemExit(f"Queries file not found: {qpath}")
queries = load_queries(qpath, args.num_queries)
print(f"Loaded {len(queries)} queries from {qpath}")
# Compute embeddings once (exclude from timing)
from leann.api import compute_embeddings as _compute
embs = _compute(
queries,
model_name="facebook/contriever-msmarco",
mode="sentence-transformers",
use_server=False,
).astype(np.float32)
if embs.ndim != 2:
raise SystemExit("Embedding compute failed or returned wrong shape")
# Build searcher
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
index_prefix_path = str(index_dir / args.index_prefix)
searcher = _DiskannSearcher(
index_prefix_path,
num_threads=int(args.threads),
cache_mechanism=int(args.cache_mechanism),
num_nodes_to_cache=int(args.num_nodes_to_cache),
)
# Warmup (not timed)
_ = searcher.search(
embs[0:1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
# Timed loop
times: list[float] = []
for i in range(embs.shape[0]):
t0 = time.time()
_ = searcher.search(
embs[i : i + 1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
times.append(time.time() - t0)
times_sorted = sorted(times)
avg = float(sum(times) / len(times))
p50 = times_sorted[len(times) // 2]
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
print("\nDiskANN (NQ, search-only) Report")
print(f" queries: {len(times)}")
print(
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
)
print(f" avg per query: {avg:.6f} s")
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
print(f" QPS: {1.0 / avg:.2f}")
if __name__ == "__main__":
main()

82
benchmarks/data/.gitattributes vendored Normal file
View File

@@ -0,0 +1,82 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.lz4 filter=lfs diff=lfs merge=lfs -text
*.mds filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
# Audio files - uncompressed
*.pcm filter=lfs diff=lfs merge=lfs -text
*.sam filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
# Audio files - compressed
*.aac filter=lfs diff=lfs merge=lfs -text
*.flac filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
# Image files - uncompressed
*.bmp filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
# Image files - compressed
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.webp filter=lfs diff=lfs merge=lfs -text
# Video files - compressed
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text

View File

@@ -1,44 +0,0 @@
---
license: mit
---
# LEANN-RAG Evaluation Data
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
## Dataset Components
This dataset is structured into three main parts:
1. **Pre-built LEANN Indices**:
* `dpr/`: A pre-built index for the DPR dataset.
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
2. **Ground Truth Data**:
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
3. **Queries**:
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
## Usage
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
```bash
pip install huggingface-hub
```
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
```python
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir="data"
)
```
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.

View File

@@ -1,286 +0,0 @@
#!/usr/bin/env python3
"""
DiskANN vs HNSW Search Performance Comparison
This benchmark compares search performance between DiskANN and HNSW backends:
- DiskANN: With graph partitioning enabled (is_recompute=True)
- HNSW: With recompute enabled (is_recompute=True)
- Tests performance across different dataset sizes
- Measures search latency, recall, and index size
"""
import gc
import multiprocessing as mp
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
try:
mp.set_start_method("fork", force=True)
except Exception:
pass
def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking."""
np.random.seed(42)
topics = [
"machine learning and artificial intelligence",
"natural language processing and text analysis",
"computer vision and image recognition",
"data science and statistical analysis",
"deep learning and neural networks",
"information retrieval and search engines",
"database systems and data management",
"software engineering and programming",
"cybersecurity and network protection",
"cloud computing and distributed systems",
]
texts = []
for i in range(n_docs):
topic = topics[i % len(topics)]
variation = np.random.randint(1, 100)
text = (
f"This is document {i} about {topic}. Content variation {variation}. "
f"Additional information about {topic} with details and examples. "
f"Technical discussion of {topic} including implementation aspects."
)
texts.append(text)
return texts
def benchmark_backend(
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
) -> dict[str, float]:
"""Benchmark a specific backend with the given configuration."""
from leann.api import LeannBuilder, LeannSearcher
print(f"\n🔧 Testing {backend_name.upper()} backend...")
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
# Build index
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
start_time = time.time()
builder = LeannBuilder(
backend_name=backend_name,
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
**backend_kwargs,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
build_time = time.time() - start_time
# Measure index size
index_dir = Path(index_path).parent
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
size_mb = total_size / (1024 * 1024)
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
# Search benchmark
print("🔍 Running search benchmark...")
searcher = LeannSearcher(index_path)
search_times = []
all_results = []
for query in test_queries:
start_time = time.time()
results = searcher.search(query, top_k=5)
search_time = time.time() - start_time
search_times.append(search_time)
all_results.append(results)
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
# Check for valid scores (detect -inf issues)
all_scores = [
result.score
for results in all_results
for result in results
if result.score is not None
]
valid_scores = [
score for score in all_scores if score != float("-inf") and score != float("inf")
]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up (ensure embedding server shutdown and object GC)
try:
if hasattr(searcher, "cleanup"):
searcher.cleanup()
del searcher
del builder
gc.collect()
except Exception as e:
print(f"⚠️ Warning: Resource cleanup error: {e}")
return {
"build_time": build_time,
"avg_search_time_ms": avg_search_time,
"index_size_mb": size_mb,
"score_validity_rate": score_validity_rate,
}
def run_comparison(n_docs: int = 500, n_queries: int = 10):
"""Run performance comparison between DiskANN and HNSW."""
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
# Create test data
texts = create_test_texts(n_docs)
test_queries = [
"machine learning algorithms",
"natural language processing",
"computer vision techniques",
"data analysis methods",
"neural network architectures",
"database query optimization",
"software development practices",
"security vulnerabilities",
"cloud infrastructure",
"distributed computing",
][:n_queries]
# HNSW benchmark
hnsw_results = benchmark_backend(
backend_name="hnsw",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable recompute for fair comparison
"M": 16,
"efConstruction": 200,
},
)
# DiskANN benchmark
diskann_results = benchmark_backend(
backend_name="diskann",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable graph partitioning
"num_neighbors": 32,
"search_list_size": 50,
},
)
# Performance comparison
print("\n📈 Performance Comparison Results")
print(f"{'=' * 60}")
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
print(f"{'-' * 60}")
# Build time comparison
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
print(
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
)
# Search time comparison
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
print(
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
)
# Index size comparison
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
print(
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
)
# Score validity
print(
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
)
print(f"{'=' * 60}")
print("\n🎯 Summary:")
if search_speedup > 1:
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
else:
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
if size_ratio > 1:
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
else:
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
print(
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
)
if __name__ == "__main__":
import sys
try:
# Handle help request
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
print()
print("Arguments:")
print(" n_docs Number of documents to index (default: 500)")
print(" n_queries Number of test queries to run (default: 10)")
print()
print("Examples:")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
sys.exit(0)
# Parse command line arguments
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Dataset: {n_docs} documents, {n_queries} queries")
print()
run_comparison(n_docs=n_docs, n_queries=n_queries)
except KeyboardInterrupt:
print("\n⚠️ Benchmark interrupted by user")
sys.exit(130)
except Exception as e:
print(f"\n❌ Benchmark failed: {e}")
sys.exit(1)
finally:
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
try:
gc.collect()
print("\n🧹 Cleanup completed")
# Flush stdio to ensure message is visible before hard-exit
try:
import sys as _sys
_sys.stdout.flush()
_sys.stderr.flush()
except Exception:
pass
except Exception:
pass
# Use os._exit to bypass atexit handlers that may hang in rare cases
import os as _os
_os._exit(0)

View File

@@ -1,141 +0,0 @@
# Enron Emails Benchmark
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
## Layout
benchmarks/enron_emails/
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
- data/: Generated passages, queries, embeddings-related files
- baseline/: FAISS Flat baseline files
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
## Quickstart
1) Prepare the data and index
cd benchmarks/enron_emails
python setup_enron_emails.py --data-dir data
Notes:
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
Alternatively, pass a local path to `--emails-csv`.
Notes:
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
2) Run recall evaluation (Stage 2)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
3) Complexity sweep (Stage 3)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
4) Index comparison (Stage 4)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
5) Generation evaluation (Stage 5)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
6) Combined index + generation evaluation (Stages 4+5, recommended)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
Notes:
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
- `--baseline-dir` defaults to `baseline`
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
- `--model-name` defaults to `Qwen/Qwen3-8B`
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
Optional flags:
- --queries data/evaluation_queries.jsonl (custom queries file)
- --baseline-dir baseline (where FAISS baseline lives)
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
- --llm-backend hf|vllm (LLM backend for generation)
- --model-name Qwen/Qwen3-8B (LLM model for generation)
- --max-queries 1000 (limit number of queries for evaluation)
## Files Produced
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
- data/enron_index_hnsw.leann.*: LEANN index files
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
## Notes
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
## Stages Summary
- Stage 2 (Recall@3):
- Compares LEANN vs FAISS Flat baseline on Recall@3.
- Compact index runs with `recompute_embeddings=True`.
- Stage 3 (Binary Search for Complexity):
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
- Stage 4 (Index Comparison):
- Reports .index-only sizes for compact vs non-compact.
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
- Stores retrieval results for Stage 5 generation evaluation.
- Fails fast if compact recompute cannot run.
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
- First from the current run (when running `--stage all`), otherwise
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
- Stage 5 (Generation Evaluation):
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
- Measures generation timing separately from search timing.
- Requires Stage 4 results (no additional searching performed).
## Example Results
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
- Stage 3 (Binary Search):
- Minimal complexity achieving 90% Recall@3: 88
- Sampled points:
- C=8 → 59.9% Recall@3
- C=72 → 89.4% Recall@3
- C=88 → 90.2% Recall@3
- C=96 → 90.7% Recall@3
- C=112 → 91.1% Recall@3
- C=136 → 91.3% Recall@3
- C=256 → 92.0% Recall@3
- Stage 4 (Index Sizes, .index only):
- Compact: ~2.2 MB
- Non-compact: ~82.0 MB
- Storage saving by compact: ~97.3%
- Stage 4 (Search Timing, 988 queries, complexity=88):
- Non-compact (no recompute): ~0.0075 s avg per query
- Compact (with recompute): ~1.981 s avg per query
- Speed ratio (non-compact/compact): ~0.0038x
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
- Average generation time: ~22.302 s per query
- Total queries processed: 988
- LLM backend: HuggingFace transformers
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
Full JSON output is saved by the script (see `--output`), e.g.:
`benchmarks/enron_emails/results_enron_stage45.json`.

View File

@@ -1 +0,0 @@
downloads/

View File

@@ -1,614 +0,0 @@
"""
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
On errors, fail fast without fallbacks.
"""
import argparse
import json
import logging
import os
import pickle
from pathlib import Path
import numpy as np
from leann import LeannBuilder, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 using FAISS Flat as ground truth"""
from leann.api import compute_embeddings
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
for i, query in enumerate(queries):
# Compute query embedding with the same model/mode as the index
q_emb = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS Flat ground truth
n = q_emb.shape[0]
k = 3
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(q_emb),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN (may require embedding server depending on index configuration)
results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {r.id for r in results}
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0
total_recall += recall
if i < 3:
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS: {list(baseline_ids)}")
print(f" LEANN: {list(test_ids)}")
print(f" ∩: {list(intersection)}")
avg = total_recall / max(1, len(queries))
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
return avg
def cleanup(self):
if hasattr(self, "searcher"):
self.searcher.cleanup()
class EnronEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
queries: list[str] = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
if "query" in data:
queries.append(data["query"])
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
return queries
def cleanup(self):
if self.searcher:
self.searcher.cleanup()
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes (.index only), similar to LAION bench."""
print("📏 Analyzing index sizes (.index only)...")
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem
sizes: dict[str, float] = {}
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json"
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
sizes["index_only_mb"] = (
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison using current passages and embeddings."""
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source and embedding model
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
# Load all passages and ids
ids: list[str] = []
texts: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
texts.append(data["text"])
# Compute embeddings using the same method as LEANN
from leann.api import compute_embeddings
embeddings = compute_embeddings(
texts,
meta["embedding_model"],
mode=meta.get("embedding_mode", "sentence-transformers"),
use_server=False,
).astype(np.float32)
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False,
is_compact=False,
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = EnronEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
) -> dict:
"""Compare search speed for non-compact vs compact indexes."""
import time
results: dict = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
"retrieval_results": [], # Store retrieval results for Stage 5
}
print("⚡ Comparing search performance between indexes...")
# Non-compact (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for q in test_queries:
t0 = time.time()
_ = non_compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=False
)
results["non_compact"]["search_times"].append(time.time() - t0)
# Compact (with recompute). Fail fast if it cannot run.
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for q in test_queries:
t0 = time.time()
docs = compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=True
)
results["compact"]["search_times"].append(time.time() - t0)
# Store retrieval results for Stage 5
results["retrieval_results"].append(
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
)
compact_searcher.cleanup()
if results["non_compact"]["search_times"]:
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
if results["compact"]["search_times"]:
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
if results["avg_search_times"].get("compact", 0) > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = 0.0
non_compact_searcher.cleanup()
return results
def evaluate_complexity(
self,
recall_eval: "RecallEvaluator",
queries: list[str],
target: float = 0.90,
c_min: int = 8,
c_max: int = 256,
max_iters: int = 10,
recompute: bool = False,
) -> dict:
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
def round_c(x: int) -> int:
# snap to multiple of 8 like other benches typically do
return max(1, int((x + 7) // 8) * 8)
metrics: list[dict] = []
lo = round_c(c_min)
hi = round_c(c_max)
print(
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
)
# Ensure upper bound can reach target; expand if needed (up to a cap)
r_lo = recall_eval.evaluate_recall_at_3(
queries, complexity=lo, recompute_embeddings=recompute
)
metrics.append({"complexity": lo, "recall_at_3": r_lo})
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
cap = 1024
while r_hi < target and hi < cap:
lo = hi
r_lo = r_hi
hi = round_c(hi * 2)
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
if r_hi < target:
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
print("📈 Observations:")
for m in metrics:
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
# Binary search within [lo, hi]
best = hi
iters = 0
while lo < hi and iters < max_iters:
mid = round_c((lo + hi) // 2)
r_mid = recall_eval.evaluate_recall_at_3(
queries, complexity=mid, recompute_embeddings=recompute
)
metrics.append({"complexity": mid, "recall_at_3": r_mid})
if r_mid >= target:
best = mid
hi = mid
else:
lo = mid + 8 # move past mid, respecting multiple-of-8 step
iters += 1
print("📈 Binary search results (sampled points):")
# Print unique complexity entries ordered by complexity
for m in sorted(
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
):
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
def main():
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all", "45"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument(
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
)
parser.add_argument(
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
)
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
args = parser.parse_args()
# Resolve queries file: if default path not found, fall back to index's directory
if not os.path.exists(args.queries):
from pathlib import Path
idx_dir = Path(args.index).parent
fallback_q = idx_dir / "evaluation_queries.jsonl"
if fallback_q.exists():
args.queries = str(fallback_q)
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_enron_emails.py first to build the baseline")
raise SystemExit(1)
results_out: dict = {}
if args.stage in ("2", "all"):
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[:10]
print(f"🧪 Using first {len(queries)} queries")
complexity = args.complexity or 64
r = evaluator.evaluate_recall_at_3(queries, complexity)
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 2 completed!\n")
if args.stage in ("3", "all"):
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[: args.max_queries]
print(f"🧪 Using first {len(queries)} queries")
# Build non-compact index for fast binary search (recompute_embeddings=False)
from pathlib import Path
index_path = Path(args.index)
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact evaluator for binary search with recompute=False
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
sweep = enron_eval.evaluate_complexity(
evaluator_nc, queries, target=args.target_recall, recompute=False
)
results_out["stage3"] = sweep
# Persist default stage 3 results near the index for Stage 4 auto-pickup
from pathlib import Path
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
with open(default_stage3_path, "w", encoding="utf-8") as f:
json.dump({"stage3": sweep}, f, indent=2)
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
evaluator_nc.cleanup()
enron_eval.cleanup()
print("✅ Stage 3 completed!\n")
if args.stage in ("4", "all", "45"):
print("🚀 Starting Stage 4: Index size + performance comparison")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
test_q = queries[: min(args.max_queries, len(queries))]
current_sizes = enron_eval.analyze_index_sizes()
# Build non-compact index for comparison (no fallback)
from pathlib import Path
index_path = Path(args.index)
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
nc_eval = EnronEvaluator(non_compact_path)
if (
current_sizes.get("index_only_mb", 0) > 0
and non_compact_sizes.get("index_only_mb", 0) > 0
):
storage_saving_percent = max(
0.0,
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
)
else:
storage_saving_percent = 0.0
if args.complexity is None:
# Prefer in-session Stage 3 result
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
complexity = results_out["stage3"]["best_complexity"]
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
else:
# Try to load last saved Stage 3 result near index
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
if default_stage3_path.exists():
with open(default_stage3_path, encoding="utf-8") as f:
prev = json.load(f)
complexity = prev.get("stage3", {}).get("best_complexity")
if complexity is None:
raise SystemExit(
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
)
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
else:
raise SystemExit(
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
)
else:
complexity = args.complexity
comp = enron_eval.compare_index_performance(
non_compact_path, args.index, test_q, complexity=complexity
)
results_out["stage4"] = {
"current_index": current_sizes,
"non_compact_index": non_compact_sizes,
"storage_saving_percent": storage_saving_percent,
"performance_comparison": comp,
}
nc_eval.cleanup()
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
# Check if Stage 4 results exist
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
print("❌ Stage 5 requires Stage 4 retrieval results")
print("💡 Run Stage 4 first or use --stage all")
raise SystemExit(1)
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
if not retrieval_results:
print("❌ No retrieval results found from Stage 4")
raise SystemExit(1)
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
# Load LLM
try:
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Run generation using stored retrieval results
import time
from llm_utils import create_prompt
generation_times = []
responses = []
print("🤖 Running generation on pre-retrieved results...")
for i, item in enumerate(retrieval_results):
query = item["query"]
retrieved_docs = item["retrieved_docs"]
# Prepare context from retrieved docs
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
prompt = create_prompt(context, query, "emails")
# Time generation only
gen_start = time.time()
response = llm_func(prompt)
gen_time = time.time() - gen_start
generation_times.append(gen_time)
responses.append(response)
if i < 3:
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
avg_gen_time = sum(generation_times) / len(generation_times)
print("\n📊 Generation Results:")
print(f" Total Queries: {len(retrieval_results)}")
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
print(" (Search time from Stage 4)")
results_out["stage5"] = {
"total_queries": len(retrieval_results),
"avg_generation_time": avg_gen_time,
"generation_times": generation_times,
"responses": responses,
}
# Show sample results
print("\n📝 Sample Results:")
for i in range(min(3, len(retrieval_results))):
query = retrieval_results[i]["query"]
response = responses[i]
print(f" Q{i + 1}: {query[:60]}...")
print(f" A{i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
print("💡 Make sure transformers/vllm is installed and model is available")
print("✅ Stage 5 completed!\n")
if args.output and results_out:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results_out, f, indent=2)
print(f"📝 Saved results to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -1,359 +0,0 @@
"""
Enron Emails Benchmark Setup Script
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
"""
import argparse
import csv
import json
import os
import re
from collections.abc import Iterable
from email import message_from_string
from email.policy import default
from pathlib import Path
from typing import Optional
from leann import LeannBuilder
class EnronSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
self.index_path = self.data_dir / "enron_index_hnsw.leann"
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
self.downloads_dir = self.data_dir / "downloads"
self.downloads_dir.mkdir(parents=True, exist_ok=True)
# ----------------------------
# Dataset acquisition
# ----------------------------
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
"""Return a path to emails.csv, downloading from Kaggle if needed."""
if emails_csv:
p = Path(emails_csv)
if not p.exists():
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
return str(p)
print(
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
)
try:
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
api.dataset_download_files(
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
)
candidate = self.downloads_dir / "emails.csv"
if candidate.exists():
print(f"✅ Downloaded emails.csv: {candidate}")
return str(candidate)
else:
raise FileNotFoundError(
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
)
except Exception as e:
print(
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
)
print(
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
)
raise e
# ----------------------------
# Data preparation
# ----------------------------
@staticmethod
def _extract_message_id(raw_email: str) -> str:
msg = message_from_string(raw_email, policy=default)
val = msg.get("Message-ID", "")
if val.startswith("<") and val.endswith(">"):
val = val[1:-1]
return val or ""
@staticmethod
def _split_header_body(raw_email: str) -> tuple[str, str]:
parts = raw_email.split("\n\n", 1)
if len(parts) == 2:
return parts[0].strip(), parts[1].strip()
# Heuristic fallback
first_lines = raw_email.splitlines()
if first_lines and ":" in first_lines[0]:
return raw_email.strip(), ""
return "", raw_email.strip()
@staticmethod
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
text = (text or "").strip()
if not text:
return []
if chunk_words <= 0:
return [text]
words = text.split()
if not words:
return []
limit = len(words)
if not keep_last:
limit = (len(words) // chunk_words) * chunk_words
if limit == 0:
return []
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
return [c for c in (s.strip() for s in chunks) if c]
def _iter_passages_from_csv(
self,
emails_csv: Path,
chunk_words: int = 256,
keep_last_header: bool = True,
keep_last_body: bool = True,
max_emails: int | None = None,
) -> Iterable[dict]:
with open(emails_csv, encoding="utf-8") as f:
reader = csv.DictReader(f)
count = 0
for i, row in enumerate(reader):
if max_emails is not None and count >= max_emails:
break
raw_message = row.get("message", "")
email_file_id = row.get("file", "")
if not raw_message.strip():
continue
message_id = self._extract_message_id(raw_message)
if not message_id:
# Fallback ID based on CSV position and file path
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
message_id = f"enron_{i}_{safe_file}"
header, body = self._split_header_body(raw_message)
# Header chunks
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": True,
"email_file_id": email_file_id,
},
}
# Body chunks
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": False,
"email_file_id": email_file_id,
},
}
count += 1
# ----------------------------
# Build LEANN index and FAISS baseline
# ----------------------------
def build_leann_index(
self,
emails_csv: Optional[str],
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
chunk_words: int = 256,
max_emails: int | None = None,
) -> str:
emails_csv_path = self.ensure_emails_csv(emails_csv)
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True,
is_compact=True,
num_threads=4,
)
# Stream passages and add to builder
preview_written = 0
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
for p in self._iter_passages_from_csv(
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
):
builder.add_text(p["text"], metadata=p["metadata"])
if preview_written < 200:
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
preview_written += 1
print(f"🔨 Building index at {self.index_path}...")
builder.build_index(str(self.index_path))
print("✅ LEANN index built!")
return str(self.index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
print("🔨 Building FAISS Flat baseline from LEANN passages...")
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read meta for passage source and embedding model
meta_path = f"{index_path}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
# Load passages from builder output so IDs match LEANN
passages: list[str] = []
passage_ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"]) # builder-assigned ID
print(f"📄 Loaded {len(passages)} passages for baseline")
print(f"🤖 Embedding model: {embedding_model}")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
# Build FAISS IndexFlatIP
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
emb_f32 = embeddings.astype(np.float32)
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as pf:
pickle.dump(passage_ids, pf)
print(f"✅ FAISS baseline saved: {baseline_path}")
print(f"✅ Metadata saved: {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
# ----------------------------
# Queries (optional): prepare evaluation queries file
# ----------------------------
def prepare_queries(self, min_realism: float = 0.85) -> Path:
print(
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
)
try:
from datasets import load_dataset
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
except Exception as e:
print(f"⚠️ Failed to load dataset: {e}")
return self.queries_file
kept = 0
with open(self.queries_file, "w", encoding="utf-8") as out:
for i, item in enumerate(ds):
how_realistic = float(item.get("how_realistic", 0.0))
if how_realistic < min_realism:
continue
qid = str(item.get("id", f"enron_q_{i}"))
query = item.get("question", "")
if not query:
continue
record = {
"id": qid,
"query": query,
# For reference only, not used in recall metric below
"gt_message_ids": item.get("message_ids", []),
}
out.write(json.dumps(record) + "\n")
kept += 1
print(f"✅ Wrote {kept} queries to {self.queries_file}")
return self.queries_file
def main():
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
parser.add_argument(
"--emails-csv",
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
)
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model for LEANN",
)
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
args = parser.parse_args()
setup = EnronSetup(args.data_dir)
# Build index
if not args.skip_build:
index_path = setup.build_leann_index(
emails_csv=args.emails_csv,
backend=args.backend,
embedding_model=args.embedding_model,
chunk_words=args.chunk_words,
max_emails=args.max_emails,
)
# Build FAISS baseline from the same passages & embeddings
setup.build_faiss_flat_baseline(index_path)
else:
print("⏭️ Skipping LEANN index build and baseline")
# Queries file (optional)
if not args.skip_queries:
setup.prepare_queries()
else:
print("⏭️ Skipping query preparation")
print("\n🎉 Enron Emails setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("Next steps:")
print(
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
)
if __name__ == "__main__":
main()

View File

@@ -1,115 +0,0 @@
# FinanceBench Benchmark for LEANN-RAG
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
## Dataset
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
- **Questions**: 150 financial Q&A examples
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
## Structure
```
benchmarks/financebench/
├── setup_financebench.py # Downloads PDFs and builds index
├── evaluate_financebench.py # Intelligent evaluation script
├── data/
│ ├── financebench_merged.jsonl # Q&A dataset
│ ├── pdfs/ # Downloaded financial documents
│ └── index/ # LEANN indexes
│ └── financebench_full_hnsw.leann
└── README.md
```
## Usage
### 1. Setup (Download & Build Index)
```bash
cd benchmarks/financebench
python setup_financebench.py
```
This will:
- Download the 150 Q&A examples
- Download all 368 PDF documents (parallel processing)
- Build a LEANN index from 53K+ text chunks
- Verify setup with test query
### 2. Evaluation
```bash
# Basic retrieval evaluation
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
# RAG generation evaluation with Qwen3-8B
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
```
## Evaluation Methods
### Retrieval Evaluation
Uses intelligent matching with three strategies:
1. **Exact text overlap** - Direct substring matches
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
3. **Semantic similarity** - Word overlap with 20% threshold
### QA Evaluation
LLM-based answer evaluation using GPT-4o:
- Handles numerical rounding and equivalent representations
- Considers fractions, percentages, and decimal equivalents
- Evaluates semantic meaning rather than exact text match
## Benchmark Results
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
**Retrieval Metrics:**
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
- **Number Match Rate**: 120.7% (key financial figures matched)*
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
- **Average Search Time**: 0.097s
**QA Metrics:**
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
- **Average QA Time**: 4.71s (end-to-end response time)
**System Performance:**
- **Index Size**: 53,985 chunks from 368 PDFs
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
### LEANN-RAG Generation Performance (Qwen3-8B)
- **Stage 4 (Index Comparison):**
- Compact Index: 5.0 MB
- Non-compact Index: 172.2 MB
- **Storage Saving**: 97.1%
- **Search Performance**:
- Non-compact (no recompute): 0.009s avg per query
- Compact (with recompute): 2.203s avg per query
- Speed ratio: 0.004x
**Generation Evaluation (20 queries, complexity=64):**
- **Average Search Time**: 1.638s per query
- **Average Generation Time**: 45.957s per query
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
- **Total Questions Processed**: 20
## Options
```bash
# Use different backends
python setup_financebench.py --backend diskann
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
# Use different embedding models
python setup_financebench.py --embedding-model facebook/contriever
```

View File

@@ -1,923 +0,0 @@
"""
FinanceBench Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
from typing import Optional
import numpy as np
import openai
from leann import LeannChat, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for given queries at specified complexity"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(queries)
for i, query in enumerate(queries):
# Get ground truth: search with FAISS flat
from leann.api import compute_embeddings
query_embedding = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity
test_results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class FinanceBenchEvaluator:
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
self.index_path = index_path
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
self.searcher = LeannSearcher(index_path)
self.chat = LeannChat(index_path) if openai_api_key else None
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
"""Load FinanceBench dataset"""
data = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data.append(json.loads(line))
print(f"📊 Loaded {len(data)} FinanceBench examples")
return data
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes with and without embeddings"""
print("📏 Analyzing index sizes...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes = {}
total_with_embeddings = 0
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
for file_path, name in [
(index_file, "index"),
(meta_file, "metadata"),
(passages_file, "passages_text"),
(passages_idx_file, "passages_index"),
]:
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
sizes[name] = size_mb
total_with_embeddings += size_mb
else:
sizes[name] = 0
sizes["total_with_embeddings"] = total_with_embeddings
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
return sizes
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
"""Create a compact index for comparison purposes"""
print("🏗️ Building compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage
**meta.get("backend_kwargs", {}),
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building compact index at {compact_index_path}...")
builder.build_index(compact_index_path)
# Analyze the compact index size
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
compact_sizes = temp_evaluator.analyze_index_sizes()
compact_sizes["index_type"] = "compact"
return compact_sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build non-compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
builder.build_index(non_compact_index_path)
# Analyze the non-compact index size
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
import time
from leann import LeannSearcher
# Test queries
test_queries = [item["question"] for item in test_data[:5]]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for query in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for query in test_queries:
start_time = time.time()
_ = compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def evaluate_timing_breakdown(
self, data: list[dict], max_samples: Optional[int] = None
) -> dict:
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
if not self.chat or not self.openai_client:
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
return {
"total_questions": 0,
"avg_search_time": 0.0,
"avg_generation_time": 0.0,
"avg_total_time": 0.0,
"accuracy": 0.0,
}
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
if max_samples:
data = data[:max_samples]
print(f"📝 Using first {max_samples} samples for timing evaluation")
search_times = []
generation_times = []
total_times = []
correct_answers = 0
for i, item in enumerate(data):
question = item["question"]
ground_truth = item["answer"]
try:
# Hack: Monkey-patch the ask method to capture internal timing
original_ask = self.chat.ask
captured_search_time = None
captured_generation_time = None
def patched_ask(*args, **kwargs):
nonlocal captured_search_time, captured_generation_time
# Time the search part
search_start = time.time()
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
captured_search_time = time.time() - search_start
# Time the generation part
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {args[0]}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
generation_start = time.time()
answer = self.chat.llm.ask(prompt)
captured_generation_time = time.time() - generation_start
return answer
# Apply the patch
self.chat.ask = patched_ask
# Time the total QA
total_start = time.time()
generated_answer = self.chat.ask(question)
total_time = time.time() - total_start
# Restore original method
self.chat.ask = original_ask
# Store the timings
search_times.append(captured_search_time)
generation_times.append(captured_generation_time)
total_times.append(total_time)
# Check accuracy using LLM as judge
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
if is_correct:
correct_answers += 1
status = "" if is_correct else ""
print(
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
)
print(f" GT: {ground_truth}")
print(f" Gen: {generated_answer[:100]}...")
except Exception as e:
print(f" ❌ Error: {e}")
search_times.append(0.0)
generation_times.append(0.0)
total_times.append(0.0)
accuracy = correct_answers / len(data) if data else 0.0
metrics = {
"total_questions": len(data),
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
"avg_generation_time": sum(generation_times) / len(generation_times)
if generation_times
else 0.0,
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
"accuracy": accuracy,
"correct_answers": correct_answers,
"search_times": search_times,
"generation_times": generation_times,
"total_times": total_times,
}
return metrics
def _check_answer_accuracy(
self, generated_answer: str, ground_truth: str, question: str
) -> bool:
"""Check if generated answer matches ground truth using LLM as judge"""
judge_prompt = f"""You are an expert judge evaluating financial question answering.
Question: {question}
Ground Truth Answer: {ground_truth}
Generated Answer: {generated_answer}
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
1. Numerical accuracy (exact values, units, currency)
2. Key financial concepts and terminology
3. Overall factual correctness
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
try:
judge_response = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=10,
temperature=0,
)
judgment = judge_response.choices[0].message.content.strip().upper()
return judgment == "CORRECT"
except Exception as e:
print(f" ⚠️ Judge error: {e}, falling back to string matching")
# Fallback to simple string matching
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
return gt_clean in gen_clean
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 EVALUATION RESULTS")
print("=" * 50)
# Index comparison analysis
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
print("\n📏 Index Comparison Analysis:")
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
print("\n📊 Accuracy:")
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
print(
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
)
print("\n📊 Timing Breakdown:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
if timing_metrics.get("avg_total_time", 0) > 0:
search_pct = (
timing_metrics.get("avg_search_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
gen_pct = (
timing_metrics.get("avg_generation_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
print("\n📈 Time Distribution:")
print(f" Search: {search_pct:.1f}%")
print(f" Generation: {gen_pct:.1f}%")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
parser.add_argument(
"--stage",
choices=["2", "3", "4", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
)
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_financebench.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load FinanceBench queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Test with more queries for robust measurement
test_queries = queries[:2000]
print(f"🧪 Testing with {len(test_queries)} queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search (will be reused in Stage 4)
print("🏗️ Creating non-compact index for binary search...")
evaluator = FinanceBenchEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Use more queries for robust measurement
test_queries = queries[:200]
print(f"🧪 Testing with {len(test_queries)} queries")
# Binary search for 90% recall complexity (without recompute for speed)
target_recall = 0.9
min_complexity, max_complexity = 1, 32
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_queries[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Comprehensive evaluation with dual index comparison
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
# Use FinanceBench evaluator for QA evaluation
evaluator = FinanceBenchEvaluator(
args.index, args.openai_api_key if args.llm_backend == "openai" else None
)
print("📖 Loading FinanceBench dataset...")
data = evaluator.load_dataset(args.dataset)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
from pathlib import Path
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes
print("\n📊 Index size comparison:")
print(
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print(
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print("\n📊 Index-only size comparison (.index file only):")
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
# Use index-only size for fair comparison (same as Enron emails)
storage_saving = (
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
/ non_compact_size_metrics["index_only_mb"]
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for performance comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, data[:10], complexity=complexity
)
# Step 5: Generation evaluation
test_samples = 20
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
if args.llm_backend == "openai" and args.openai_api_key:
print("🔍🤖 Running OpenAI-based generation evaluation...")
evaluation_start = time.time()
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
evaluation_time = time.time() - evaluation_start
else:
print(
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
)
try:
# Load LLM
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Simple generation evaluation
queries = [item["question"] for item in data[:test_samples]]
gen_results = evaluate_rag(
evaluator.searcher,
llm_func,
queries,
domain="finance",
complexity=complexity,
)
timing_metrics = {
"total_questions": len(queries),
"avg_search_time": gen_results["avg_search_time"],
"avg_generation_time": gen_results["avg_generation_time"],
"results": gen_results["results"],
}
evaluation_time = time.time()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
timing_metrics = {
"total_questions": 0,
"avg_search_time": 0,
"avg_generation_time": 0,
}
evaluation_time = 0
# Combine all metrics
combined_metrics = {
**timing_metrics,
"total_evaluation_time": evaluation_time,
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print results
print("\n📊 Generation Results:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review accuracy and timing breakdown for performance optimization")
print(" - Run full evaluation on complete dataset if needed")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
from pathlib import Path
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,462 +0,0 @@
#!/usr/bin/env python3
"""
FinanceBench Complete Setup Script
Downloads all PDFs and builds full LEANN datastore
"""
import argparse
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import pymupdf
import requests
from leann import LeannBuilder, LeannSearcher
from tqdm import tqdm
class FinanceBenchSetup:
def __init__(self, data_dir: str = "data"):
self.base_dir = Path(__file__).parent # benchmarks/financebench/
self.data_dir = self.base_dir / data_dir
self.pdf_dir = self.data_dir / "pdfs"
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
self.index_dir = self.data_dir / "index"
self.download_lock = Lock()
def download_dataset(self):
"""Download the main FinanceBench dataset"""
print("📊 Downloading FinanceBench dataset...")
self.data_dir.mkdir(parents=True, exist_ok=True)
if self.dataset_file.exists():
print(f"✅ Dataset already exists: {self.dataset_file}")
return
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
response = requests.get(url, stream=True)
response.raise_for_status()
with open(self.dataset_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Dataset downloaded: {self.dataset_file}")
def get_pdf_list(self):
"""Get list of all PDF files from GitHub"""
print("📋 Fetching PDF list from GitHub...")
response = requests.get(
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
)
response.raise_for_status()
pdf_files = response.json()
print(f"Found {len(pdf_files)} PDF files")
return pdf_files
def download_single_pdf(self, pdf_info, position):
"""Download a single PDF file"""
pdf_name = pdf_info["name"]
pdf_path = self.pdf_dir / pdf_name
# Skip if already downloaded
if pdf_path.exists() and pdf_path.stat().st_size > 0:
return f"{pdf_name} (cached)"
try:
# Download PDF
response = requests.get(pdf_info["download_url"], timeout=60)
response.raise_for_status()
# Write to file
with self.download_lock:
with open(pdf_path, "wb") as f:
f.write(response.content)
return f"{pdf_name} ({len(response.content) // 1024}KB)"
except Exception as e:
return f"{pdf_name}: {e!s}"
def download_all_pdfs(self, max_workers: int = 5):
"""Download all PDF files with parallel processing"""
self.pdf_dir.mkdir(parents=True, exist_ok=True)
pdf_files = self.get_pdf_list()
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks
future_to_pdf = {
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
for i, pdf_info in enumerate(pdf_files)
}
# Process completed downloads with progress bar
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
for future in as_completed(future_to_pdf):
result = future.result()
pbar.set_postfix_str(result.split()[-1] if "" in result else "Error")
pbar.update(1)
# Verify downloads
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
# Show any failures
missing_pdfs = []
for pdf_info in pdf_files:
pdf_path = self.pdf_dir / pdf_info["name"]
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
missing_pdfs.append(pdf_info["name"])
if missing_pdfs:
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
for pdf in missing_pdfs[:5]: # Show first 5
print(f" - {pdf}")
if len(missing_pdfs) > 5:
print(f" ... and {len(missing_pdfs) - 5} more")
def build_leann_index(
self,
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
):
"""Build LEANN index from all PDFs"""
print(f"🏗️ Building LEANN index with {backend} backend...")
# Check if we have PDFs
pdf_files = list(self.pdf_dir.glob("*.pdf"))
if not pdf_files:
raise RuntimeError("No PDF files found! Run download first.")
print(f"Found {len(pdf_files)} PDF files to process")
start_time = time.time()
# Initialize builder with standard compact configuration
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage (pruned)
num_threads=4,
)
# Process PDFs and extract text
total_chunks = 0
failed_pdfs = []
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
try:
chunks = self.extract_pdf_text(pdf_path)
for chunk in chunks:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
total_chunks += 1
except Exception as e:
print(f"❌ Failed to process {pdf_path.name}: {e}")
failed_pdfs.append(pdf_path.name)
continue
# Build index in index directory
self.index_dir.mkdir(parents=True, exist_ok=True)
index_path = self.index_dir / f"financebench_full_{backend}.leann"
print(f"🔨 Building index: {index_path}")
builder.build_index(str(index_path))
build_time = time.time() - start_time
print("✅ Index built successfully!")
print(f" 📁 Index path: {index_path}")
print(f" 📊 Total chunks: {total_chunks:,}")
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
print(f" ⏱️ Build time: {build_time:.1f}s")
if failed_pdfs:
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
return str(index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
print("🔨 Building FAISS Flat baseline...")
import os
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read metadata from the built index
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
import json
meta = json.loads(f.read())
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
print(f"📊 Loading passages from {passage_file}...")
print(f"🤖 Using embedding model: {embedding_model}")
# Load all passages for baseline
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"📄 Loaded {len(passages)} passages")
# Compute embeddings using the same method as LEANN
print("🧮 Computing embeddings...")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
print(f"📐 Embedding shape: {embeddings.shape}")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(passage_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
"""Extract and chunk text from a PDF file"""
chunks = []
doc = pymupdf.open(pdf_path)
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text() # type: ignore
if not text.strip():
continue
# Create metadata
metadata = {
"source_file": pdf_path.name,
"page_number": page_num + 1,
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
"company": pdf_path.name.split("_")[0],
"doc_period": self.extract_year_from_filename(pdf_path.name),
}
# Use recursive character splitting like LangChain
if len(text.split()) > 500:
# Split by double newlines (paragraphs)
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
current_chunk = ""
for para in paragraphs:
# If adding this paragraph would make chunk too long, save current chunk
if current_chunk and len((current_chunk + " " + para).split()) > 300:
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
current_chunk = para
else:
current_chunk = (current_chunk + " " + para).strip()
# Add the last chunk
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
else:
# Page is short enough, use as single chunk
chunks.append(
{
"text": text.strip(),
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
}
)
doc.close()
return chunks
def extract_year_from_filename(self, filename: str) -> str:
"""Extract year from PDF filename"""
# Try to find 4-digit year in filename
match = re.search(r"(\d{4})", filename)
return match.group(1) if match else "unknown"
def verify_setup(self, index_path: str):
"""Verify the setup by testing a simple query"""
print("🧪 Verifying setup with test query...")
try:
searcher = LeannSearcher(index_path)
# Test query
test_query = "What is the capital expenditure for 3M in 2018?"
results = searcher.search(test_query, top_k=3)
print(f"✅ Test query successful! Found {len(results)} results:")
for i, result in enumerate(results, 1):
company = result.metadata.get("company", "Unknown")
year = result.metadata.get("doc_period", "Unknown")
page = result.metadata.get("page_number", "Unknown")
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
print(f" {result.text[:100]}...")
searcher.cleanup()
print("✅ Setup verification completed successfully!")
except Exception as e:
print(f"❌ Setup verification failed: {e}")
raise
def main():
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument(
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
)
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model",
)
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
parser.add_argument(
"--build-baseline-only",
action="store_true",
help="Only build FAISS baseline from existing index",
)
args = parser.parse_args()
print("🏦 FinanceBench Complete Setup")
print("=" * 50)
setup = FinanceBenchSetup(args.data_dir)
try:
if args.build_baseline_only:
# Only build baseline from existing index
index_path = setup.index_dir / f"financebench_full_{args.backend}"
index_file = f"{index_path}.index"
meta_file = f"{index_path}.leann.meta.json"
if not os.path.exists(index_file) or not os.path.exists(meta_file):
print("❌ Index files not found:")
print(f" Index: {index_file}")
print(f" Meta: {meta_file}")
print("💡 Run without --build-baseline-only to build the index first")
exit(1)
print(f"🔨 Building baseline from existing index: {index_path}")
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
print(f"✅ Baseline built at {baseline_path}")
return
# Step 1: Download dataset
setup.download_dataset()
# Step 2: Download PDFs
if not args.skip_download:
setup.download_all_pdfs(max_workers=args.max_workers)
else:
print("⏭️ Skipping PDF download")
# Step 3: Build LEANN index
if not args.skip_build:
index_path = setup.build_leann_index(
backend=args.backend, embedding_model=args.embedding_model
)
# Step 4: Build FAISS flat baseline
print("\n🔨 Building FAISS flat baseline...")
baseline_path = setup.build_faiss_flat_baseline(index_path)
print(f"✅ Baseline built at {baseline_path}")
# Step 5: Verify setup
setup.verify_setup(index_path)
else:
print("⏭️ Skipping index building")
print("\n🎉 FinanceBench setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("\nNext steps:")
print(
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
)
print(
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
)
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,214 +0,0 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "faiss-cpu",
# "numpy",
# "sentence-transformers",
# "torch",
# "tqdm",
# ]
# ///
"""
Independent recall verification script using standard FAISS.
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
"""
import json
import time
from pathlib import Path
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
"""
Direct embedding computation using sentence-transformers.
Copied logic to avoid dependency issues.
"""
print(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
print(f"Computing embeddings for {len(chunks)} chunks...")
embeddings = model.encode(
chunks,
show_progress_bar=True,
batch_size=32,
convert_to_numpy=True,
normalize_embeddings=False,
)
return embeddings.astype(np.float32)
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
"""Load FinanceBench queries from dataset"""
queries = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
if len(queries) >= max_queries:
break
return queries
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
"""Load passages from LEANN index structure"""
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
index_dir = Path(index_path).parent
passage_file = index_dir / Path(passage_file).name
print(f"Loading passages from {passage_file}")
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Loading passages"):
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"Loaded {len(passages)} passages")
return passages, passage_ids
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
dimension = embeddings.shape[1]
# Build Flat index (ground truth)
print("Building FAISS IndexFlatIP (ground truth)...")
flat_index = faiss.IndexFlatIP(dimension)
flat_index.add(embeddings)
# Build HNSW index
print("Building FAISS IndexHNSWFlat...")
M = 32 # Same as LEANN default
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
hnsw_index.add(embeddings)
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
return flat_index, hnsw_index
def evaluate_recall_at_k(
query_embeddings: np.ndarray,
flat_index: faiss.Index,
hnsw_index: faiss.Index,
passage_ids: list[str],
k: int = 3,
ef_search: int = 64,
) -> float:
"""Evaluate recall@k comparing HNSW vs Flat"""
# Set search parameters for HNSW
hnsw_index.hnsw.efSearch = ef_search
total_recall = 0.0
num_queries = query_embeddings.shape[0]
for i in range(num_queries):
query = query_embeddings[i : i + 1] # Keep 2D shape
# Get ground truth from Flat index (standard FAISS API)
flat_distances, flat_indices = flat_index.search(query, k)
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
# Get results from HNSW index (standard FAISS API)
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
# Calculate recall
intersection = ground_truth_ids.intersection(hnsw_ids)
recall = len(intersection) / k
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
print(f" Flat: {list(ground_truth_ids)}")
print(f" HNSW: {list(hnsw_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
return avg_recall
def main():
# Configuration
dataset_path = "data/financebench_merged.jsonl"
index_path = "data/index/financebench_full_hnsw.leann"
embedding_model = "sentence-transformers/all-mpnet-base-v2"
print("🔍 FAISS Recall Verification")
print("=" * 50)
# Check if files exist
if not Path(dataset_path).exists():
print(f"❌ Dataset not found: {dataset_path}")
return
if not Path(f"{index_path}.meta.json").exists():
print(f"❌ Index metadata not found: {index_path}.meta.json")
return
# Load data
print("📖 Loading FinanceBench queries...")
queries = load_financebench_queries(dataset_path, max_queries=50)
print(f"Loaded {len(queries)} queries")
print("📄 Loading passages from LEANN index...")
passages, passage_ids = load_passages_from_leann_index(index_path)
# Compute embeddings
print("🧮 Computing passage embeddings...")
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
print("🧮 Computing query embeddings...")
query_embeddings = compute_embeddings_direct(queries, embedding_model)
# Build FAISS indexes
print("🏗️ Building FAISS indexes...")
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
# Test different efSearch values (equivalent to LEANN complexity)
print("\n📊 Evaluating Recall@3 at different efSearch values...")
ef_search_values = [16, 32, 64, 128, 256]
for ef_search in ef_search_values:
print(f"\n🧪 Testing efSearch = {ef_search}")
start_time = time.time()
recall = evaluate_recall_at_k(
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
)
elapsed = time.time() - start_time
print(
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
)
print("\n✅ Verification completed!")
print("\n📋 Summary:")
print(" - Built independent FAISS Flat and HNSW indexes")
print(" - Compared recall@3 at different efSearch values")
print(" - Used same embedding model as LEANN")
print(" - This validates LEANN's recall measurements")
if __name__ == "__main__":
main()

View File

@@ -1 +0,0 @@
data/

View File

@@ -1,199 +0,0 @@
# LAION Multimodal Benchmark
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
## Overview
This benchmark evaluates:
- **Image retrieval timing** using caption-based queries
- **Recall@K performance** for image search
- **Complexity analysis** across different search parameters
- **Index size and storage efficiency**
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
## Dataset Configuration
- **Dataset**: LAION-400M subset (10,000 images)
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
- **Queries**: 200 random captions from the dataset
- **Ground Truth**: Self-recall (query caption → original image)
## Quick Start
### 1. Setup the benchmark
```bash
cd benchmarks/laion
python setup_laion.py --num-samples 10000 --num-queries 200
```
This will:
- Create dummy LAION data (10K samples)
- Generate CLIP embeddings (512-dim)
- Build LEANN index with HNSW backend
- Create 200 evaluation queries
### 2. Run evaluation
```bash
# Run all evaluation stages
python evaluate_laion.py --index data/laion_index.leann
# Run specific stages
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
# Multimodal generation with Qwen2.5-VL
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
```
### 3. Save results
```bash
python evaluate_laion.py --index data/laion_index.leann --output results.json
```
## Configuration Options
### Setup Options
```bash
python setup_laion.py \
--num-samples 10000 \
--num-queries 200 \
--index-path data/laion_index.leann \
--backend hnsw
```
### Evaluation Options
```bash
python evaluate_laion.py \
--index data/laion_index.leann \
--queries data/evaluation_queries.jsonl \
--complexity 64 \
--top-k 3 \
--num-samples 100 \
--stage all
```
## Evaluation Stages
### Stage 2: Recall Evaluation
- Evaluates Recall@3 for multimodal retrieval
- Compares LEANN vs FAISS baseline performance
- Self-recall: query caption should retrieve original image
### Stage 3: Complexity Analysis
- Binary search for optimal complexity (90% recall target)
- Tests performance across different complexity levels
- Analyzes speed vs. accuracy tradeoffs
### Stage 4: Index Comparison
- Compares compact vs non-compact index sizes
- Measures search performance differences
- Reports storage efficiency and speed ratios
### Stage 5: Multimodal Generation
- Uses Qwen2.5-VL for image understanding and description
- Retrieval-Augmented Generation (RAG) with multimodal context
- Measures both search and generation timing
## Output Metrics
### Timing Metrics
- Average/median/min/max search time
- Standard deviation
- Searches per second
- Latency in milliseconds
### Recall Metrics
- Recall@3 percentage for image retrieval
- Number of queries with ground truth
### Index Metrics
- Total index size (MB)
- Component breakdown (index, passages, metadata)
- Storage savings (compact vs non-compact)
- Backend and embedding model info
### Generation Metrics (Stage 5)
- Average search time per query
- Average generation time per query
- Time distribution (search vs generation)
- Sample multimodal responses
- Model: Qwen2.5-VL performance
## Benchmark Results
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
**Stage 3: Optimal Complexity Analysis**
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
- **Binary Search Range**: 1-128
- **Target Recall**: 90%
- **Index Type**: Non-compact (for fast binary search)
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
- **Total Queries**: 20
- **Average Search Time**: 1.200s per query
- **Average Generation Time**: 6.558s per query
- **Time Distribution**: Search 15.5%, Generation 84.5%
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
- **Optimal Complexity**: 85
**System Performance:**
- **Index Size**: ~10,000 image embeddings from LAION subset
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
- **Backend**: HNSW with cosine distance
### Example Results
```
🎯 LAION MULTIMODAL BENCHMARK RESULTS
============================================================
📊 Multimodal Generation Results:
Total Queries: 20
Avg Search Time: 1.200s
Avg Generation Time: 6.558s
Time Distribution: Search 15.5%, Generation 84.5%
LLM Backend: HuggingFace transformers
Model: Qwen/Qwen2.5-VL-7B-Instruct
⚙️ Optimal Complexity Analysis:
Target Recall: 90%
Optimal Complexity: 85
Binary Search Range: 1-128
Non-compact Index (fast search, no recompute)
🚀 Performance Summary:
Multimodal RAG: 7.758s total per query
Search: 15.5% of total time
Generation: 84.5% of total time
```
## Directory Structure
```
benchmarks/laion/
├── setup_laion.py # Setup script
├── evaluate_laion.py # Evaluation script
├── README.md # This file
└── data/ # Generated data
├── laion_images/ # Image files (placeholder)
├── laion_metadata.jsonl # Image metadata
├── laion_passages.jsonl # LEANN passages
├── laion_embeddings.npy # CLIP embeddings
├── evaluation_queries.jsonl # Evaluation queries
└── laion_index.leann/ # LEANN index files
```
## Notes
- Current implementation uses dummy data for demonstration
- For real LAION data, implement actual download logic in `setup_laion.py`
- CLIP embeddings are randomly generated - replace with real CLIP model for production
- Adjust `num_samples` and `num_queries` based on available resources
- Consider using `--num-samples` during evaluation for faster testing

View File

@@ -1,725 +0,0 @@
"""
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
import numpy as np
from leann import LeannSearcher
from leann_backend_hnsw import faiss
from sentence_transformers import SentenceTransformer
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline (image embeddings)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.image_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
self.st_clip = SentenceTransformer("clip-ViT-L-14")
def evaluate_recall_at_3(
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(captions)
for i, caption in enumerate(captions):
# Get ground truth: search with FAISS flat using caption text embedding
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
query_embedding = self.st_clip.encode(
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results (image IDs from FAISS)
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity (using caption as text query)
test_results = self.searcher.search(
caption,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class LAIONEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
"""Load caption queries from evaluation file"""
captions = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if line.strip():
query_data = json.loads(line)
captions.append(query_data["query"])
print(f"📊 Loaded {len(captions)} caption queries")
return captions
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
print("📏 Analyzing index sizes (.index only)...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes: dict[str, float] = {}
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
# Core index size (.index only)
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
sizes["index_only_mb"] = index_mb
# Other files for reference (not counted in index_only_mb)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {index_mb:.1f} MB")
if sizes["metadata_mb"]:
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
print(
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
)
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Load CLIP embeddings
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
embeddings = np.load(embeddings_file)
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
distance_metric="cosine",
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact", "distance_metric"]
},
)
# Prepare ids and add passages
ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
# Ensure metadata contains the id used by the vector index
metadata = {**data.get("metadata", {}), "id": data["id"]}
builder.add_text(text=data["text"], metadata=metadata)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
# Test queries
test_queries = test_captions[:5]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for caption in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for caption in test_queries:
start_time = time.time()
_ = compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
print("=" * 60)
# Index comparison analysis (prefer .index-only view if present)
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
if "index_only_mb" in current and "index_only_mb" in non_compact:
print("\n📏 Index Comparison Analysis (.index only):")
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
# Show excluded components for reference if available
if any(
k in non_compact
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
):
print(" (passages excluded in totals, shown for reference):")
print(
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
)
else:
# Fallback to legacy totals if running with older metrics
print("\n📏 Index Comparison Analysis:")
print(
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
)
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend",
choices=["hf"],
default="hf",
help="LLM backend (Qwen2.5-VL only supports HF)",
)
parser.add_argument(
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
)
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_laion.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load caption queries for testing
laion_evaluator = LAIONEvaluator(args.index)
captions = laion_evaluator.load_queries(args.queries)
# Test with queries for robust measurement
test_captions = captions[:100] # Use subset for speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search
print("🏗️ Creating non-compact index for binary search...")
evaluator = LAIONEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load caption queries for testing
captions = evaluator.load_queries(args.queries)
# Use subset for robust measurement
test_captions = captions[:50] # Smaller subset for binary search speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Binary search for 90% recall complexity
target_recall = 0.9
min_complexity, max_complexity = 1, 128
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_captions[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Index comparison (without LLM generation)
print("🚀 Starting Stage 4: Index comparison analysis")
# Use LAION evaluator for index comparison
evaluator = LAIONEvaluator(args.index)
# Load caption queries
captions = evaluator.load_queries(args.queries)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes (.index only)
print("\n📊 Index size comparison (.index only):")
print(
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
)
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
storage_saving = 0.0
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
storage_saving = (
(
non_compact_size_metrics.get("index_only_mb", 0)
- compact_size_metrics.get("index_only_mb", 0)
)
/ non_compact_size_metrics.get("index_only_mb", 1)
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for index comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, captions[:10], complexity=complexity
)
# Combine all metrics
combined_metrics = {
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print comprehensive results
evaluator._print_results(combined_metrics)
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
evaluator = LAIONEvaluator(args.index)
captions = evaluator.load_queries(args.queries)
test_captions = captions[: min(20, len(captions))] # Use subset for generation
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
# Load Qwen2.5-VL model
try:
print("Loading Qwen2.5-VL model...")
processor, model = load_qwen_vl_model(args.model_name)
# Run multimodal generation evaluation
complexity = args.complexity or 64
gen_results = evaluate_multimodal_rag(
evaluator.searcher,
test_captions,
processor=processor,
model=model,
complexity=complexity,
)
print("\n📊 Multimodal Generation Results:")
print(f" Total Queries: {len(test_captions)}")
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
search_pct = (gen_results["avg_search_time"] / total_time) * 100
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
print(" LLM Backend: HuggingFace transformers")
print(f" Model: {args.model_name}")
# Show sample results
print("\n📝 Sample Multimodal Generations:")
for i, response in enumerate(gen_results["results"][:3]):
# Handle both string and dict formats for captions
if isinstance(test_captions[i], dict):
caption_text = test_captions[i].get("query", str(test_captions[i]))
else:
caption_text = str(test_captions[i])
print(f" Query {i + 1}: {caption_text[:60]}...")
print(f" Response {i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Multimodal generation evaluation failed: {e}")
print("💡 Make sure transformers and Qwen2.5-VL are installed")
import traceback
traceback.print_exc()
evaluator.cleanup()
print("✅ Stage 5 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Index comparison analysis completed")
print(" Stage 5: ✅ Multimodal generation evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review index comparison for storage vs performance tradeoffs")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
import traceback
traceback.print_exc()
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,576 +0,0 @@
"""
LAION Multimodal Benchmark Setup Script
Downloads LAION subset and builds LEANN index with sentence embeddings
"""
import argparse
import asyncio
import io
import json
import os
import pickle
import time
from pathlib import Path
import aiohttp
import numpy as np
from datasets import load_dataset
from leann import LeannBuilder
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
class LAIONSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.images_dir = self.data_dir / "laion_images"
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
# Create directories
self.data_dir.mkdir(exist_ok=True)
self.images_dir.mkdir(exist_ok=True)
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
"""Download a single image asynchronously"""
async with semaphore: # Limit concurrent downloads
try:
image_url = sample_data["url"]
image_path = sample_data["image_path"]
# Skip if already exists
if os.path.exists(image_path):
progress_bar.update(1)
return sample_data
async with session.get(image_url, timeout=10) as response:
if response.status == 200:
content = await response.read()
# Verify it's a valid image
try:
img = Image.open(io.BytesIO(content))
img = img.convert("RGB")
img.save(image_path, "JPEG")
progress_bar.update(1)
return sample_data
except Exception:
progress_bar.update(1)
return None # Skip invalid images
else:
progress_bar.update(1)
return None
except Exception:
progress_bar.update(1)
return None
def download_laion_subset(self, num_samples: int = 1000):
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
# Load LAION-400M subset from HuggingFace
print("🤗 Loading from HuggingFace datasets...")
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
# Collect sample metadata first (fast)
print("📋 Collecting sample metadata...")
candidates = []
for sample in dataset:
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
break
image_url = sample.get("url", "")
caption = sample.get("caption", "")
if not image_url or not caption:
continue
image_filename = f"laion_{len(candidates):06d}.jpg"
image_path = self.images_dir / image_filename
candidate = {
"id": f"laion_{len(candidates):06d}",
"url": image_url,
"caption": caption,
"image_path": str(image_path),
"width": sample.get("original_width", 512),
"height": sample.get("original_height", 512),
"similarity": sample.get("similarity", 0.0),
}
candidates.append(candidate)
print(
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
)
# Download images in parallel
async def download_batch():
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
timeout = aiohttp.ClientTimeout(total=30)
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
tasks = []
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
task = self.download_single_image(session, candidate, semaphore, progress_bar)
tasks.append(task)
# Wait for all downloads
results = await asyncio.gather(*tasks, return_exceptions=True)
progress_bar.close()
# Filter successful downloads
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
return successful[:num_samples]
# Run async download
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
samples = loop.run_until_complete(download_batch())
finally:
loop.close()
# Save metadata
with open(self.metadata_file, "w", encoding="utf-8") as f:
for sample in samples:
f.write(json.dumps(sample) + "\n")
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
return samples
def generate_clip_image_embeddings(self, samples: list[dict]):
"""Generate CLIP image embeddings for downloaded images"""
print("🔍 Generating CLIP image embeddings...")
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
# This single model can encode both images and text.
model = SentenceTransformer("clip-ViT-L-14")
embeddings = []
valid_samples = []
for sample in tqdm(samples, desc="Processing images"):
try:
# Load image
image_path = sample["image_path"]
image = Image.open(image_path).convert("RGB")
# Encode image to 768-dim embedding via sentence-transformers (normalized)
vec = model.encode(
[image],
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=1,
show_progress_bar=False,
)[0]
embeddings.append(vec.astype(np.float32))
valid_samples.append(sample)
except Exception as e:
print(f" ⚠️ Failed to process {sample['id']}: {e}")
# Skip invalid images
embeddings = np.array(embeddings, dtype=np.float32)
# Save embeddings
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
np.save(embeddings_file, embeddings)
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
return embeddings, valid_samples
def build_faiss_baseline(
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
):
"""Build FAISS flat baseline using CLIP image embeddings"""
print("🔨 Building FAISS Flat baseline...")
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Extract image IDs (must be present)
if not samples or "id" not in samples[0]:
raise KeyError("samples missing 'id' field for FAISS baseline")
image_ids: list[str] = [str(sample["id"]) for sample in samples]
print(f"📐 Embedding shape: {embeddings.shape}")
print(f"📄 Processing {len(image_ids)} images")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(image_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def create_leann_passages(self, samples: list[dict]):
"""Create LEANN-compatible passages from LAION data"""
print("📝 Creating LEANN passages...")
passages_file = self.data_dir / "laion_passages.jsonl"
with open(passages_file, "w", encoding="utf-8") as f:
for i, sample in enumerate(samples):
passage = {
"id": sample["id"],
"text": sample["caption"], # Use caption as searchable text
"metadata": {
"image_url": sample["url"],
"image_path": sample.get("image_path", ""),
"width": sample["width"],
"height": sample["height"],
"similarity": sample["similarity"],
"image_index": i, # Index for embedding lookup
},
}
f.write(json.dumps(passage) + "\n")
print(f"✅ Created {len(samples)} passages")
return passages_file
def build_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
print(f"🏗️ Building compact LEANN index with {backend} backend...")
start_time = time.time()
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
npy_path = self.data_dir / "clip_image_embeddings.npy"
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in the same order as passages_file (matches embeddings order)
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - compact with recompute
# Note: For multimodal case, we need to handle embeddings differently
# Let's try using sentence-transformers mode but with custom embeddings
builder = LeannBuilder(
backend_name=backend,
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
# HNSW params (or forwarded to chosen backend)
graph_degree=32,
complexity=64,
# Compact/pruned with recompute at query time
is_recompute=True,
is_compact=True,
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
num_threads=4,
)
# Add passages (text + metadata)
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def build_non_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
start_time = time.time()
# Ensure embeddings are saved (npy + pickle)
npy_path = self.data_dir / "clip_image_embeddings.npy"
if not npy_path.exists():
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in same order as passages_file
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
if not pkl_path.exists():
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - non-compact without recompute
builder = LeannBuilder(
backend_name=backend,
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=False, # Store embeddings (no recompute needed)
is_compact=False, # Store full index (not pruned)
distance_metric="cosine",
num_threads=4,
)
# Add passages - embeddings will be loaded from file
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Non-compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
"""Helper to add passages with pre-computed CLIP embeddings"""
with open(passages_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Adding passages"):
if line.strip():
passage = json.loads(line)
# Add image metadata - LEANN will handle embeddings separately
# Note: We store image metadata and caption text for searchability
# Important: ensure passage ID in metadata matches vector ID
builder.add_text(
text=passage["text"], # Image caption for searchability
metadata={**passage["metadata"], "id": passage["id"]},
)
def _analyze_index_size(self, index_path: str):
"""Analyze index file sizes"""
print("📏 Analyzing index sizes...")
index_path = Path(index_path)
index_dir = index_path.parent
index_name = index_path.name # e.g., laion_index.leann
index_prefix = index_path.stem # e.g., laion_index
files = [
(f"{index_prefix}.index", ".index", "core"),
(f"{index_name}.meta.json", ".meta.json", "core"),
(f"{index_name}.ids.txt", ".ids.txt", "core"),
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
]
def _fmt_size(bytes_val: int) -> str:
if bytes_val < 1024:
return f"{bytes_val} B"
kb = bytes_val / 1024
if kb < 1024:
return f"{kb:.1f} KB"
mb = kb / 1024
if mb < 1024:
return f"{mb:.2f} MB"
gb = mb / 1024
return f"{gb:.2f} GB"
total_index_only_mb = 0.0
total_all_mb = 0.0
for filename, label, group in files:
file_path = index_dir / filename
if file_path.exists():
size_bytes = file_path.stat().st_size
print(f" {label}: {_fmt_size(size_bytes)}")
size_mb = size_bytes / (1024 * 1024)
total_all_mb += size_mb
if group == "core":
total_index_only_mb += size_mb
else:
print(f" {label}: (missing)")
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
print(f" Total (including passages): {total_all_mb:.2f} MB")
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
"""Create evaluation queries from captions"""
print(f"📝 Creating {num_queries} evaluation queries...")
# Sample random captions as queries
import random
random.seed(42) # For reproducibility
query_samples = random.sample(samples, min(num_queries, len(samples)))
queries_file = self.data_dir / "evaluation_queries.jsonl"
with open(queries_file, "w", encoding="utf-8") as f:
for sample in query_samples:
query = {
"id": sample["id"],
"query": sample["caption"],
"ground_truth_id": sample["id"], # For potential recall evaluation
}
f.write(json.dumps(query) + "\n")
print(f"✅ Created {len(query_samples)} evaluation queries")
return queries_file
def main():
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
parser.add_argument(
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
)
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
args = parser.parse_args()
print("🚀 Setting up LAION Multimodal Benchmark")
print("=" * 50)
try:
# Initialize setup
setup = LAIONSetup(args.data_dir)
# Step 1: Download LAION subset
if not args.skip_download:
print("\n📦 Step 1: Download LAION subset")
samples = setup.download_laion_subset(args.num_samples)
# Step 2: Generate CLIP image embeddings
print("\n🔍 Step 2: Generate CLIP image embeddings")
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
# Step 3: Create LEANN passages (image metadata with embeddings)
print("\n📝 Step 3: Create LEANN passages")
passages_file = setup.create_leann_passages(valid_samples)
else:
print("⏭️ Skipping LAION dataset download")
# Load existing data
passages_file = setup.data_dir / "laion_passages.jsonl"
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
if not passages_file.exists() or not embeddings_file.exists():
raise FileNotFoundError(
"Passages or embeddings file not found. Run without --skip-download first."
)
embeddings = np.load(embeddings_file)
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
# Step 4: Build LEANN indexes (both compact and non-compact)
if not args.skip_build:
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
# Build compact index (production mode - small, recompute required)
compact_index_path = args.index_path
print(f"Building compact index: {compact_index_path}")
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
# Build non-compact index (comparison mode - large, fast search)
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
print(f"Building non-compact index: {non_compact_index_path}")
setup.build_non_compact_index(
passages_file, embeddings, non_compact_index_path, args.backend
)
# Step 5: Build FAISS flat baseline
print("\n🔨 Step 5: Build FAISS flat baseline")
if not args.skip_download:
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
else:
# Load valid_samples from passages file for FAISS baseline
valid_samples = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
passage = json.loads(line)
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
# Step 6: Create evaluation queries
print("\n📝 Step 6: Create evaluation queries")
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
else:
print("⏭️ Skipping index building")
baseline_path = "data/baseline/faiss_index.bin"
queries_file = setup.data_dir / "evaluation_queries.jsonl"
print("\n🎉 Setup completed successfully!")
print("📊 Summary:")
if not args.skip_download:
print(f" Downloaded samples: {len(samples)}")
print(f" Valid samples with embeddings: {len(valid_samples)}")
else:
print(f" Loaded {len(embeddings)} embeddings")
if not args.skip_build:
print(f" Compact index: {compact_index_path}")
print(f" Non-compact index: {non_compact_index_path}")
print(f" FAISS baseline: {baseline_path}")
print(f" Queries: {queries_file}")
print("\n🔧 Next steps:")
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
else:
print(" Skipped building indexes")
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,342 +0,0 @@
"""
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
"""
import time
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
def is_qwen3_model(model_name):
"""Check if model is Qwen3"""
return "Qwen3" in model_name or "qwen3" in model_name.lower()
def is_qwen_vl_model(model_name):
"""Check if model is Qwen2.5-VL"""
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
def apply_qwen3_chat_template(tokenizer, prompt):
"""Apply Qwen3 chat template with thinking enabled"""
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
def extract_thinking_answer(response):
"""Extract final answer from Qwen3 thinking model response"""
if "<think>" in response and "</think>" in response:
try:
think_end = response.index("</think>") + len("</think>")
final_answer = response[think_end:].strip()
return final_answer
except (ValueError, IndexError):
pass
return response.strip()
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
"""Load HuggingFace model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading HF: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return tokenizer, model
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
"""Load vLLM model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not VLLM_AVAILABLE:
raise ImportError("vllm not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading vLLM: {model_name}")
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
# Qwen3 specific config
if is_qwen3_model(model_name):
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
max_tokens = 2048
else:
stop_tokens = None
max_tokens = 1024
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
return llm, sampling_params
def generate_hf(tokenizer, model, prompt, max_tokens=None):
"""Generate with HF - supports Qwen3 thinking models"""
model_name = getattr(model, "name_or_path", "unknown")
is_qwen3 = is_qwen3_model(model_name)
# Apply chat template for Qwen3
if is_qwen3:
prompt = apply_qwen3_chat_template(tokenizer, prompt)
max_tokens = max_tokens or 2048
else:
max_tokens = max_tokens or 1024
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt) :].strip()
# Extract final answer for thinking models
if is_qwen3:
return extract_thinking_answer(response)
return response
def generate_vllm(llm, sampling_params, prompt):
"""Generate with vLLM - supports Qwen3 thinking models"""
outputs = llm.generate([prompt], sampling_params)
response = outputs[0].outputs[0].text.strip()
# Extract final answer for Qwen3 thinking models
model_name = str(llm.llm_engine.model_config.model)
if is_qwen3_model(model_name):
return extract_thinking_answer(response)
return response
def create_prompt(context, query, domain="default"):
"""Create RAG prompt"""
if domain == "emails":
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "finance":
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "multimodal":
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
else:
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
"""Simple RAG evaluation with timing"""
search_times = []
gen_times = []
results = []
for i, query in enumerate(queries):
# Search
start = time.time()
docs = searcher.search(query, top_k=top_k, complexity=complexity)
search_time = time.time() - start
# Generate
context = "\n\n".join([doc.text for doc in docs])
prompt = create_prompt(context, query, domain)
start = time.time()
response = llm_func(prompt)
gen_time = time.time() - start
search_times.append(search_time)
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
"""Load Qwen2.5-VL multimodal model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading Qwen2.5-VL: {model_name}")
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModelForVision2Seq.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return processor, model
except Exception as e:
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
# Fallback to specific class
try:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return processor, model
except Exception as e2:
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
"""Generate with Qwen2.5-VL multimodal model"""
from PIL import Image
# Prepare inputs
if image_path:
image = Image.open(image_path)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
)
# Decode response
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
response = processor.decode(generated_ids[0], skip_special_tokens=True)
return response
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
"""Create prompt for multimodal RAG"""
if task_type == "images":
return f"""Based on the retrieved images and their descriptions, answer the following question.
Retrieved Image Descriptions:
{context}
Question: {query}
Provide a detailed answer based on the visual content described above."""
return f"Context: {context}\nQuestion: {query}\nAnswer:"
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
"""Evaluate multimodal RAG with Qwen2.5-VL"""
search_times = []
gen_times = []
results = []
for i, query_item in enumerate(queries):
# Handle both string and dict formats for queries
if isinstance(query_item, dict):
query = query_item.get("query", "")
image_path = query_item.get("image_path") # Optional reference image
else:
query = str(query_item)
image_path = None
# Search
start_time = time.time()
search_results = searcher.search(query, top_k=3, complexity=complexity)
search_time = time.time() - start_time
search_times.append(search_time)
# Prepare context from search results
context_parts = []
for result in search_results:
context_parts.append(f"- {result.text}")
context = "\n".join(context_parts)
# Generate with multimodal model
start_time = time.time()
if processor and model:
prompt = create_multimodal_prompt(context, query, context_parts)
response = generate_qwen_vl(processor, model, prompt, image_path)
else:
response = f"Context: {context}"
gen_time = time.time() - start_time
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}

View File

@@ -12,7 +12,7 @@ import time
from pathlib import Path
import numpy as np
from leann.api import LeannBuilder, LeannChat, LeannSearcher
from leann.api import LeannBuilder, LeannSearcher
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv sync --only-group dev")
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
@@ -197,25 +197,6 @@ def main():
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Batch size for HNSW batched search (0 disables batching)",
)
parser.add_argument(
"--llm-type",
type=str,
choices=["ollama", "hf", "openai", "gemini", "simulated"],
default="ollama",
help="LLM backend type to optionally query during evaluation (default: ollama)",
)
parser.add_argument(
"--llm-model",
type=str,
default="qwen3:1.7b",
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
)
args = parser.parse_args()
# --- Path Configuration ---
@@ -337,24 +318,9 @@ def main():
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Optional: also call the LLM with configurable backend/model (does not affect recall)
llm_config = {"type": args.llm_type, "model": args.llm_model}
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
answer = chat.ask(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
print(f"Answer: {answer}")
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}

View File

@@ -20,7 +20,7 @@ except ImportError:
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever-msmarco"
model_path: str = "facebook/contriever"
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
@@ -34,7 +34,7 @@ class BenchmarkConfig:
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
class MLXBenchmark:
@@ -179,16 +179,10 @@ class Benchmark:
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
# print shape of input_ids and attention_mask
print(f"input_ids shape: {input_ids.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
start_time = time.time()
with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask)
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.synchronize()
end_time = time.time()
return end_time - start_time

View File

@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
### Setup Pre-commit
1. **Install pre-commit tools**:
1. **Install pre-commit** (already included when you run `uv sync`):
```bash
uv sync lint
uv pip install pre-commit
```
2. **Install the git hooks**:
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
3. **Run pre-commit manually** (optional):
```bash
uv run pre-commit run --all-files
pre-commit run --all-files
```
### Pre-commit Checks
@@ -85,9 +85,6 @@ Our pre-commit configuration includes:
### Running Tests
```bash
# Install test tools only (no project runtime)
uv sync --group test
# Run all tests
uv run pytest

View File

@@ -1,143 +0,0 @@
# AST-Aware Code chunking guide
## Overview
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
## Quick Start
### Basic Usage
```bash
# Enable AST chunking for mixed content (code + docs)
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
# Specialized code repository indexing
python -m apps.code_rag --repo-dir ./my_codebase
# Global CLI with AST support
leann build my-code-index --docs ./src --use-ast-chunking
```
### Installation
```bash
# Install LEANN with AST chunking support
uv pip install -e "."
```
#### For normal users (PyPI install)
- Use `pip install leann` or `uv pip install leann`.
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
#### For developers (from source, editable)
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
## Best Practices
### When to Use AST Chunking
**Recommended for:**
- Code repositories with multiple languages
- Mixed documentation and code content
- Complex codebases with deep function/class hierarchies
- When working with Claude Code for code assistance
**Not recommended for:**
- Pure text documents
- Very large files (>1MB)
- Languages not supported by tree-sitter
### Optimal Configuration
```bash
# Recommended settings for most codebases
python -m apps.code_rag \
--repo-dir ./src \
--ast-chunk-size 768 \
--ast-chunk-overlap 96 \
--exclude-dirs .git __pycache__ node_modules build dist
```
### Supported Languages
| Extension | Language | Status |
|-----------|----------|--------|
| `.py` | Python | ✅ Full support |
| `.java` | Java | ✅ Full support |
| `.cs` | C# | ✅ Full support |
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
## Integration Examples
### Document RAG with Code Support
```python
# Enable code chunking in document RAG
python -m apps.document_rag \
--enable-code-chunking \
--data-dir ./project \
--query "How does authentication work in the codebase?"
```
### Claude Code Integration
When using with Claude Code MCP server, AST chunking provides better context for:
- Code completion and suggestions
- Bug analysis and debugging
- Architecture understanding
- Refactoring assistance
## Troubleshooting
### Common Issues
1. **Fallback to Traditional Chunking**
- Normal behavior for unsupported languages
- Check logs for specific language support
2. **Performance with Large Files**
- Adjust `--max-file-size` parameter
- Use `--exclude-dirs` to skip unnecessary directories
3. **Quality Issues**
- Try different `--ast-chunk-size` values (512, 768, 1024)
- Adjust overlap for better context preservation
### Debug Mode
```bash
export LEANN_LOG_LEVEL=DEBUG
python -m apps.code_rag --repo-dir ./my_code
```
## Migration from Traditional Chunking
Existing workflows continue to work without changes. To enable AST chunking:
```bash
# Before
python -m apps.document_rag --chunk-size 256
# After (maintains traditional chunking for non-code files)
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
```
## References
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
- [Research Paper](https://arxiv.org/html/2506.15655v1)
---
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.

View File

@@ -52,7 +52,7 @@ Based on our experience developing LEANN, embedding models fall into three categ
### Quick Start: Cloud and Local Embedding Options
**OpenAI Embeddings (Fastest Setup)**
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
For immediate testing without local model downloads:
```bash
# Set OpenAI embeddings (requires OPENAI_API_KEY)
--embedding-mode openai --embedding-model text-embedding-3-small
@@ -83,81 +83,6 @@ ollama pull nomic-embed-text
</details>
## Local & Remote Inference Endpoints
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint either on the same machine or across the network with a couple of flags or environment variables.
### One-Time Environment Setup
```bash
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
export OPENAI_BASE_URL="http://localhost:1234/v1"
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
```
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
### Passing Hosts Per Command
```bash
# Build an index with a remote embedding server
leann build my-notes \
--docs ./notes \
--embedding-mode openai \
--embedding-model text-embedding-qwen3-embedding-0.6b \
--embedding-api-base http://192.168.1.50:1234/v1 \
--embedding-api-key local-dev-key
# Query using a local LM Studio instance via OpenAI-compatible API
leann ask my-notes \
--llm openai \
--llm-model qwen3-8b \
--api-base http://localhost:1234/v1 \
--api-key local-dev-key
# Query an Ollama instance running on another box
leann ask my-notes \
--llm ollama \
--llm-model qwen3:14b \
--host http://192.168.1.101:11434
```
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
- Configure router or cloud provider port forwarding.
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
**Tip:** If your runtime does not require an API key (many local stacks dont), leave `--api-key` unset. LEANN will skip injecting credentials.
### Python API Usage
You can pass the same configuration from Python:
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
backend_name="hnsw",
embedding_mode="openai",
embedding_model="text-embedding-qwen3-embedding-0.6b",
embedding_options={
"base_url": "http://192.168.1.50:1234/v1",
"api_key": "local-dev-key",
},
)
builder.build_index("./indexes/my-notes", chunks)
```
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)
@@ -172,24 +97,16 @@ builder.build_index("./indexes/my-notes", chunks)
```
### DiskANN
**Best for**: Large datasets, especially when you want `recompute=True`.
**Key advantages:**
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
- **Better scaling**: Designed for 100k+ documents
**Recompute behavior:**
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
```bash
# Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64
# For billion-scale deployments
--backend-name diskann --graph-degree 64 --build-complexity 128
```
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison
### LLM Engines
@@ -342,118 +259,27 @@ Every configuration choice involves trade-offs:
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Low-resource setups
## Deep Dive: Critical Configuration Decisions
If you dont have a local GPU or builds/searches are too slow, use one or more of the options below.
### When to Disable Recomputation
### 1) Use OpenAI embeddings (no local compute)
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
```bash
export OPENAI_API_KEY=sk-...
# Build with OpenAI embeddings
leann build my-index \
--embedding-mode openai \
--embedding-model text-embedding-3-small
# Search with OpenAI embeddings (recompute at query time)
leann search my-index "your query" \
--recompute
--no-recompute # Disable selective recomputation
```
### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
```bash
# One-time: install and configure SkyPilot
pip install skypilot
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
sky launch -c leann-gpu sky/leann-build.yaml
# Override parameters via -e key=value (optional)
sky launch -c leann-gpu sky/leann-build.yaml \
-e index_name=my-index \
-e backend=hnsw \
-e embedding_mode=sentence-transformers \
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
# Copy the built index back to your local .leann (use rsync)
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
```
### 3) Disable recomputation to trade storage for speed
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
```bash
# Build without recomputation (HNSW requires non-compact in this mode)
leann build my-index --no-recompute --no-compact
# Search without recomputation
leann search my-index "your query" --no-recompute
```
When to use:
- Extreme low latency requirements (high QPS, interactive assistants)
- Read-heavy workloads where storage is cheaper than latency
- No always-available GPU
Constraints:
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
- DiskANN: supported; `--no-recompute` skips selective recompute during search
Storage impact:
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
Converting an existing index (rebuild required):
```bash
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
leann build my-index --force --no-recompute --no-compact
```
Python API usage:
```python
from leann import LeannSearcher
searcher = LeannSearcher("/path/to/my-index.leann")
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
```
Trade-offs:
- Lower latency and fewer network hops at query time
- Significantly higher storage (10100× vs selective recomputation)
- Slightly larger memory footprint during build and search
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
- HNSW
```text
recompute=True: search_time=0.818s, size=1.1MB
recompute=False: search_time=0.012s, size=16.6MB
```
- DiskANN
```text
recompute=True: search_time=0.041s, size=5.9MB
recompute=False: search_time=0.013s, size=24.6MB
```
Conclusion:
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
**Trade-offs**:
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
**Disable when**:
- You have abundant storage and memory
- Need extremely low latency (< 100ms)
- Running a read-heavy workload where storage cost is acceptable
## Further Reading
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)

View File

@@ -3,7 +3,6 @@
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments

View File

@@ -1,149 +0,0 @@
# LEANN Grep Search Usage Guide
## Overview
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
## Basic Usage
### Simple Grep Search
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("your_index_path")
# Exact text search
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:100]}...")
print("-" * 40)
```
### Comparison: Semantic vs Grep Search
```python
# Semantic search - finds conceptually similar content
semantic_results = searcher.search("machine learning algorithms", top_k=3)
# Grep search - finds exact text matches
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
```
## When to Use Grep Search
### Use Cases
- **Code Search**: Finding specific function definitions, class names, or variable references
- **Error Debugging**: Locating exact error messages or stack traces
- **Documentation**: Finding specific API endpoints or exact terminology
### Examples
```python
# Find function definitions
functions = searcher.search("def __init__", use_grep=True)
# Find import statements
imports = searcher.search("from sklearn import", use_grep=True)
# Find specific error types
errors = searcher.search("FileNotFoundError", use_grep=True)
# Find TODO comments
todos = searcher.search("TODO:", use_grep=True)
# Find configuration entries
configs = searcher.search("server_port=", use_grep=True)
```
## Technical Details
### How It Works
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
3. **Result Processing**: Parses JSON lines and extracts text and metadata
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
### Search Process
```
Query: "def train_model"
grep -i -n "def train_model" documents.leann.passages.jsonl
Parse matching JSON lines
Calculate scores based on term frequency
Return top_k results
```
### Scoring Algorithm
```python
# Term frequency in document
score = text.lower().count(query.lower())
```
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
## Error Handling
### Common Issues
#### Grep Command Not Found
```
RuntimeError: grep command not found. Please install grep or use semantic search.
```
**Solution**: Install grep on your system:
- **Ubuntu/Debian**: `sudo apt-get install grep`
- **macOS**: grep is pre-installed
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
#### No Results Found
```python
# Check if your query exists in the raw data
results = searcher.search("your_query", use_grep=True)
if not results:
print("No exact matches found. Try:")
print("1. Check spelling and case")
print("2. Use partial terms")
print("3. Switch to semantic search")
```
## Complete Example
```python
#!/usr/bin/env python3
"""
Grep Search Example
Demonstrates grep search for exact text matching.
"""
from leann.api import LeannSearcher
def demonstrate_grep_search():
# Initialize searcher
searcher = LeannSearcher("my_index")
print("=== Function Search ===")
functions = searcher.search("def __init__", use_grep=True, top_k=5)
for i, result in enumerate(functions, 1):
print(f"{i}. Score: {result.score}")
print(f" Preview: {result.text[:60]}...")
print()
print("=== Error Search ===")
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
for result in errors:
print(f"Content: {result.text.strip()}")
print("-" * 40)
if __name__ == "__main__":
demonstrate_grep_search()
```

View File

@@ -1,300 +0,0 @@
# LEANN Metadata Filtering Usage Guide
## Overview
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
## Basic Usage
### Adding Metadata to Your Documents
When building your index, add metadata to each text chunk:
```python
from leann.api import LeannBuilder
builder = LeannBuilder("hnsw")
# Add text with metadata
builder.add_text(
text="Chapter 1: Alice falls down the rabbit hole",
metadata={
"chapter": 1,
"character": "Alice",
"themes": ["adventure", "curiosity"],
"word_count": 150
}
)
builder.build_index("alice_in_wonderland_index")
```
### Searching with Metadata Filters
Use the `metadata_filters` parameter in search calls:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("alice_in_wonderland_index")
# Search with filters
results = searcher.search(
query="What happens to Alice?",
top_k=10,
metadata_filters={
"chapter": {"<=": 5}, # Only chapters 1-5
"spoiler_level": {"!=": "high"} # No high spoilers
}
)
```
## Filter Syntax
### Basic Structure
```python
metadata_filters = {
"field_name": {"operator": value},
"another_field": {"operator": value}
}
```
### Supported Operators
#### Comparison Operators
- `"=="`: Equal to
- `"!="`: Not equal to
- `"<"`: Less than
- `"<="`: Less than or equal
- `">"`: Greater than
- `">="`: Greater than or equal
```python
# Examples
{"chapter": {"==": 1}} # Exactly chapter 1
{"page": {">": 100}} # Pages after 100
{"rating": {">=": 4.0}} # Rating 4.0 or higher
{"word_count": {"<": 500}} # Short passages
```
#### Membership Operators
- `"in"`: Value is in list
- `"not_in"`: Value is not in list
```python
# Examples
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
```
#### String Operators
- `"contains"`: String contains substring
- `"starts_with"`: String starts with prefix
- `"ends_with"`: String ends with suffix
```python
# Examples
{"title": {"contains": "alice"}} # Title contains "alice"
{"filename": {"ends_with": ".py"}} # Python files
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
```
#### Boolean Operators
- `"is_true"`: Field is truthy
- `"is_false"`: Field is falsy
```python
# Examples
{"is_published": {"is_true": True}} # Published content
{"is_draft": {"is_false": False}} # Not drafts
```
### Multiple Operators on Same Field
You can apply multiple operators to the same field (AND logic):
```python
metadata_filters = {
"word_count": {
">=": 100, # At least 100 words
"<=": 500 # At most 500 words
}
}
```
### Compound Filters
Multiple fields are combined with AND logic:
```python
metadata_filters = {
"chapter": {"<=": 10}, # Up to chapter 10
"character": {"==": "Alice"}, # About Alice
"spoiler_level": {"!=": "high"} # No major spoilers
}
```
## Use Case Examples
### 1. Spoiler-Free Book Search
```python
# Reader has only read up to chapter 5
def search_spoiler_free(query, max_chapter):
return searcher.search(
query=query,
metadata_filters={
"chapter": {"<=": max_chapter},
"spoiler_level": {"in": ["none", "low"]}
}
)
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
```
### 2. Document Management by Date
```python
# Find recent documents
recent_docs = searcher.search(
query="project updates",
metadata_filters={
"date": {">=": "2024-01-01"},
"document_type": {"==": "report"}
}
)
```
### 3. Code Search by File Type
```python
# Search only Python files
python_code = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
### 4. Content Filtering by Audience
```python
# Age-appropriate content
family_content = searcher.search(
query="adventure stories",
metadata_filters={
"age_rating": {"in": ["G", "PG"]},
"content_warnings": {"not_in": ["violence", "adult_themes"]}
}
)
```
### 5. Multi-Book Series Management
```python
# Search across first 3 books only
early_series = searcher.search(
query="character development",
metadata_filters={
"series": {"==": "Harry Potter"},
"book_number": {"<=": 3}
}
)
```
## Running the Example
You can see metadata filtering in action with our spoiler-free book RAG example:
```bash
# Don't forget to set up the environment
uv venv
source .venv/bin/activate
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
export OPENAI_API_KEY="your-api-key-here"
# Run the spoiler-free book RAG example
uv run examples/spoiler_free_book_rag.py
```
This example demonstrates:
- Building an index with metadata (chapter numbers, characters, themes, locations)
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
- Different scenarios for readers at various points in the book
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
## Advanced Patterns
### Custom Chunking with metadata
```python
def chunk_book_with_metadata(book_text, book_info):
chunks = []
for chapter_num, chapter_text in parse_chapters(book_text):
# Extract entities, themes, etc.
characters = extract_characters(chapter_text)
themes = classify_themes(chapter_text)
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
# Create chunks with rich metadata
for paragraph in split_paragraphs(chapter_text):
chunks.append({
"text": paragraph,
"metadata": {
"book_title": book_info["title"],
"chapter": chapter_num,
"characters": characters,
"themes": themes,
"spoiler_level": spoiler_level,
"word_count": len(paragraph.split()),
"reading_level": calculate_reading_level(paragraph)
}
})
return chunks
```
## Performance Considerations
### Efficient Filtering Strategies
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
### Best Practices
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
### Adding Metadata to Existing Indices
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
```python
# Read existing passages and add metadata
def add_metadata_to_existing_chunks(chunks):
for chunk in chunks:
# Extract or assign metadata based on content
chunk["metadata"] = extract_metadata(chunk["text"])
return chunks
# Rebuild index with metadata
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
builder = LeannBuilder("hnsw")
for chunk in enhanced_chunks:
builder.add_text(chunk["text"], chunk["metadata"])
builder.build_index("enhanced_index")
```

View File

View File

@@ -1,429 +0,0 @@
"""Dynamic HNSW update demo without compact storage.
This script reproduces the minimal scenario we used while debugging on-the-fly
recompute:
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
2. Print the top results with `recompute_embeddings=True`.
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
4. Run the same query again to show the newly inserted passages.
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
ZMQ activity)::
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
uv run -m examples.dynamic_update_no_recompute \
--index-path .leann/examples/leann-demo.leann
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
It issues the query "What's LEANN?" before and after the update to show how the
new passages become immediately searchable. The script uses the
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
freshly added passages are embedded locally just like the initial build.
To make storage comparisons easy, the script can also build a matching
``is_recompute=False`` baseline (enabled by default) and report the index size
delta after the update. Disable the baseline run with
``--skip-compare-no-recompute`` if you only need the recompute flow.
"""
import argparse
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from leann.api import LeannBuilder, LeannSearcher
from leann.registry import register_project_directory
from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
REPO_ROOT / "data" / "PrideandPrejudice.txt",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path]) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
return [c for c in chunks if isinstance(c, str) and c.strip()]
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
searcher = LeannSearcher(str(index_path))
try:
return searcher.search(
query=query,
top_k=top_k,
recompute_embeddings=recompute_embeddings,
batch_size=16,
)
finally:
searcher.cleanup()
def print_results(title: str, results: Iterable) -> None:
print(f"\n=== {title} ===")
res_list = list(results)
print(f"results count: {len(res_list)}")
print("passages:")
if not res_list:
print(" (no passages returned)")
for res in res_list:
snippet = res.text.replace("\n", " ")[:120]
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def update_index(
index_path: Path,
start_id: int,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
updater = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for offset, passage in enumerate(paragraphs, start=start_id):
updater.add_text(passage, metadata={"id": str(offset)})
updater.update_index(str(index_path))
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
"""Remove leftover index artifacts for a clean rebuild."""
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def index_file_size(index_path: Path) -> int:
"""Return the size of the primary .index file for the given index path."""
index_file = index_path.parent / f"{index_path.stem}.index"
return index_file.stat().st_size if index_file.exists() else 0
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
if not meta_path.exists():
return None
try:
return json.loads(meta_path.read_text())
except json.JSONDecodeError:
return None
def run_workflow(
*,
label: str,
index_path: Path,
initial_paragraphs: list[str],
update_paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
query: str,
top_k: int,
skip_search: bool,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
ensure_index_dir(index_path)
cleanup_index_files(index_path)
print(f"{prefix}Building initial index...")
build_initial_index(
index_path,
initial_paragraphs,
model_name,
embedding_mode,
is_recompute=is_recompute,
)
initial_size = index_file_size(index_path)
if not skip_search:
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
before_results = None
print(f"\n{prefix}Updating index with additional passages...")
update_index(
index_path,
start_id=len(initial_paragraphs),
paragraphs=update_paragraphs,
model_name=model_name,
embedding_mode=embedding_mode,
is_recompute=is_recompute,
)
if not skip_search:
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
else:
after_results = None
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results if not skip_search else None,
"after_results": after_results if not skip_search else None,
"metadata": load_metadata_snapshot(index_path),
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--initial-files",
type=Path,
nargs="+",
default=DEFAULT_INITIAL_FILES,
help="Initial document files (PDF/TXT) used to build the base index",
)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/examples/leann-demo.leann"),
help="Destination index path (default: .leann/examples/leann-demo.leann)",
)
parser.add_argument(
"--initial-count",
type=int,
default=8,
help="Number of chunks to use from the initial documents (default: 8)",
)
parser.add_argument(
"--update-files",
type=Path,
nargs="*",
default=DEFAULT_UPDATE_FILES,
help="Additional documents to add during update (PDF/TXT)",
)
parser.add_argument(
"--update-count",
type=int,
default=4,
help="Number of chunks to append from update documents (default: 4)",
)
parser.add_argument(
"--update-text",
type=str,
default=(
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
"on demand to keep disk usage minimal."
),
help="Fallback text to append if --update-files is omitted",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
help="Number of results to show for each search (default: 4)",
)
parser.add_argument(
"--query",
type=str,
default=DEFAULT_QUERY,
help="Query to run before/after the update",
)
parser.add_argument(
"--embedding-model",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--compare-no-recompute",
dest="compare_no_recompute",
action="store_true",
help="Also run a baseline with is_recompute=False and report its index growth.",
)
parser.add_argument(
"--skip-compare-no-recompute",
dest="compare_no_recompute",
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.add_argument(
"--skip-search",
dest="skip_search",
action="store_true",
help="Skip the search step.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
ensure_index_dir(args.index_path)
register_project_directory(REPO_ROOT)
initial_chunks = load_chunks_from_files(list(args.initial_files))
if not initial_chunks:
raise ValueError("No text chunks extracted from the initial files.")
initial = initial_chunks[: args.initial_count]
if not initial:
raise ValueError("Initial chunk set is empty after applying --initial-count.")
if args.update_files:
update_chunks = load_chunks_from_files(list(args.update_files))
if not update_chunks:
raise ValueError("No text chunks extracted from the update files.")
to_add = update_chunks[: args.update_count]
else:
if not args.update_text:
raise ValueError("Provide --update-files or --update-text for the update step.")
to_add = [args.update_text]
if not to_add:
raise ValueError("Update chunk set is empty after applying --update-count.")
recompute_stats = run_workflow(
label="recompute",
index_path=args.index_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=True,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
if not args.skip_search:
print_results("initial search", recompute_stats["before_results"])
if not args.skip_search:
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
)
if recompute_stats["metadata"]:
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if args.compare_no_recompute:
baseline_path = (
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
)
baseline_stats = run_workflow(
label="no-recompute",
index_path=baseline_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=False,
query=args.query,
top_k=args.top_k,
skip_search=args.skip_search,
)
print(
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
f"{baseline_stats['delta']})"
)
after_texts = (
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
)
baseline_after_texts = (
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
)
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."
)
else:
print("[no-recompute] WARNING: search results differ from recompute baseline.")
if baseline_stats["metadata"]:
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[no-recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if __name__ == "__main__":
main()

View File

@@ -1,35 +0,0 @@
"""
Grep Search Example
Shows how to use grep-based text search instead of semantic search.
Useful when you need exact text matches rather than meaning-based results.
"""
from leann import LeannSearcher
# Load your index
searcher = LeannSearcher("my-documents.leann")
# Regular semantic search
print("=== Semantic Search ===")
results = searcher.search("machine learning algorithms", top_k=3)
for result in results:
print(f"Score: {result.score:.3f}")
print(f"Text: {result.text[:80]}...")
print()
# Grep-based search for exact text matches
print("=== Grep Search ===")
results = searcher.search("def train_model", top_k=3, use_grep=True)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:80]}...")
print()
# Find specific error messages
error_results = searcher.search("FileNotFoundError", use_grep=True)
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
# Search for function definitions
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
print(f"Found {len(func_results)} class definitions")

View File

@@ -1,250 +0,0 @@
#!/usr/bin/env python3
"""
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
This example demonstrates how to use LEANN's metadata filtering to create
a spoiler-free book RAG system where users can search for information
up to a specific chapter they've read.
Usage:
python spoiler_free_book_rag.py
"""
import os
import sys
from typing import Any, Optional
# Add LEANN to path (adjust path as needed)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
from leann.api import LeannBuilder, LeannSearcher
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
"""
Create sample book chunks with metadata for demonstration.
In a real implementation, this would parse actual book files (epub, txt, etc.)
and extract chapter boundaries, character mentions, etc.
Args:
book_title: Title of the book
Returns:
List of chunk dictionaries with text and metadata
"""
# Sample book chunks with metadata
# In practice, you'd use proper text processing libraries
sample_chunks = [
{
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 1,
"characters": ["Alice", "Sister"],
"themes": ["boredom", "curiosity"],
"location": "riverbank",
},
},
{
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 2,
"characters": ["Alice", "White Rabbit"],
"themes": ["decision", "surprise", "magic"],
"location": "riverbank",
},
},
{
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
"metadata": {
"book": book_title,
"chapter": 2,
"page": 15,
"characters": ["Alice"],
"themes": ["falling", "wonder", "transformation"],
"location": "rabbit hole",
},
},
{
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
"metadata": {
"book": book_title,
"chapter": 6,
"page": 85,
"characters": ["Alice", "Cheshire Cat"],
"themes": ["madness", "philosophy", "identity"],
"location": "Duchess's house",
},
},
{
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
"metadata": {
"book": book_title,
"chapter": 8,
"page": 120,
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
"themes": ["justice", "absurdity", "authority"],
"location": "Queen's court",
},
},
{
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
"metadata": {
"book": book_title,
"chapter": 12,
"page": 180,
"characters": ["Alice", "Sister", "Rabbit"],
"themes": ["revelation", "reality", "growth"],
"location": "riverbank",
},
},
]
return sample_chunks
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
"""
Build a LEANN index with book chunks that include spoiler metadata.
Args:
book_chunks: List of book chunks with metadata
index_name: Name for the index
Returns:
Path to the built index
"""
print(f"📚 Building spoiler-free book index: {index_name}")
# Initialize LEANN builder
builder = LeannBuilder(
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
)
# Add each chunk with its metadata
for chunk in book_chunks:
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
# Build the index
index_path = f"{index_name}_book_index"
builder.build_index(index_path)
print(f"✅ Index built successfully: {index_path}")
return index_path
def spoiler_free_search(
index_path: str,
query: str,
max_chapter: int,
character_filter: Optional[list[str]] = None,
) -> list[dict[str, Any]]:
"""
Perform a spoiler-free search on the book index.
Args:
index_path: Path to the LEANN index
query: Search query
max_chapter: Maximum chapter number to include
character_filter: Optional list of characters to focus on
Returns:
List of search results safe for the reader
"""
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
searcher = LeannSearcher(index_path)
metadata_filters = {"chapter": {"<=": max_chapter}}
if character_filter:
metadata_filters["characters"] = {"contains": character_filter[0]}
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
return results
def demo_spoiler_free_rag():
"""
Demonstrate the spoiler-free book RAG system.
"""
print("🎭 Spoiler-Free Book RAG Demo")
print("=" * 40)
# Step 1: Prepare book data
book_title = "Alice's Adventures in Wonderland"
book_chunks = chunk_book_with_metadata(book_title)
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
# Step 2: Build the index (in practice, this would be done once)
try:
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
except Exception as e:
print(f"❌ Failed to build index (likely missing dependencies): {e}")
print(
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
)
return
# Step 3: Demonstrate various spoiler-free searches
search_scenarios = [
{
"description": "Reader who has only read Chapter 1",
"query": "What can you tell me about the rabbit?",
"max_chapter": 1,
},
{
"description": "Reader who has read up to Chapter 5",
"query": "Tell me about Alice's adventures",
"max_chapter": 5,
},
{
"description": "Reader who has read most of the book",
"query": "What does the Cheshire Cat represent?",
"max_chapter": 10,
},
{
"description": "Reader who has read the whole book",
"query": "What can you tell me about the rabbit?",
"max_chapter": 12,
},
]
for scenario in search_scenarios:
print(f"\n📚 Scenario: {scenario['description']}")
print(f" Query: {scenario['query']}")
try:
results = spoiler_free_search(
index_path=index_path,
query=scenario["query"],
max_chapter=scenario["max_chapter"],
)
print(f" 📄 Found {len(results)} results:")
for i, result in enumerate(results[:3], 1): # Show top 3
chapter = result.metadata.get("chapter", "?")
location = result.metadata.get("location", "?")
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
except Exception as e:
print(f" ❌ Search failed: {e}")
if __name__ == "__main__":
print("📚 LEANN Spoiler-Free Book RAG Example")
print("=====================================")
try:
demo_spoiler_free_rag()
except ImportError as e:
print(f"❌ Cannot run demo due to missing dependencies: {e}")
except Exception as e:
print(f"❌ Error running demo: {e}")

View File

@@ -1,28 +0,0 @@
# llms.txt — LEANN MCP and Agent Integration
product: LEANN
homepage: https://github.com/yichuan-w/LEANN
contact: https://github.com/yichuan-w/LEANN/issues
# Installation
install: uv tool install leann-core --with leann
# MCP Server Entry Point
mcp.server: leann_mcp
mcp.protocol_version: 2024-11-05
# Tools
mcp.tools: leann_list, leann_search
mcp.tool.leann_list.description: List available LEANN indexes
mcp.tool.leann_list.input: {}
mcp.tool.leann_search.description: Semantic search across a named LEANN index
mcp.tool.leann_search.input.index_name: string, required
mcp.tool.leann_search.input.query: string, required
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
# Notes
note: Build indexes with `leann build <name> --docs <files...>` before searching.
example.add: claude mcp add --scope user leann-server -- leann_mcp
example.verify: claude mcp list | cat

View File

@@ -1,7 +1 @@
from . import diskann_backend as diskann_backend
from . import graph_partition
# Export main classes and functions
from .graph_partition import GraphPartitioner, partition_graph
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]

View File

@@ -22,11 +22,6 @@ logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
if os.getenv("CI") == "true":
yield
return
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
@@ -142,71 +137,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
"""
Safely cleanup files after partition.
In partition mode, C++ doesn't read _disk.index content,
so we can delete it if all derived files exist.
"""
disk_index_file = index_dir / f"{index_prefix}_disk.index"
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
# Required files that C++ partition mode needs
# Note: C++ generates these with _disk.index suffix
disk_suffix = "_disk.index"
required_files = [
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
f"{index_prefix}_pq_pivots.bin", # PQ table
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
]
# Check if all required files exist
missing_files = []
for filename in required_files:
file_path = index_dir / filename
if not file_path.exists():
missing_files.append(filename)
if missing_files:
logger.warning(
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
)
logger.info("Keeping all original files for safety")
return
# Calculate space savings
space_saved = 0
files_to_delete = []
if disk_index_file.exists():
space_saved += disk_index_file.stat().st_size
files_to_delete.append(disk_index_file)
if beam_search_file.exists():
space_saved += beam_search_file.stat().st_size
files_to_delete.append(beam_search_file)
# Safe to delete!
for file_to_delete in files_to_delete:
try:
os.remove(file_to_delete)
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
except Exception as e:
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
if space_saved > 0:
space_saved_mb = space_saved / (1024 * 1024)
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
# Show what files are kept
logger.info("📁 Kept essential files for partition mode:")
for filename in required_files:
file_path = index_dir / filename
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info(f" - {filename} ({size_mb:.1f} MB)")
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
@@ -221,17 +151,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
_write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs}
# Extract is_recompute from nested backend_kwargs if needed
is_recompute = build_kwargs.get("is_recompute", False)
if not is_recompute and "backend_kwargs" in build_kwargs:
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
# Flatten all backend_kwargs parameters to top level for compatibility
if "backend_kwargs" in build_kwargs:
nested_params = build_kwargs.pop("backend_kwargs")
build_kwargs.update(nested_params)
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
@@ -266,30 +185,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
build_kwargs.get("pq_disk_bytes", 0),
"",
)
# Auto-partition if is_recompute is enabled
if build_kwargs.get("is_recompute", False):
logger.info("is_recompute=True, starting automatic graph partitioning...")
from .graph_partition import partition_graph
# Partition the index using absolute paths
# Convert to absolute paths to avoid issues with working directory changes
absolute_index_dir = Path(index_dir).resolve()
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
disk_graph_path, partition_bin_path = partition_graph(
index_prefix_path=absolute_index_prefix_path,
output_dir=str(absolute_index_dir),
partition_prefix=index_prefix,
)
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
self._safe_cleanup_after_partition(index_dir, index_prefix)
logger.info("✅ Graph partitioning completed successfully!")
logger.info(f" - Disk graph: {disk_graph_path}")
logger.info(f" - Partition file: {partition_bin_path}")
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
@@ -318,42 +213,16 @@ class DiskannSearcher(BaseSearcher):
# For DiskANN, we need to reinitialize the index when zmq_port changes
# Store the initialization parameters for later use
# Note: C++ load method expects the BASE path (without _disk.index suffix)
# C++ internally constructs: index_prefix + "_disk.index"
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
# Auto-detect partition files and set partition_prefix
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
partition_prefix = ""
if partition_graph_file.exists() and partition_bin_file.exists():
# C++ expects full path prefix, not just filename
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
else:
logger.debug("No partition files detected, using standard index files")
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._init_params = {
"metric_enum": metric_enum,
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
# 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
"cache_mechanism": kwargs.get("cache_mechanism", 1),
"cache_mechanism": 1,
"pq_prefix": "",
"partition_prefix": partition_prefix,
"partition_prefix": "",
}
# Log partition configuration for debugging
if partition_prefix:
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None
@@ -442,14 +311,9 @@ class DiskannSearcher(BaseSearcher):
else: # "global"
use_global_pruning = True
# Strategy:
# - Traversal always uses PQ distances
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
# (fetch embeddings for the final candidate set only)
# - Do not recompute neighbor distances along the path
use_deferred_fetch = True if recompute_embeddings else False
recompute_neighors = False # Expected typo. For backward compatibility.
# Perform search with suppressed C++ output based on log level
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
recompute_neighors = False
with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search(
query,

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Any, Optional
from typing import Optional
import numpy as np
import zmq
@@ -32,16 +32,6 @@ if not logger.handlers:
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
@@ -91,9 +81,10 @@ def create_diskann_embedding_server(
with open(passages_file) as f:
meta = json.load(f)
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
passages = PassageManager(meta["passage_sources"])
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
# Import protobuf after ensuring the path is correct
try:
@@ -111,9 +102,8 @@ def create_diskann_embedding_server(
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.setsockopt(zmq.SNDTIMEO, 1000)
socket.setsockopt(zmq.LINGER, 0)
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
while True:
try:
@@ -191,12 +181,7 @@ def create_diskann_embedding_server(
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -235,222 +220,30 @@ def create_diskann_embedding_server(
traceback.print_exc()
raise
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
This creates its own REP socket, binds to zmq_port, and periodically
checks shutdown_event using recv timeouts to exit cleanly.
"""
logger.info("DiskANN ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
# Set receive timeout so we can check shutdown_event periodically
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
# REP socket receives single-part messages
message = rep_socket.recv()
# Check for empty messages - REP socket requires response to every request
if not message:
logger.warning("Received empty message, sending empty response")
rep_socket.send(b"")
continue
# Try protobuf first (same logic as original)
texts = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
# Look up texts by node IDs
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"ZMQ received msgpack text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
continue
# Process the request
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error("NaN or Inf detected in embeddings!")
# Send error response
if is_text_request:
import msgpack
response_data = msgpack.packb([])
else:
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = resp_proto.SerializeToString()
rep_socket.send(response_data)
continue
# Prepare response based on request type
if is_text_request:
# For direct text requests, return msgpack
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For protobuf requests, return protobuf
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
rep_socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
try:
# Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("DiskANN ZMQ server thread exiting gracefully")
# Add shutdown coordination
shutdown_event = threading.Event()
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Clean up other resources
try:
import gc
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed")
sys.exit(0)
# Register signal handlers within this function scope
import signal
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Start ZMQ thread (NOT daemon!)
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import signal
import sys
# Signal handlers are now registered within create_diskann_embedding_server
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env python3
"""
Graph Partition Module for LEANN DiskANN Backend
This module provides Python bindings for the graph partition functionality
of DiskANN, allowing users to partition disk-based indices for better
performance.
"""
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
class GraphPartitioner:
"""
A Python interface for DiskANN's graph partition functionality.
This class provides methods to partition disk-based indices for improved
search performance and memory efficiency.
"""
def __init__(self, build_type: str = "release"):
"""
Initialize the GraphPartitioner.
Args:
build_type: Build type for the executables ("debug" or "release")
"""
self.build_type = build_type
self._ensure_executables()
def _get_executable_path(self, name: str) -> str:
"""Get the path to a graph partition executable."""
# Get the directory where this Python module is located
module_dir = Path(__file__).parent
# Navigate to the graph_partition directory
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
if not executable_path.exists():
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
return str(executable_path)
def _ensure_executables(self):
"""Ensure that the required executables are built."""
try:
self._get_executable_path("partitioner")
self._get_executable_path("index_relayout")
except FileNotFoundError:
# Try to build the executables automatically
print("Executables not found, attempting to build them...")
self._build_executables()
def _build_executables(self):
"""Build the required executables."""
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Clean any existing build
if (graph_partition_dir / "build").exists():
shutil.rmtree(graph_partition_dir / "build")
# Run the build script
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
# Check if executables were created
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
print(f"✅ Built partitioner: {partitioner_path}")
print(f"✅ Built index_relayout: {relayout_path}")
except Exception as e:
raise RuntimeError(f"Failed to build executables: {e}")
finally:
os.chdir(original_dir)
def partition_graph(
self,
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
"""
Partition a disk-based index for improved performance.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory for results (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
**kwargs: Additional parameters for graph partitioning:
- gp_times: Number of LDG partition iterations (default: 10)
- lock_nums: Number of lock nodes (default: 10)
- cut: Cut adjacency list degree (default: 100)
- scale_factor: Scale factor (default: 1)
- data_type: Data type (default: "float")
- thread_nums: Number of threads (default: 10)
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
Raises:
RuntimeError: If the partitioning process fails
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Determine partition prefix
if partition_prefix is None:
partition_prefix = Path(index_prefix_path).name
# Get executable paths
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Change to the graph_partition directory for temporary files
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Create temporary data directory
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Run partitioner
gp_file_path = graph_gp_path / "_part.bin"
partitioner_cmd = [
partitioner_path,
"--index_file",
old_index_file,
"--data_type",
params["data_type"],
"--gp_file",
str(gp_file_path),
"-T",
str(params["thread_nums"]),
"--ldg_times",
str(params["gp_times"]),
"--scale",
str(params["scale_factor"]),
"--mode",
"1",
]
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
result = subprocess.run(
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Partitioner failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Run relayout
part_tmp_index = graph_gp_path / "_part_tmp.index"
relayout_cmd = [
relayout_path,
old_index_file,
str(gp_file_path),
params["data_type"],
"1",
]
print(f"Running relayout: {' '.join(relayout_cmd)}")
result = subprocess.run(
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Relayout failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Copy results to output directory
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
shutil.copy2(part_tmp_index, disk_graph_path)
shutil.copy2(gp_file_path, partition_bin_path)
print(f"Results copied to: {output_dir}")
return str(disk_graph_path), str(partition_bin_path)
finally:
os.chdir(original_dir)
def get_partition_info(self, partition_bin_path: str) -> dict:
"""
Get information about a partition file.
Args:
partition_bin_path: Path to the partition binary file
Returns:
Dictionary containing partition information
"""
if not os.path.exists(partition_bin_path):
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
# For now, return basic file information
# In the future, this could parse the binary file for detailed info
stat = os.stat(partition_bin_path)
return {
"file_size": stat.st_size,
"file_path": partition_bin_path,
"modified_time": stat.st_mtime,
}
def partition_graph(
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
build_type: str = "release",
**kwargs,
) -> tuple[str, str]:
"""
Convenience function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix
output_dir: Output directory (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
build_type: Build type for executables ("debug" or "release")
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
partitioner = GraphPartitioner(build_type=build_type)
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
# Example usage:
if __name__ == "__main__":
# Example: partition an index
try:
disk_graph_path, partition_bin_path = partition_graph(
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
)
print("Partitioning completed successfully!")
print(f"Disk graph index: {disk_graph_path}")
print(f"Partition binary: {partition_bin_path}")
except Exception as e:
print(f"Partitioning failed: {e}")

View File

@@ -1,11 +1,11 @@
[build-system]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.3.4"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
version = "0.2.7"
dependencies = ["leann-core==0.2.7", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -13,7 +13,7 @@ if(APPLE)
else()
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
endif()
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")
@@ -49,28 +49,9 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
# Disable additional SIMD versions to speed up compilation
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
# ARM64-specific configuration
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
message(STATUS "Configuring Faiss for ARM64 architecture")
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
else()
# Use generic optimization for other ARM64 platforms (like macOS)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
endif()
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
message(STATUS "Using ARM64-compatible Faiss submodule")
endif()
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)

View File

@@ -1,21 +1,12 @@
import argparse
import gc # Import garbage collector interface
import logging
import os
import struct
import sys
import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
# Set up logging to avoid print buffer issues
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
# Add other HNSW fourccs if you expect different storage types inside HNSW
@@ -239,288 +230,6 @@ def write_compact_format(
f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic ---
@@ -534,8 +243,6 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
output_filename: Output CSR index file
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
"""
# Keep prints simple; rely on CI runner to flush output as needed
print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time()
original_hnsw_data = {}
@@ -984,29 +691,6 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@@ -1,7 +1,6 @@
import logging
import os
import shutil
import time
from pathlib import Path
from typing import Any, Literal, Optional
@@ -14,7 +13,7 @@ from leann.interface import (
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
from .convert_to_csr import convert_hnsw_graph_to_csr
logger = logging.getLogger(__name__)
@@ -55,13 +54,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute and self.is_compact:
# Auto-correct: non-recompute requires non-compact storage for HNSW
logger.warning(
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
)
self.is_compact = False
self.build_params["is_compact"] = False
if not self.is_recompute:
if self.is_compact:
# TODO: support this case @andy
raise ValueError(
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
)
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
from . import faiss # type: ignore
@@ -90,19 +88,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact:
self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
@@ -144,10 +131,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
self.is_compact, self.is_pruned = (
self.meta.get("is_compact", True),
self.meta.get("is_pruned", True),
)
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
@@ -161,16 +148,6 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search(
self,
query: np.ndarray,
@@ -207,11 +184,9 @@ class HNSWSearcher(BaseSearcher):
"""
from . import faiss # type: ignore
if not recompute_embeddings and self.is_pruned:
raise RuntimeError(
"Recompute is required for pruned/compact HNSW index. "
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
)
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
@@ -258,7 +233,6 @@ class HNSWSearcher(BaseSearcher):
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
search_time = time.time()
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
@@ -267,21 +241,7 @@ class HNSWSearcher(BaseSearcher):
faiss.swig_ptr(labels),
params,
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
if self._id_map:
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Any, Optional
from typing import Union
import msgpack
import numpy as np
@@ -24,39 +24,17 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have handlers if none exist
# Ensure we have a handler if none exists
if not logger.handlers:
stream_handler = logging.StreamHandler()
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_file: Union[str, None] = None,
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
@@ -104,359 +82,199 @@ def create_hnsw_embedding_server(
with open(passages_file) as f:
meta = json.load(f)
# Let PassageManager handle path resolution uniformly. It supports fallback order:
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
# Dimension from metadata for shaping responses
try:
embedding_dim: int = int(meta.get("dimensions", 0))
except Exception:
embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Convert relative paths to absolute paths based on metadata file location
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
passage_sources = []
for source in meta["passage_sources"]:
source_copy = source.copy()
# Convert relative paths to absolute paths
if not Path(source_copy["path"]).is_absolute():
source_copy["path"] = str(metadata_dir / source_copy["path"])
if not Path(source_copy["index_path"]).is_absolute():
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
passage_sources.append(source_copy)
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []
try:
meta_path = Path(passages_file)
base = meta_path.name
if base.endswith(".meta.json"):
base = base[: -len(".meta.json")] # e.g., laion_index.leann
if base.endswith(".leann"):
base = base[: -len(".leann")] # e.g., laion_index
idmap_file = meta_path.parent / f"{base}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
id_map = [line.rstrip("\n") for line in f]
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
else:
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def _map_node_id(nid) -> str:
try:
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
idx = int(nid)
if 0 <= idx < len(id_map):
return id_map[idx]
except Exception:
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
Creates its own REP socket bound to zmq_port and polls with timeouts
to allow graceful shutdown.
"""
logger.info("ZMQ server thread started with shutdown support")
passages = PassageManager(passage_sources)
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
def zmq_server_thread():
"""ZMQ server thread"""
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
# Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
while True:
try:
message_bytes = socket.recv()
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
# Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes)
e2e_start = time.time()
request_payload = msgpack.unpackb(message_bytes)
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
# Handle direct text embedding request
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
# Use unified embedding computation (now with model caching)
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
node_ids = request[0]
# Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
# Handle distance calculation requests
if (
isinstance(request_payload, list)
and len(request_payload) == 2
and isinstance(request_payload[0], list)
and isinstance(request_payload[1], list)
):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
# Get embeddings for node IDs
texts = []
for nid in node_ids:
try:
passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {passage_id}")
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
texts.append(txt)
except KeyError:
logger.error(f"Passage with ID {nid} not found")
logger.error(f"Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if texts:
try:
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
# Calculate distances
if distance_metric == "l2":
distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
distances = -np.dot(embeddings, query_vector)
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
logger.debug(f"Sending distance response with {len(distances)} distances")
rep_socket.send(response_bytes)
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("ZMQ server thread exiting gracefully")
# Standard embedding request (passage ID lookup)
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
logger.error(
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
# Add shutdown coordination
shutdown_event = threading.Event()
node_ids = request_payload[0]
logger.debug(f"Request for {len(node_ids)} node embeddings")
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
# Look up texts by node IDs
texts = []
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
raise AssertionError()
# Clean up other resources
try:
import gc
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
logger.info("Graceful shutdown completed")
sys.exit(0)
except zmq.Again:
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
# Register signal handlers within this function scope
import signal
traceback.print_exc()
socket.send(msgpack.packb([[], []]))
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Pass shutdown_event to ZMQ thread
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("HNSW Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import signal
import sys
# Signal handlers are now registered within create_hnsw_embedding_server
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.3.4"
version = "0.2.7"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.3.4",
"leann-core==0.2.7",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.3.4"
version = "0.2.7"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"
@@ -18,16 +18,14 @@ dependencies = [
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=3.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
# Keep transformers below 4.46: 4.46.0 adds Python 3.10-only return type syntax and
# breaks Python 3.9 environments.
"transformers>=4.30.0,<4.46",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
@@ -42,7 +40,7 @@ dependencies = [
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<4.46", # 4.46.0 switches to PEP 604 typing (int | None), breaks Py3.9
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
]

View File

@@ -5,26 +5,19 @@ with the correct, original embedding logic from the user's reference code.
import json
import logging
import os
import pickle
import re
import subprocess
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional
import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interactive_utils import create_api_session
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .embedding_server_manager import EmbeddingServerManager
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__)
@@ -42,7 +35,6 @@ def compute_embeddings(
use_server: bool = True,
port: Optional[int] = None,
is_build=False,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -54,7 +46,6 @@ def compute_embeddings(
- "sentence-transformers": Use sentence-transformers library (default)
- "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API
- "gemini": Use Google Gemini embedding API
use_server: Whether to use embedding server (True for search, False for build)
Returns:
@@ -76,7 +67,6 @@ def compute_embeddings(
model_name,
mode=mode,
is_build=is_build,
provider_options=provider_options,
)
@@ -125,156 +115,42 @@ class SearchResult:
class PassageManager:
def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
):
self.offset_maps: dict[str, dict[str, int]] = {}
self.passage_files: dict[str, str] = {}
# Avoid materializing a single gigantic global map to reduce memory
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
# per-shard maps and do a lightweight per-shard lookup on demand.
self._total_count: int = 0
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
index_name_base = None
if metadata_file_path:
meta_name = Path(metadata_file_path).name
if meta_name.endswith(".meta.json"):
index_name_base = meta_name[: -len(".meta.json")]
def __init__(self, passage_sources: list[dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source.get("path", "")
index_file = source.get("index_path", "") # .idx file
passage_file = source["path"]
index_file = source["index_path"] # .idx file
# Fix path resolution - relative paths should be relative to metadata file directory
def _resolve_candidates(
primary: str,
relative_key: str,
default_name: Optional[str],
source_dict: dict[str, Any],
) -> list[Path]:
"""
Build an ordered list of candidate paths. For relative paths specified in
metadata, prefer resolution relative to the metadata file directory first,
then fall back to CWD-based resolution, and finally to conventional
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
"""
candidates: list[Path] = []
# 1) Primary path
if primary:
p = Path(primary)
if p.is_absolute():
candidates.append(p)
else:
# Prefer metadata-relative resolution for relative paths
if metadata_file_path:
candidates.append(Path(metadata_file_path).parent / p)
# Also consider CWD-relative as a fallback for legacy layouts
candidates.append(Path.cwd() / p)
# 2) metadata-relative explicit relative key (if present)
if metadata_file_path and source_dict.get(relative_key):
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
# 3) metadata-relative standard sibling filename
if metadata_file_path and default_name:
candidates.append(Path(metadata_file_path).parent / default_name)
return candidates
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
idx_candidates = _resolve_candidates(
index_file, "index_path_relative", idx_default, source
)
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
def _pick_existing(cands: list[Path]) -> str:
for c in cands:
if c.exists():
return str(c.resolve())
# Fallback to last candidate (best guess) even if not exists; will error below
return str(cands[-1].resolve()) if cands else ""
index_file = _pick_existing(idx_candidates)
passage_file = _pick_existing(pas_candidates)
# Fix path resolution for Colab and other environments
if not Path(index_file).is_absolute():
# If relative path, try to resolve it properly
index_file = str(Path(index_file).resolve())
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
self._total_count += len(offset_map)
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> dict[str, Any]:
# Fast path: check each shard map (there are typically few shards).
# This avoids building a massive combined dict while keeping lookups
# bounded by the number of shards.
for passage_file, offset_map in self.offset_maps.items():
try:
offset = offset_map[passage_id]
with open(passage_file, encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
except KeyError:
continue
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
def filter_search_results(
self,
search_results: list[SearchResult],
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
) -> list[SearchResult]:
"""
Apply metadata filters to search results.
Args:
search_results: List of SearchResult objects
metadata_filters: Filter specifications to apply
Returns:
Filtered list of SearchResult objects
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying metadata filters to {len(search_results)} results")
# Convert SearchResult objects to dictionaries for the filter engine
result_dicts = []
for result in search_results:
result_dicts.append(
{
"id": result.id,
"score": result.score,
"text": result.text,
"metadata": result.metadata,
}
)
# Apply filters using the filter engine
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
# Convert back to SearchResult objects
filtered_results = []
for result_dict in filtered_dicts:
filtered_results.append(
SearchResult(
id=result_dict["id"],
score=result_dict["score"],
text=result_dict["text"],
metadata=result_dict["metadata"],
)
)
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
return filtered_results
def __len__(self) -> int:
return self._total_count
class LeannBuilder:
def __init__(
@@ -283,22 +159,9 @@ class LeannBuilder:
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
**backend_kwargs,
):
self.backend_name = backend_name
# Normalize incompatible combinations early (for consistent metadata)
if backend_name == "hnsw":
is_recompute = backend_kwargs.get("is_recompute", True)
is_compact = backend_kwargs.get("is_compact", True)
if is_recompute is False and is_compact is True:
warnings.warn(
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
UserWarning,
stacklevel=2,
)
backend_kwargs["is_compact"] = False
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
@@ -306,7 +169,6 @@ class LeannBuilder:
self.embedding_model = embedding_model
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.embedding_options = embedding_options or {}
# Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
@@ -390,23 +252,6 @@ class LeannBuilder:
def build_index(self, index_path: str):
if not self.chunks:
raise ValueError("No chunks added.")
# Filter out invalid/empty text chunks early to keep passage and embedding counts aligned
valid_chunks: list[dict[str, Any]] = []
skipped = 0
for chunk in self.chunks:
text = chunk.get("text", "")
if isinstance(text, str) and text.strip():
valid_chunks.append(chunk)
else:
skipped += 1
if skipped > 0:
print(
f"Warning: Skipping {skipped} empty/invalid text chunk(s). Processing {len(valid_chunks)} valid chunks"
)
self.chunks = valid_chunks
if not self.chunks:
raise ValueError("All provided chunks are empty or invalid. Nothing to index.")
if self.dimensions is None:
self.dimensions = len(
compute_embeddings(
@@ -414,7 +259,6 @@ class LeannBuilder:
self.embedding_model,
self.embedding_mode,
use_server=False,
provider_options=self.embedding_options,
)[0]
)
path = Path(index_path)
@@ -454,20 +298,8 @@ class LeannBuilder:
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
string_ids = [chunk["id"] for chunk in self.chunks]
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
@@ -482,25 +314,20 @@ class LeannBuilder:
"passage_sources": [
{
"type": "jsonl",
# Preserve existing relative file names (backward-compatible)
"path": passages_file.name,
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
"path": str(passages_file),
"index_path": str(offset_file),
}
],
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = bool(is_recompute)
meta_data["is_pruned"] = (
is_compact and is_recompute
) # Pruned only if compact and recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
@@ -587,17 +414,6 @@ class LeannBuilder:
# Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids]
# Persist ID map (order == embeddings order)
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path)
@@ -614,254 +430,26 @@ class LeannBuilder:
"passage_sources": [
{
"type": "jsonl",
# Preserve existing relative file names (backward-compatible)
"path": passages_file.name,
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
"path": str(passages_file),
"index_path": str(offset_file),
}
],
"built_from_precomputed_embeddings": True,
"embeddings_source": str(embeddings_file),
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = bool(is_recompute)
meta_data["is_pruned"] = is_compact and is_recompute
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
def update_index(self, index_path: str):
"""Append new passages and vectors to an existing HNSW index."""
if not self.chunks:
raise ValueError("No new chunks provided for update.")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_prefix = path.stem
meta_path = index_dir / f"{index_name}.meta.json"
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
index_file = index_dir / f"{index_prefix}.index"
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
backend_name = meta.get("backend_name")
if backend_name != self.backend_name:
raise ValueError(
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
)
meta_backend_kwargs = meta.get("backend_kwargs", {})
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
if index_is_compact:
raise ValueError(
"Compact HNSW indices do not support in-place updates. Rebuild required."
)
distance_metric = meta_backend_kwargs.get(
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
).lower()
needs_recompute = bool(
meta.get("is_pruned")
or meta_backend_kwargs.get("is_recompute")
or self.backend_kwargs.get("is_recompute")
)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in self.chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
embedding_dim = embeddings.shape[1]
expected_dim = meta.get("dimensions")
if expected_dim is not None and expected_dim != embedding_dim:
raise ValueError(
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
)
from leann_backend_hnsw import faiss # type: ignore
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
print(f"index.is_recompute: {index.is_recompute}")
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
# Faiss expects storage.ntotal to reflect the existing graph's
# population (even if the vectors themselves were pruned from disk
# for recompute mode). When we attach a fresh IndexFlat here its
# ntotal starts at zero, which later causes IndexHNSW::add to
# believe new "preset" levels were provided and trips the
# `n0 + n == levels.size()` assertion. Seed the temporary storage
# with the current ntotal so Faiss maintains the proper offset for
# incoming vectors.
try:
storage_index.ntotal = index.ntotal
except AttributeError:
# Older Faiss builds may not expose ntotal as a writable
# attribute; in that case we fall back to the default behaviour.
pass
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
passage_meta_mode = meta.get("embedding_mode", self.embedding_mode)
passage_provider_options = meta.get("embedding_options", self.embedding_options)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
# Append passages/offsets before we attempt index.add so the ZMQ server
# can resolve newly assigned IDs during recompute. Keep rollback hooks
# so we can restore files if the update fails mid-way.
rollback_passages_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager: Optional[EmbeddingServerManager] = None
server_started = False
requested_zmq_port = int(os.getenv("LEANN_UPDATE_ZMQ_PORT", "5557"))
try:
if needs_recompute:
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=requested_zmq_port,
model_name=self.embedding_model,
embedding_mode=passage_meta_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
provider_options=passage_provider_options,
)
if not server_started:
raise RuntimeError(
"Failed to start HNSW embedding server for recompute update."
)
if actual_port != requested_zmq_port:
logger.warning(
"Embedding server started on port %s instead of requested %s. "
"Using reassigned port.",
actual_port,
requested_zmq_port,
)
try:
index.hnsw.zmq_port = actual_port
except AttributeError:
pass
if needs_recompute:
for i in range(embeddings.shape[0]):
print(f"add {i} embeddings")
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
else:
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
finally:
if server_started and server_manager is not None:
server_manager.stop_server()
except Exception:
# Roll back appended passages/offset map to keep files consistent.
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_passages_size)
offset_map = offset_map_backup
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
raise
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(
"Appended %d passages to index '%s'. New total: %d",
len(valid_chunks),
index_path,
len(offset_map),
)
self.chunks.clear()
if needs_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -885,20 +473,12 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta_data.get("embedding_options", {})
# Delegate portability handling to PassageManager
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
)
# Preserve backend name for conditional parameter forwarding
self.backend_name = backend_name
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
@@ -913,49 +493,15 @@ class LeannSearcher:
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
**kwargs,
) -> list[SearchResult]:
"""
Search for nearest neighbors with optional metadata filtering.
Args:
query: Text query to search for
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata.
Format: {"field_name": {"operator": value}}
Supported operators:
- Comparison: "==", "!=", "<", "<=", ">", ">="
- Membership: "in", "not_in"
- String: "contains", "starts_with", "ends_with"
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
**kwargs: Backend-specific parameters
Returns:
List of SearchResult objects with text, metadata, and similarity scores
"""
# Handle grep search
if use_grep:
return self._grep_search(query, top_k)
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
logger.info(f" Metadata filters: {metadata_filters}")
logger.info(f" Additional kwargs: {kwargs}")
# Smart top_k detection and adjustment
# Use PassageManager length (sum of shard sizes) to avoid
# depending on a massive combined map
total_docs = len(self.passage_manager)
total_docs = len(self.passage_manager.global_offset_map)
original_top_k = top_k
if top_k > total_docs:
top_k = total_docs
@@ -984,39 +530,29 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
logger.info(f" Embedding time: {embedding_time} seconds")
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
backend_search_kwargs: dict[str, Any] = {
"complexity": complexity,
"beam_width": beam_width,
"prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy,
"zmq_port": zmq_port,
}
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
**backend_search_kwargs,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
time.time() - start_time
# logger.info(f" Search time: {search_time} seconds")
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
# Python 3.9 does not support zip(strict=...); lengths are expected to match
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
@@ -1044,138 +580,13 @@ class LeannSearcher:
)
except KeyError:
RED = "\033[91m"
RESET = "\033[0m"
logger.error(
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
# Apply metadata filters if specified
if metadata_filters:
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
enriched_results = self.passage_manager.filter_search_results(
enriched_results, metadata_filters
)
# Define color codes outside the loop for final message
GREEN = "\033[92m"
RESET = "\033[0m"
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results
def _find_jsonl_file(self) -> Optional[str]:
"""Find the .jsonl file containing raw passages for grep search"""
index_path = Path(self.meta_path_str).parent
potential_files = [
index_path / "documents.leann.passages.jsonl",
index_path.parent / "documents.leann.passages.jsonl",
]
for file_path in potential_files:
if file_path.exists():
return str(file_path)
return None
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Perform grep-based search on raw passages"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl passages file found for grep search")
try:
cmd = ["grep", "-i", "-n", query, jsonl_file]
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 1:
return []
elif result.returncode != 0:
raise RuntimeError(f"Grep failed: {result.stderr}")
matches = []
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split(":", 1)
if len(parts) != 2:
continue
try:
data = json.loads(parts[1])
text = data.get("text", "")
score = text.lower().count(query.lower())
matches.append(
SearchResult(
id=data.get("id", parts[0]),
text=text,
metadata=data.get("metadata", {}),
score=float(score),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
except FileNotFoundError:
raise RuntimeError(
"grep command not found. Please install grep or use semantic search."
)
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Fallback regex search"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl file found")
pattern = re.compile(re.escape(query), re.IGNORECASE)
matches = []
with open(jsonl_file, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if pattern.search(line):
try:
data = json.loads(line.strip())
matches.append(
SearchResult(
id=data.get("id", str(line_num)),
text=data.get("text", ""),
metadata=data.get("metadata", {}),
score=float(len(pattern.findall(data.get("text", "")))),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios.
"""
backend = getattr(self.backend_impl, "embedding_server_manager", None)
if backend is not None:
backend.stop_server()
# Enable automatic cleanup patterns
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
try:
self.cleanup()
except Exception:
pass
def __del__(self):
try:
self.cleanup()
except Exception:
# Avoid noisy errors during interpreter shutdown
pass
class LeannChat:
def __init__(
@@ -1183,15 +594,9 @@ class LeannChat:
index_path: str,
llm_config: Optional[dict[str, Any]] = None,
enable_warmup: bool = False,
searcher: Optional[LeannSearcher] = None,
**kwargs,
):
if searcher is None:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self._owns_searcher = True
else:
self.searcher = searcher
self._owns_searcher = False
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
def ask(
@@ -1205,9 +610,6 @@ class LeannChat:
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None,
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
**search_kwargs,
):
if llm_kwargs is None:
@@ -1222,12 +624,10 @@ class LeannChat:
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
metadata_filters=metadata_filters,
batch_size=batch_size,
**search_kwargs,
)
search_time = time.time() - search_time
logger.info(f" Search time: {search_time} seconds")
# logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -1243,38 +643,16 @@ class LeannChat:
return ans
def start_interactive(self):
"""Start interactive chat session."""
session = create_api_session()
def handle_query(user_input: str):
response = self.ask(user_input)
print(f"Leann: {response}")
session.run_interactive_loop(handle_query)
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
# Only stop the embedding server if this LeannChat instance created the searcher.
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()
# Enable automatic cleanup patterns
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
try:
self.cleanup()
except Exception:
pass
def __del__(self):
try:
self.cleanup()
except Exception:
pass
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ["quit", "exit"]:
break
if not user_input:
continue
response = self.ask(user_input)
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break

View File

@@ -12,8 +12,6 @@ from typing import Any, Optional
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -312,12 +310,11 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest(
model_name: str, llm_type: str, host: Optional[str] = None
model_name: str, llm_type: str, host: str = "http://localhost:11434"
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
available_models = check_ollama_models(host)
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -425,6 +422,7 @@ class LLMInterface(ABC):
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
@@ -436,6 +434,7 @@ class LLMInterface(ABC):
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
@@ -460,19 +459,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
self.model = model
self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
try:
import requests
# Check if the Ollama server is responsive
if self.host:
requests.get(self.host)
if host:
requests.get(host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", self.host)
model_error = validate_model_and_suggest(model, "ollama", host)
if model_error:
raise ValueError(model_error)
@@ -481,11 +480,9 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
raise ConnectionError(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
@@ -546,30 +543,11 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models with proper chat templates.
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
Args:
model_name (str): Name of the Hugging Face model to load.
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models as this can pose
a security risk if the model repository is compromised.
"""
def __init__(
self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat", trust_remote_code: bool = False
):
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Security warning when trust_remote_code is enabled
if trust_remote_code:
logger.warning(
"SECURITY WARNING: trust_remote_code=True allows execution of arbitrary code from the model repository. "
"Only enable this for models from trusted sources. This creates a potential security risk if the model "
"repository is compromised."
)
self.trust_remote_code = trust_remote_code
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
@@ -607,16 +585,14 @@ class HFChat(LLMInterface):
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=self.trust_remote_code
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=self.trust_remote_code,
trust_remote_code=True,
)
logger.info(f"Successfully loaded {model_name}")
finally:
@@ -706,88 +682,24 @@ class HFChat(LLMInterface):
return response.strip()
class GeminiChat(LLMInterface):
"""LLM interface for Google Gemini models."""
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing Gemini Chat with model='{model}'")
try:
import google.genai as genai
self.client = genai.Client(api_key=self.api_key)
except ImportError:
raise ImportError(
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Gemini with model {self.model}")
try:
from google.genai.types import GenerateContentConfig
generation_config = GenerateContentConfig(
temperature=kwargs.get("temperature", 0.7),
max_output_tokens=kwargs.get("max_tokens", 1000),
)
# Handle top_p parameter
if "top_p" in kwargs:
generation_config.top_p = kwargs["top_p"]
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=generation_config,
)
# Handle potential None response text
response_text = response.text
if response_text is None:
logger.warning("Gemini returned None response text")
return ""
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Gemini: {e}")
return f"Error: Could not get a response from Gemini. Details: {e}"
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
self.model = model
self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
self.client = openai.OpenAI(api_key=self.api_key)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -877,21 +789,12 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host"),
host=llm_config.get("host", "http://localhost:11434"),
)
elif llm_type == "hf":
return HFChat(
model_name=model or "deepseek-ai/deepseek-llm-7b-chat",
trust_remote_code=llm_config.get("trust_remote_code", False),
)
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else:

View File

@@ -1,220 +0,0 @@
"""
Enhanced chunking utilities with AST-aware code chunking support.
Packaged within leann-core so installed wheels can import it reliably.
"""
import logging
from pathlib import Path
from typing import Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
".java": "java",
".cs": "csharp",
".ts": "typescript",
".tsx": "typescript",
".js": "typescript",
".jsx": "typescript",
}
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
"""Separate documents into code files and regular text files."""
if code_extensions is None:
code_extensions = CODE_EXTENSIONS
code_docs = []
text_docs = []
for doc in documents:
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
if file_path:
file_ext = Path(file_path).suffix.lower()
if file_ext in code_extensions:
doc.metadata["language"] = code_extensions[file_ext]
doc.metadata["is_code"] = True
code_docs.append(doc)
else:
doc.metadata["is_code"] = False
text_docs.append(doc)
else:
doc.metadata["is_code"] = False
text_docs.append(doc)
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
return code_docs, text_docs
def get_language_from_extension(file_path: str) -> Optional[str]:
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
ext = Path(file_path).suffix.lower()
return CODE_EXTENSIONS.get(ext)
def create_ast_chunks(
documents,
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
"""Create AST-aware chunks from code documents using astchunk.
Falls back to traditional chunking if astchunk is unavailable.
"""
try:
from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
language = doc.metadata.get("language")
if not language:
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
continue
try:
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
"metadata_template": metadata_template,
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
}
repo_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
"creation_date": doc.metadata.get("creation_date", ""),
"last_modified_date": doc.metadata.get("last_modified_date", ""),
}
configs["repo_level_metadata"] = repo_metadata
chunk_builder = ASTChunkBuilder(**configs)
code_content = doc.get_content()
if not code_content or not code_content.strip():
logger.warning("Empty code content, skipping")
continue
chunks = chunk_builder.chunkify(code_content)
for chunk in chunks:
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
else:
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip())
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
)
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
if chunk_overlap < 0:
chunk_overlap = 0
if chunk_overlap >= chunk_size:
chunk_overlap = chunk_size // 2
node_parser = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
all_texts = []
for doc in documents:
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
all_texts.extend(node.get_content() for node in nodes)
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
return all_texts
def create_text_chunks(
documents,
chunk_size: int = 256,
chunk_overlap: int = 128,
use_ast_chunking: bool = False,
ast_chunk_size: int = 512,
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""Create text chunks from documents with optional AST support for code files."""
if not documents:
logger.warning("No documents provided for chunking")
return []
local_code_extensions = CODE_EXTENSIONS.copy()
if code_file_extensions:
ext_mapping = {
".py": "python",
".java": "java",
".cs": "c_sharp",
".ts": "typescript",
".tsx": "typescript",
}
for ext in code_file_extensions:
if ext.lower() not in local_code_extensions:
if ext.lower() in ext_mapping:
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
else:
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
all_chunks = []
if use_ast_chunking:
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
if code_docs:
try:
all_chunks.extend(
create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
)
)
except Exception as e:
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
all_chunks.extend(
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
)
else:
raise
if text_docs:
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
else:
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
return all_chunks

View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,14 +6,12 @@ Preserves all optimization parameters to ensure performance
import logging
import os
import time
from typing import Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import numpy as np
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -31,9 +29,6 @@ def compute_embeddings(
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -49,8 +44,6 @@ def compute_embeddings(
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
provider_options = provider_options or {}
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
@@ -58,27 +51,13 @@ def compute_embeddings(
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
manual_tokenize=manual_tokenize,
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(
texts,
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
)
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(
texts,
model_name,
is_build=is_build,
host=provider_options.get("host"),
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
@@ -91,8 +70,6 @@ def compute_embeddings_sentence_transformers(
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
@@ -236,130 +213,20 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
start_time = time.time()
if not manual_tokenize:
# Use SentenceTransformer's optimized encode path (default)
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
# Synchronize if CUDA to measure accurate wall time
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Use torch.inference_mode for optimal performance
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
all_embeddings: list[np.ndarray] = []
# Progress bar when building or for large inputs
show_progress = is_build or len(texts) > 32
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
unit="batch",
)
else:
batch_iter = range(0, len(texts), batch_size)
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
@@ -368,41 +235,26 @@ def compute_embeddings_sentence_transformers(
return embeddings
def compute_embeddings_openai(
texts: list[str],
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> np.ndarray:
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import os
import openai
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
# Validate input list
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
# Extra validation: abort early if any item is empty/whitespace
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
if invalid_count > 0:
raise ValueError(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
if not resolved_api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = f"openai_client::{resolved_base_url}"
cache_key = "openai_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
client = openai.OpenAI(api_key=api_key)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
@@ -412,16 +264,8 @@ def compute_embeddings_openai(
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K
max_batch_size = 1000 # Conservative batch size
all_embeddings = []
# get the avg len of texts
avg_len = sum(len(text) for text in texts) / len(texts)
print(f"avg len of texts: {avg_len}")
# if avg len is less than 1000, use the max batch size
if avg_len > 300:
max_batch_size = 500
# if avg len is less than 1000, use the max batch size
try:
from tqdm import tqdm
@@ -527,21 +371,16 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama(
texts: list[str],
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Compute embeddings using Ollama API.
Args:
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (defaults to environment or http://localhost:11434)
host: Ollama host URL (default: http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -556,19 +395,17 @@ def compute_embeddings_ollama(
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
resolved_host = resolve_ollama_host(host)
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{resolved_host}/api/version", timeout=5)
response = requests.get(f"{host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
f"❌ Could not connect to Ollama at {host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
@@ -580,7 +417,7 @@ def compute_embeddings_ollama(
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
response = requests.get(f"{host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
@@ -601,19 +438,12 @@ def compute_embeddings_ollama(
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model)
# Check if model exists (handle versioned names) and resolve to full name
resolved_model_name = None
for name in model_names:
# Exact match
if model_name == name:
resolved_model_name = name
break
# Match without version tag (use the versioned name)
elif model_name == name.split(":")[0]:
resolved_model_name = name
break
# Check if model exists (handle versioned names)
model_found = any(
model_name == name.split(":")[0] or model_name == name for name in model_names
)
if not resolved_model_name:
if not model_found:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model
@@ -635,17 +465,10 @@ def compute_embeddings_ollama(
error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg)
# Use the resolved model name for all subsequent operations
if resolved_model_name != model_name:
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
)
if test_response.status_code != 200:
error_msg = (
@@ -662,147 +485,162 @@ def compute_embeddings_ollama(
except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}")
# Determine batch size based on device availability
# Check for CUDA/MPS availability using torch if available
batch_size = 32 # Default for MPS/CPU
try:
import torch
# Process embeddings with optimized concurrent processing
import requests
if torch.cuda.is_available():
batch_size = 128 # CUDA gets larger batch size
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
batch_size = 32 # MPS gets smaller batch size
except ImportError:
# If torch is not available, use conservative batch size
batch_size = 32
def get_single_embedding(text_idx_tuple):
"""Helper function to get embedding for a single text."""
text, idx = text_idx_tuple
max_retries = 3
retry_count = 0
logger.info(f"Using batch size: {batch_size}")
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
result = response.json()
embedding = result.get("embedding")
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
if embedding is None:
raise ValueError(f"No embedding returned for text {idx}")
return idx, embedding
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
return idx, None
except Exception as e:
if retry_count >= max_retries - 1:
logger.error(f"Failed to get embedding for text {idx}: {e}")
return idx, None
retry_count += 1
return idx, None
# Determine if we should use concurrent processing
use_concurrent = (
len(texts) > 5 and not is_build
) # Don't use concurrent in build mode to avoid overwhelming
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
failed_indices = []
if use_concurrent:
logger.info(
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_idx = {
executor.submit(get_single_embedding, (text, idx)): idx
for idx, text in enumerate(texts)
}
# Add progress bar for concurrent processing
try:
if is_build or len(texts) > 10:
from tqdm import tqdm
futures_iterator = tqdm(
as_completed(future_to_idx),
total=len(texts),
desc="Computing Ollama embeddings",
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
else:
futures_iterator = as_completed(future_to_idx)
except ImportError:
futures_iterator = as_completed(future_to_idx)
# Collect results as they complete
for future in futures_iterator:
try:
idx, embedding = future.result()
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
idx = future_to_idx[future]
logger.error(f"Exception for text {idx}: {e}")
failed_indices.append(idx)
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
except ImportError:
show_progress = False
# Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else:
batch_iterator = range(num_batches)
# Sequential processing with progress bar
show_progress = is_build or len(texts) > 10
for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
try:
if show_progress:
from tqdm import tqdm
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
iterator = tqdm(
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
)
else:
iterator = enumerate(texts)
except ImportError:
iterator = enumerate(texts)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
for idx, text in iterator:
result_idx, embedding = get_single_embedding((text, idx))
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
# Handle failed embeddings
if all_failed_indices:
if len(all_failed_indices) == len(texts):
if failed_indices:
if len(failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings")
logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
# Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding:
embedding_dim = len(valid_embedding)
for i, embedding in enumerate(all_embeddings):
if embedding is None:
all_embeddings[i] = [0.0] * embedding_dim
for idx in failed_indices:
all_embeddings[idx] = [0.0] * embedding_dim
# Remove None values
# Remove None values and convert to numpy array
all_embeddings = [e for e in all_embeddings if e is not None]
if not all_embeddings:
raise RuntimeError("No valid embeddings were computed")
# Validate embedding dimensions before creating numpy array
if all_embeddings:
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
# Validate embedding dimensions
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += (
f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
)
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
# Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32)
@@ -814,83 +652,3 @@ def compute_embeddings_ollama(
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings
def compute_embeddings_gemini(
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
) -> np.ndarray:
"""
Compute embeddings using Google Gemini API.
Args:
texts: List of texts to compute embeddings for
model_name: Gemini model name (default: "text-embedding-004")
is_build: Whether this is a build operation (shows progress bar)
Returns:
Embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import os
import google.genai as genai
except ImportError as e:
raise ImportError(f"Google GenAI package not installed: {e}")
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("GEMINI_API_KEY environment variable not set")
# Cache Gemini client
cache_key = "gemini_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = genai.Client(api_key=api_key)
_model_cache[cache_key] = client
logger.info("Gemini client cached")
logger.info(
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
)
# Gemini supports batch embedding
max_batch_size = 100 # Conservative batch size for Gemini
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
# Use the embed_content method from the new Google GenAI SDK
response = client.models.embed_content(
model=model_name,
contents=batch_texts,
config=genai.types.EmbedContentConfig(
task_type="RETRIEVAL_DOCUMENT" # For document embedding
),
)
# Extract embeddings from response
for embedding_data in response.embeddings:
all_embeddings.append(embedding_data.values)
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings

View File

@@ -1,5 +1,4 @@
import atexit
import json
import logging
import os
import socket
@@ -9,9 +8,7 @@ import time
from pathlib import Path
from typing import Optional
from .settings import encode_provider_options
# Lightweight, self-contained server manager with no cross-process inspection
import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -46,86 +43,130 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0
# Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
cmdline = proc.info["cmdline"]
if not cmdline:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
logger.warning(f"Could not check process on port {port}: {e}")
return False
# Note: All cross-process scanning helpers removed for simplicity
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
logger.info(f"Found compatible server on port {port}")
return port, True
else:
logger.info(f"Port {port} has incompatible server, trying next port...")
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager:
@@ -144,16 +185,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
# Track last-started config for in-process reuse only
self._server_config: Optional[dict] = None
self._atexit_registered = False
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
self._finalizer = None
def start_server(
self,
@@ -163,86 +195,35 @@ class EmbeddingServerManager:
**kwargs,
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
passages_file = kwargs.get("passages_file")
config_signature = self._build_config_signature(
model_name=model_name,
embedding_mode=embedding_mode,
provider_options=provider_options,
passages_file=passages_file,
)
# If this manager already has a live server, just reuse it
if (
self.server_process
and self.server_process.poll() is None
and self.server_port
and self._server_config == config_signature
):
logger.info("Reusing in-process server")
return True, self.server_port
# Configuration changed, stop existing server before starting a new one
if self.server_process and self.server_process.poll() is None:
logger.info("Existing server configuration differs; restarting embedding server")
self.stop_server()
# Check if we have a compatible server already running
if self._has_compatible_running_server(model_name, passages_file):
logger.info("Found compatible running server!")
return True, port
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(
port,
model_name,
embedding_mode,
config_signature=config_signature,
provider_options=provider_options,
**kwargs,
)
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Always pick a fresh available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
# Start a new server
return self._start_new_server(
actual_port,
model_name,
embedding_mode,
provider_options=provider_options,
config_signature=config_signature,
**kwargs,
# Find a compatible port or next available
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
if is_compatible:
logger.info(f"Found compatible server on port {actual_port}")
return True, actual_port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
@@ -260,35 +241,26 @@ class EmbeddingServerManager:
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(
command,
actual_port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
return True
logger.info("Existing server process is incompatible. Should start a new server.")
return False
def _start_new_server(
self,
port: int,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
**kwargs,
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
@@ -296,21 +268,8 @@ class EmbeddingServerManager:
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(
command,
port,
provider_options=provider_options,
config_signature=config_signature,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e:
logger.error(f"Failed to start embedding server: {e}")
return False, port
@@ -340,85 +299,27 @@ class EmbeddingServerManager:
return command
def _launch_server_process(
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
# Embedding servers use many print statements that can fill stdout buffers
is_ci = os.environ.get("CI") == "true"
if is_ci:
stdout_target = subprocess.DEVNULL
stderr_target = None # Keep stderr for error debugging in CI
logger.info(
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
)
else:
stdout_target = None # Direct to console for visible logs
stderr_target = None # Direct to console for visible logs
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
env=env,
stdout=None, # Direct to console
stderr=None, # Direct to console
)
self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready)
if config_signature is not None:
self._server_config = config_signature
else: # Fallback for unexpected code paths
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Always attempt best-effort finalize at interpreter exit
atexit.register(self._finalize_process)
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
# Touch finalizer so it knows there is a live process
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
pass
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
@@ -443,35 +344,24 @@ class EmbeddingServerManager:
if not self.server_process:
return
if self.server_process and self.server_process.poll() is not None:
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
self.server_port = None
self._server_config = None
return
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
# Use simple termination first; if the server installed signal handlers,
# it will exit cleanly. Otherwise escalate to kill after a short wait.
try:
self.server_process.terminate()
except Exception:
pass
self.server_process.terminate()
try:
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
self.server_process.wait(timeout=3)
logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
logger.warning(
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
)
try:
self.server_process.kill()
except Exception:
pass
self.server_process.kill()
try:
self.server_process.wait(timeout=2)
logger.info(f"Server process {self.server_process.pid} killed successfully.")
@@ -479,75 +369,34 @@ class EmbeddingServerManager:
logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung"
)
# Don't hang indefinitely
# Clean up process resources with timeout to avoid CI hang
# Clean up process resources to prevent resource tracker warnings
try:
# Use shorter timeout in CI environments
is_ci = os.environ.get("CI") == "true"
timeout = 3 if is_ci else 10
self.server_process.wait(timeout=timeout)
logger.info(f"Server process {self.server_process.pid} cleanup completed")
except subprocess.TimeoutExpired:
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
except Exception as e:
logger.warning(f"Error during process cleanup: {e}")
finally:
self.server_process = None
self.server_port = None
self._server_config = None
def _finalize_process(self) -> None:
"""Best-effort cleanup used by weakref.finalize/atexit."""
try:
self.stop_server()
self.server_process.wait() # Ensure process is fully cleaned up
except Exception:
pass
def _adopt_existing_server(self, *args, **kwargs) -> None:
# Removed: cross-process adoption no longer supported
return
self.server_process = None
def _launch_server_process_colab(
self,
command: list,
port: int,
*,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None:
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=env,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback (unified)
# Register atexit callback
if not self._atexit_registered:
atexit.register(self._finalize_process)
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
if config_signature is not None:
self._server_config = config_signature
else:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -1,189 +0,0 @@
"""
Interactive session utilities for LEANN applications.
Provides shared readline functionality and command handling across
CLI, API, and RAG example interactive modes.
"""
import atexit
import os
from pathlib import Path
from typing import Callable, Optional
# Try to import readline with fallback for Windows
try:
import readline
HAS_READLINE = True
except ImportError:
# Windows doesn't have readline by default
HAS_READLINE = False
readline = None
class InteractiveSession:
"""Manages interactive session with optional readline support and common commands."""
def __init__(
self,
history_name: str,
prompt: str = "You: ",
welcome_message: str = "",
):
"""
Initialize interactive session with optional readline support.
Args:
history_name: Name for history file (e.g., "cli", "api_chat")
(ignored if readline not available)
prompt: Input prompt to display
welcome_message: Message to show when starting session
Note:
On systems without readline (e.g., Windows), falls back to basic input()
with limited functionality (no history, no line editing).
"""
self.history_name = history_name
self.prompt = prompt
self.welcome_message = welcome_message
self._setup_complete = False
def setup_readline(self):
"""Setup readline with history support (if available)."""
if self._setup_complete:
return
if not HAS_READLINE:
# Readline not available (likely Windows), skip setup
self._setup_complete = True
return
# History file setup
history_dir = Path.home() / ".leann" / "history"
history_dir.mkdir(parents=True, exist_ok=True)
history_file = history_dir / f"{self.history_name}.history"
# Load history if exists
try:
readline.read_history_file(str(history_file))
readline.set_history_length(1000)
except FileNotFoundError:
pass
# Save history on exit
atexit.register(readline.write_history_file, str(history_file))
# Optional: Enable vi editing mode (commented out by default)
# readline.parse_and_bind("set editing-mode vi")
self._setup_complete = True
def _show_help(self):
"""Show available commands."""
print("Commands:")
print(" quit/exit/q - Exit the chat")
print(" help - Show this help message")
print(" clear - Clear screen")
print(" history - Show command history")
def _show_history(self):
"""Show command history."""
if not HAS_READLINE:
print(" History not available (readline not supported on this system)")
return
history_length = readline.get_current_history_length()
if history_length == 0:
print(" No history available")
return
for i in range(history_length):
item = readline.get_history_item(i + 1)
if item:
print(f" {i + 1}: {item}")
def get_user_input(self) -> Optional[str]:
"""
Get user input with readline support.
Returns:
User input string, or None if EOF (Ctrl+D)
"""
try:
return input(self.prompt).strip()
except KeyboardInterrupt:
print("\n(Use 'quit' to exit)")
return "" # Return empty string to continue
except EOFError:
print("\nGoodbye!")
return None
def run_interactive_loop(self, handler_func: Callable[[str], None]):
"""
Run the interactive loop with a custom handler function.
Args:
handler_func: Function to handle user input that's not a built-in command
Should accept a string and handle the user's query
"""
self.setup_readline()
if self.welcome_message:
print(self.welcome_message)
while True:
user_input = self.get_user_input()
if user_input is None: # EOF (Ctrl+D)
break
if not user_input: # Empty input or KeyboardInterrupt
continue
# Handle built-in commands
command = user_input.lower()
if command in ["quit", "exit", "q"]:
print("Goodbye!")
break
elif command == "help":
self._show_help()
elif command == "clear":
os.system("clear" if os.name != "nt" else "cls")
elif command == "history":
self._show_history()
else:
# Regular user input - pass to handler
try:
handler_func(user_input)
except Exception as e:
print(f"Error: {e}")
def create_cli_session(index_name: str) -> InteractiveSession:
"""Create an interactive session for CLI usage."""
return InteractiveSession(
history_name=index_name,
prompt="\nYou: ",
welcome_message="LEANN Assistant ready! Type 'quit' to exit, 'help' for commands\n"
+ "=" * 40,
)
def create_api_session() -> InteractiveSession:
"""Create an interactive session for API chat."""
return InteractiveSession(
history_name="api_chat",
prompt="You: ",
welcome_message="Leann Chat started (type 'quit' to exit, 'help' for commands)\n"
+ "=" * 40,
)
def create_rag_session(app_name: str, data_description: str) -> InteractiveSession:
"""Create an interactive session for RAG examples."""
return InteractiveSession(
history_name=f"{app_name}_rag",
prompt="You: ",
welcome_message=f"[Interactive Mode] Chat with your {data_description} data!\nType 'quit' or 'exit' to stop, 'help' for commands.\n"
+ "=" * 40,
)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
from typing import Any, Literal, Union
import numpy as np
@@ -35,7 +35,7 @@ class LeannBackendSearcherInterface(ABC):
@abstractmethod
def _ensure_server_running(
self, passages_source_file: str, port: Optional[int], **kwargs
self, passages_source_file: str, port: Union[int, None], **kwargs
) -> int:
"""Ensure server is running"""
pass
@@ -50,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
zmq_port: Union[int, None] = None,
**kwargs,
) -> dict[str, Any]:
"""Search for nearest neighbors
@@ -76,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
self,
query: str,
use_server_if_available: bool = True,
zmq_port: Optional[int] = None,
zmq_port: Union[int, None] = None,
) -> np.ndarray:
"""Compute embedding for a query string

View File

@@ -64,6 +64,19 @@ def handle_request(request):
"required": ["index_name", "query"],
},
},
{
"name": "leann_status",
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
"inputSchema": {
"type": "object",
"properties": {
"index_name": {
"type": "string",
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
}
},
},
},
{
"name": "leann_list",
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
@@ -94,7 +107,7 @@ def handle_request(request):
},
}
# Build simplified command with non-interactive flag for MCP compatibility
# Build simplified command
cmd = [
"leann",
"search",
@@ -102,10 +115,19 @@ def handle_request(request):
args["query"],
f"--top-k={args.get('top_k', 5)}",
f"--complexity={args.get('complexity', 32)}",
"--non-interactive",
]
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_status":
if args.get("index_name"):
# Check specific index status - for now, we'll use leann list and filter
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
# We could enhance this to show more detailed status per index
else:
# Show all indexes status
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
elif tool_name == "leann_list":
result = subprocess.run(["leann", "list"], capture_output=True, text=True)

View File

@@ -1,240 +0,0 @@
"""
Metadata filtering engine for LEANN search results.
This module provides generic metadata filtering capabilities that can be applied
to search results from any LEANN backend. The filtering supports various
operators for different data types including numbers, strings, booleans, and lists.
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
# Type alias for filter specifications
FilterValue = Union[str, int, float, bool, list]
FilterSpec = dict[str, FilterValue]
MetadataFilters = dict[str, FilterSpec]
class MetadataFilterEngine:
"""
Engine for evaluating metadata filters against search results.
Supports various operators for filtering based on metadata fields:
- Comparison: ==, !=, <, <=, >, >=
- Membership: in, not_in
- String operations: contains, starts_with, ends_with
- Boolean operations: is_true, is_false
"""
def __init__(self):
"""Initialize the filter engine with supported operators."""
self.operators = {
"==": self._equals,
"!=": self._not_equals,
"<": self._less_than,
"<=": self._less_than_or_equal,
">": self._greater_than,
">=": self._greater_than_or_equal,
"in": self._in,
"not_in": self._not_in,
"contains": self._contains,
"starts_with": self._starts_with,
"ends_with": self._ends_with,
"is_true": self._is_true,
"is_false": self._is_false,
}
def apply_filters(
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
) -> list[dict[str, Any]]:
"""
Apply metadata filters to a list of search results.
Args:
search_results: List of result dictionaries, each containing 'metadata' field
metadata_filters: Dictionary of filter specifications
Format: {"field_name": {"operator": value}}
Returns:
Filtered list of search results
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying filters: {metadata_filters}")
logger.debug(f"Input results count: {len(search_results)}")
filtered_results = []
for result in search_results:
if self._evaluate_filters(result, metadata_filters):
filtered_results.append(result)
logger.debug(f"Filtered results count: {len(filtered_results)}")
return filtered_results
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
"""
Evaluate all filters against a single search result.
All filters must pass (AND logic) for the result to be included.
Args:
result: Full search result dictionary (including metadata, text, etc.)
filters: Filter specifications to evaluate
Returns:
True if all filters pass, False otherwise
"""
for field_name, filter_spec in filters.items():
if not self._evaluate_field_filter(result, field_name, filter_spec):
return False
return True
def _evaluate_field_filter(
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
) -> bool:
"""
Evaluate a single field filter against a search result.
Args:
result: Full search result dictionary
field_name: Name of the field to filter on
filter_spec: Filter specification for this field
Returns:
True if the filter passes, False otherwise
"""
# First check top-level fields, then check metadata
field_value = result.get(field_name)
if field_value is None:
# Try to get from metadata if not found at top level
metadata = result.get("metadata", {})
field_value = metadata.get(field_name)
# Handle missing fields - they fail all filters except existence checks
if field_value is None:
logger.debug(f"Field '{field_name}' not found in result or metadata")
return False
# Evaluate each operator in the filter spec
for operator, expected_value in filter_spec.items():
if operator not in self.operators:
logger.warning(f"Unsupported operator: {operator}")
return False
try:
if not self.operators[operator](field_value, expected_value):
logger.debug(
f"Filter failed: {field_name} {operator} {expected_value} "
f"(actual: {field_value})"
)
return False
except Exception as e:
logger.warning(
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
)
return False
return True
# Comparison operators
def _equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value equals expected value."""
return field_value == expected_value
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value does not equal expected value."""
return field_value != expected_value
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
# Membership operators
def _in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'in' operator requires a list, tuple, or set")
return field_value in expected_value
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is not in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'not_in' operator requires a list, tuple, or set")
return field_value not in expected_value
# String operators
def _contains(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value contains the expected substring."""
field_str = str(field_value)
expected_str = str(expected_value)
return expected_str in field_str
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value starts with the expected prefix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.startswith(expected_str)
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value ends with the expected suffix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.endswith(expected_str)
# Boolean operators
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is truthy."""
return bool(field_value)
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is falsy."""
return not bool(field_value)
# Helper methods
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
"""
Helper for numeric comparisons with type coercion.
Args:
field_value: Value from metadata
expected_value: Value to compare against
compare_func: Comparison function to apply
Returns:
Result of comparison
"""
try:
# Try to convert both values to numbers for comparison
if isinstance(field_value, str) and isinstance(expected_value, str):
# String comparison if both are strings
return compare_func(field_value, expected_value)
# Numeric comparison - attempt to convert to float
field_num = (
float(field_value) if not isinstance(field_value, (int, float)) else field_value
)
expected_num = (
float(expected_value)
if not isinstance(expected_value, (int, float))
else expected_value
)
return compare_func(field_num, expected_num)
except (ValueError, TypeError):
# Fall back to string comparison if numeric conversion fails
return compare_func(str(field_value), str(expected_value))

View File

@@ -2,17 +2,11 @@
import importlib
import importlib.metadata
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
# Set up logger for this module
logger = logging.getLogger(__name__)
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
@@ -20,7 +14,7 @@ def register_backend(name: str):
"""A decorator to register a new backend class."""
def decorator(cls):
logger.debug(f"Registering backend '{name}'")
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
@@ -45,54 +39,3 @@ def autodiscover_backends():
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")
def register_project_directory(project_dir: Optional[Union[str, Path]] = None):
"""
Register a project directory in the global LEANN registry.
This allows `leann list` to discover indexes created by apps or other tools.
Args:
project_dir: Directory to register. If None, uses current working directory.
"""
if project_dir is None:
project_dir = Path.cwd()
else:
project_dir = Path(project_dir)
# Only register directories that have some kind of LEANN content
# Either .leann/indexes/ (CLI format) or *.leann.meta.json files (apps format)
has_cli_indexes = (project_dir / ".leann" / "indexes").exists()
has_app_indexes = any(project_dir.rglob("*.leann.meta.json"))
if not (has_cli_indexes or has_app_indexes):
# Don't register if there are no LEANN indexes
return
global_registry = Path.home() / ".leann" / "projects.json"
global_registry.parent.mkdir(exist_ok=True)
project_str = str(project_dir.resolve())
# Load existing registry
projects = []
if global_registry.exists():
try:
with open(global_registry) as f:
projects = json.load(f)
except Exception:
logger.debug("Could not load existing project registry")
projects = []
# Add project if not already present
if project_str not in projects:
projects.append(project_str)
# Save updated registry
try:
with open(global_registry, "w") as f:
json.dump(projects, f, indent=2)
logger.debug(f"Registered project directory: {project_str}")
except Exception as e:
logger.warning(f"Could not save project registry: {e}")

View File

@@ -41,7 +41,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta.get("embedding_options", {})
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name,
@@ -78,7 +77,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -127,12 +125,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings(
[query],
self.embedding_model,
embedding_mode,
provider_options=self.embedding_options,
)
return compute_embeddings([query], self.embedding_model, embedding_mode)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""

Some files were not shown because too many files have changed in this diff Show More