Compare commits

..

55 Commits

Author SHA1 Message Date
Andy Lee
4e9e2f3da0 CI: add build venv scripts path for wheel repair 2025-09-24 01:39:35 -07:00
Andy Lee
ed167f43b0 CI: use temporary uv venv for build deps 2025-09-24 01:23:16 -07:00
Andy Lee
f9746d3fe2 CI: install build deps with uv python interpreter 2025-09-24 01:19:19 -07:00
Andy Lee
a090a3444a CI: rely on setup-uv for Python and tighten group install 2025-09-24 01:14:54 -07:00
Andy Lee
aaaba27a4f CI: use uv group install with local wheel selection 2025-09-24 01:10:16 -07:00
Andy Lee
f40f539456 CI: revert install step to match main 2025-09-24 00:50:27 -07:00
Andy Lee
576a2dcb49 CI: use matrix python venv and set macOS deployment target 2025-09-24 00:48:27 -07:00
Andy Lee
ad8ab84675 CI: handle python tag mismatches for local wheels 2025-09-23 23:24:02 -07:00
Andy Lee
58b96b64d8 CI: pick wheels matching current Python tag 2025-09-23 23:05:32 -07:00
Andy Lee
a76c3cdac4 CI: install local wheels via file paths 2025-09-23 22:53:44 -07:00
Andy Lee
520619deab CI: force local wheels in uv install step 2025-09-23 22:27:31 -07:00
Andy Lee
dea08c95b4 Merge remote-tracking branch 'origin/main' into financebench 2025-09-23 21:52:14 -07:00
Andy Lee
ec889f7ef4 Allow 'leann ask' to accept a positional question (#116) 2025-09-23 21:18:57 -07:00
Yi-Ting Chiu
322e5c162d docs: open ai api compatibility (#118) 2025-09-23 21:17:50 -07:00
Yichuan Wang
edde0cdeb2 [Feat] ColQwen intergration (#111)
* add colqwen stuff

* add colqwen stuff and pass ruff

* remove ipynb
2025-09-23 17:51:29 -07:00
Andy Lee
db7ba27ff6 feat: Add support for configurable local LLM endpoints (#115)
* feat: support configurable local llm endpoints

* docs
2025-09-23 15:12:13 -07:00
Andy Lee
5f7806e16f Introducing dynamic index update (#108)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly

* fix: enable is_compact=False, is_recompute=True

* feat: update when recompute

* test

* fix: real recompute

* refactor

* fix: compare with no-recompute

* fix: test
2025-09-21 22:56:27 -07:00
yichuan-w
d034e2195b fix build from source in diskann 2025-09-20 19:52:29 +00:00
yichuan520030910320
43894ff605 update submodule 2025-09-19 17:03:55 -07:00
yichuan520030910320
10311cc611 change the submodule for easy pull 2025-09-19 17:02:09 -07:00
Andy Lee
ad0d2faabc feat: Add GitHub PR and issue templates (#105)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly
2025-09-19 13:51:36 -07:00
Andy Lee
e93c0dec6f [Fix] Enable AST chunking when installed (package chunking utils) (#101)
* fix(core): package chunking utils for AST chunking; re-export in apps; CLI imports packaged utils

* style

* chore: fix ruff warnings (RUF059, F401)

* style
2025-09-17 18:44:00 -07:00
GitHub Actions
c5a29f849a chore: release v0.3.4 2025-09-16 20:45:22 +00:00
Andy Lee
3357d5765e fix: find links to install wheels available 2025-09-15 22:22:38 -07:00
Andy Lee
9dbd0c64cc fix(ci): run with lint only 2025-09-15 21:55:19 -07:00
Andy Lee
9c400acd7e fix(ci): should checkout modules as well since uv sync checks 2025-09-15 21:40:35 -07:00
Andy Lee
ac560964f5 chore: use http url of astchunk; use group for some dev deps 2025-09-15 21:21:09 -07:00
Andy Lee
07e4f176e1 fix(ci): only run pre-commit 2025-09-15 19:57:56 -07:00
Andy Lee
b1daf021e0 Merge remote-tracking branch 'origin/main' into financebench 2025-09-15 19:52:37 -07:00
Andy Lee
3578680cb6 fix: as package 2025-09-15 19:50:45 -07:00
Andy Lee
a0d6857faa docs: data updated 2025-09-15 19:50:02 -07:00
Yichuan Wang
3b8dc6368e Ast fork (#92) 2025-09-08 18:43:31 -07:00
Aiden Huang
e309f292de docs(mcp): add root llms.txt for MCP discovery; update MCP README to reference it; refs #76 (#91) 2025-09-07 14:39:58 -07:00
AWS Mcleod
0d9f92ea0f Add grep search functionality - Issue #86 (#87)
* Add grep search functionality to LeannSearcher

- Add use_grep parameter to search method
- Implement grep-based search on .jsonl files
- Add fallback Python regex search
- Support same SearchResult format as semantic search

Addresses issue #86

* fix: resolve linting errors

* docs: add grep search example

* docs: add grep search to README examples

* refactor: remove regex fallback, move grep example to features section

* docs: add grep search to Advanced Features with comprehensive guide
2025-09-05 13:48:07 -07:00
GitHub Actions
b0b353d279 chore: release v0.3.3 2025-09-02 21:29:56 +00:00
Andy Lee
4dffdfedbe feat: Add ARM64 Linux wheel support for leann-backend-hnsw (#83)
* feat: Add ARM64 Linux wheel support for leann-backend-hnsw

* fix: Use OpenBLAS for ARM64 Linux builds instead of Intel MKL

* fix: Configure Faiss with SVE optimization for ARM64 builds

- Set FAISS_OPT_LEVEL to "sve" for ARM64 architecture
- Disable x86-specific SIMD instructions (AVX2, AVX512, SSE4.1)
- Use ARM64-native SVE optimization as per Faiss conda build scripts
- Add architecture detection and proper configuration messages

Fixes compilation error: "xmmintrin.h: No such file or directory"
on ubuntu-24.04-arm runners.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Apply ARM64 compatibility fix directly to Faiss submodule

- Modify faiss/impl/pq.cpp to use x86-specific preprocessor conditions
- Remove patch file approach in favor of direct submodule modification
- Update CMakeLists.txt to reflect the submodule changes
- Fixes ARM64 Linux compilation by preventing x86 SIMD header inclusion

This resolves the "xmmintrin.h: No such file or directory" error
when building ARM64 Linux wheels for Docker compatibility.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: Update Faiss submodule to include ARM64 compatibility fix

- Points to commit ed96ff7d with x86-specific preprocessor conditions
- Enables successful ARM64 Linux wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* retrigger ci

* fix: Use different optimization levels for ARM64 based on platform

- Use SVE optimization only for ARM64 Linux
- Use generic optimization for ARM64 macOS to avoid clang SVE issues
- Fixes macOS ARM64 compilation errors with SVE instructions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* feat: Update DiskANN submodule with OpenBLAS fallback support

- Points to commit 5c396c4 with ARM64 Linux OpenBLAS support
- Enables DiskANN to build on ARM64 Linux using standard BLAS libraries
- Resolves Intel MKL dependency issues for Docker ARM64 deployments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with ZeroMQ polling configuration

- Points to commit 3a1016e with explicit polling method setup
- Resolves ZeroMQ autodetection issues on ARM64 Linux
- Ensures stable cross-platform ZeroMQ builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* retrigger ci

* fix: Update DiskANN submodule with ARM64 compiler flags fix

- Points to commit a0dc600 with architecture-specific compiler flags
- Removes x86 SIMD flags on ARM64 Linux to fix compilation errors
- Enables successful ARM64 Linux wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with ARM64 compiler flags fix

- Points to commit 0921664 with architecture-specific compiler flags
- Removes x86 SIMD flags on ARM64 Linux to fix compilation errors
- Enables successful ARM64 Linux wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* retrigger ci

* fix: Update DiskANN submodule with cross-platform prefetch support

- Points to commit 39192d6 with unified prefetch macros
- Replaces all Intel-specific _mm_prefetch calls with cross-platform macros
- Enables ARM64 Linux compatibility while maintaining x86 performance
- Resolves all remaining compilation errors for ARM64 builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with corrected ARM64 compatibility fixes

- Points to commit 3cb87a8 with proper x86 platform detection
- Includes ARM64 fallback for AVXDistanceInnerProductFloat function
- Resolves all remaining '__m256 was not declared' compilation errors
- Enables successful ARM64 Linux wheel builds for Docker compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with template type handling fix

- Points to commit d396bc3 with corrected template type handling
- Fixes DistanceInnerProduct template instantiation for int8_t/uint8_t types
- Resolves 'cannot convert const signed char* to const float*' error
- Completes ARM64 Linux compilation compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with DistanceFastL2::norm template fix

- Points to commit 69d9a99 with corrected template type handling
- Fixes DistanceFastL2::norm template instantiation for int8_t/uint8_t types
- Resolves another 'cannot convert const signed char* to const float*' error
- Continues ARM64 Linux compilation compatibility improvements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with LAPACKE header detection

- Points to commit 64a9e01 with LAPACKE header path configuration
- Adds pkg-config based detection for LAPACKE include directories
- Resolves 'lapacke.h: No such file or directory' compilation error
- Completes OpenBLAS integration for ARM64 Linux builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with enhanced LAPACKE header detection

- Points to commit 18d0721 with fallback LAPACKE header search paths
- Checks multiple standard locations for lapacke.h on various systems
- Improves ARM64 Linux compatibility for OpenBLAS builds
- Should resolve 'lapacke.h: No such file or directory' errors

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Add liblapacke-dev package for ARM64 Linux builds

- Add liblapacke-dev to ARM64 dependencies alongside libopenblas-dev
- Provides lapacke.h header file needed for LAPACK C interface
- Fixes 'lapacke.h: No such file or directory' compilation error
- Enables complete OpenBLAS + LAPACKE support for ARM64 wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with cosine_similarity.h x86 intrinsics fix

- Points to commit dbb17eb with corrected conditional compilation
- Fixes immintrin.h inclusion for ARM64 compatibility in cosine_similarity.h
- Resolves 'immintrin.h: No such file or directory' error
- Continues systematic ARM64 Linux compilation fixes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with LAPACKE library linking fix

- Points to commit 19f9603 with explicit LAPACKE library discovery and linking
- Resolves 'undefined symbol: LAPACKE_sgesdd' runtime error on ARM64 Linux
- Completes ARM64 Linux wheel build compatibility for Docker deployments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-09-02 14:27:06 -07:00
Yichuan Wang
d41e467df9 [CLI] More robust leann list and leann build (#84)
* chore(submodule): bump faiss to latest storage-efficient build

* [chore] add slack to share use case

* [cli] better gitignore / better leann list

* [cli] fix # 81
2025-09-01 18:36:27 -07:00
yichuan520030910320
4ca0489cb1 [chore] add slack to share use case 2025-09-01 13:31:16 -07:00
yichuan520030910320
e83a671918 chore(submodule): bump faiss to latest storage-efficient build 2025-09-01 13:31:12 -07:00
Andy Lee
d7011bbea0 docs: data 2025-08-25 16:25:59 -07:00
Andy Lee
ef4c69d128 chore(ci): remove paru-bin submodule and config to fix checkout --recurse-submodules 2025-08-25 16:08:16 -07:00
Andy Lee
75c8aeee5f style: format 2025-08-25 15:48:04 -07:00
Andy Lee
3d79741f9c experiments for running DiskANN & BM25 on Arch 4090 2025-08-25 15:46:48 -07:00
Andy Lee
df34c84bd3 feat: enron email bench 2025-08-24 23:06:57 -07:00
Andy Lee
8dfd2f015c fix: resolve ruff linting errors
- Remove unused variables in benchmark scripts
- Rename unused loop variables to follow convention
2025-08-22 13:53:30 -07:00
Andy Lee
ed72232bab style: format 2025-08-22 13:51:10 -07:00
Andy Lee
26d961bfc5 style: format 2025-08-22 13:44:26 -07:00
Andy Lee
722bda4ebb Merge remote-tracking branch 'origin/main' into financebench 2025-08-22 13:39:08 -07:00
Andy Lee
a7c7e8801d feat: laion, also required idmaps support 2025-08-22 13:32:33 -07:00
Andy Lee
069bce558b feat: fix financebench 2025-08-22 13:32:23 -07:00
Andy Lee
772894012e Merge branch 'main' into financebench 2025-08-20 20:40:27 -07:00
Andy Lee
5c163737c4 Merge remote-tracking branch 'origin/main' into financebench 2025-08-17 11:58:34 -07:00
Andy Lee
6d1d67ead7 chore: ignroe data README 2025-08-17 11:58:32 -07:00
Andy Lee
ed27ea6990 docs: results 2025-08-16 16:48:01 -07:00
Andy Lee
baf2d76e0e feat: finance bench 2025-08-16 16:22:50 -07:00
67 changed files with 12247 additions and 4177 deletions

50
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,50 @@
name: Bug Report
description: Report a bug in LEANN
labels: ["bug"]
body:
- type: textarea
id: description
attributes:
label: What happened?
description: A clear description of the bug
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: How to reproduce
placeholder: |
1. Install with...
2. Run command...
3. See error
validations:
required: true
- type: textarea
id: error
attributes:
label: Error message
description: Paste any error messages
render: shell
- type: input
id: version
attributes:
label: LEANN Version
placeholder: "0.1.0"
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating System
options:
- macOS
- Linux
- Windows
- Docker
validations:
required: true

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: Documentation
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
about: Read the docs first
- name: Discussions
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
about: Ask questions and share ideas

View File

@@ -0,0 +1,27 @@
name: Feature Request
description: Suggest a new feature for LEANN
labels: ["enhancement"]
body:
- type: textarea
id: problem
attributes:
label: What problem does this solve?
description: Describe the problem or need
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: How would you like this to work?
validations:
required: true
- type: textarea
id: example
attributes:
label: Example usage
description: Show how the API might look
render: python

13
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,13 @@
## What does this PR do?
<!-- Brief description of your changes -->
## Related Issues
Fixes #
## Checklist
- [ ] Tests pass (`uv run pytest`)
- [ ] Code formatted (`ruff format` and `ruff check`)
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)

View File

@@ -17,26 +17,17 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ inputs.ref }} ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python - name: Install uv and Python
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v6
with: with:
python-version: '3.11' python-version: '3.11'
- name: Install uv - name: Run pre-commit with only lint group (no project deps)
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: | run: |
uv tool install ruff uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build: build:
needs: lint needs: lint
@@ -54,6 +45,17 @@ jobs:
python: '3.12' python: '3.12'
- os: ubuntu-22.04 - os: ubuntu-22.04
python: '3.13' python: '3.13'
# ARM64 Linux builds
- os: ubuntu-24.04-arm
python: '3.9'
- os: ubuntu-24.04-arm
python: '3.10'
- os: ubuntu-24.04-arm
python: '3.11'
- os: ubuntu-24.04-arm
python: '3.12'
- os: ubuntu-24.04-arm
python: '3.13'
- os: macos-14 - os: macos-14
python: '3.9' python: '3.9'
- os: macos-14 - os: macos-14
@@ -92,14 +94,11 @@ jobs:
ref: ${{ inputs.ref }} ref: ${{ inputs.ref }}
submodules: recursive submodules: recursive
- name: Setup Python - name: Install uv and Python
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v6
with: with:
python-version: ${{ matrix.python }} python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install system dependencies (Ubuntu) - name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux' if: runner.os == 'Linux'
run: | run: |
@@ -108,13 +107,46 @@ jobs:
pkg-config libabsl-dev libaio-dev libprotobuf-dev \ pkg-config libabsl-dev libaio-dev libprotobuf-dev \
patchelf patchelf
# Install Intel MKL for DiskANN # Debug: Show system information
echo "🔍 System Information:"
echo "Architecture: $(uname -m)"
echo "OS: $(uname -a)"
echo "CPU info: $(lscpu | head -5)"
# Install math library based on architecture
ARCH=$(uname -m)
echo "🔍 Setting up math library for architecture: $ARCH"
if [[ "$ARCH" == "x86_64" ]]; then
# Install Intel MKL for DiskANN on x86_64
echo "📦 Installing Intel MKL for x86_64..."
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
echo "✅ Intel MKL installed for x86_64"
# Debug: Check MKL installation
echo "🔍 MKL Installation Check:"
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
elif [[ "$ARCH" == "aarch64" ]]; then
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
echo "📦 Installing OpenBLAS for ARM64..."
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
echo "✅ OpenBLAS installed for ARM64"
# Debug: Check OpenBLAS installation
echo "🔍 OpenBLAS Installation Check:"
dpkg -l | grep openblas || echo "OpenBLAS package not found"
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
fi
# Debug: Show final library paths
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
- name: Install system dependencies (macOS) - name: Install system dependencies (macOS)
if: runner.os == 'macOS' if: runner.os == 'macOS'
@@ -124,11 +156,24 @@ jobs:
- name: Install build dependencies - name: Install build dependencies
run: | run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11 uv python install ${{ matrix.python }}
if [[ "$RUNNER_OS" == "Linux" ]]; then uv venv --python ${{ matrix.python }} .uv-build
uv pip install --system auditwheel if [[ "$RUNNER_OS" == "Windows" ]]; then
BUILD_PY=".uv-build\\Scripts\\python.exe"
else else
uv pip install --system delocate BUILD_PY=".uv-build/bin/python"
fi
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --python "$BUILD_PY" auditwheel
else
uv pip install --python "$BUILD_PY" delocate
fi
if [[ "$RUNNER_OS" == "Windows" ]]; then
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
else
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
fi fi
- name: Set macOS environment variables - name: Set macOS environment variables
@@ -264,18 +309,66 @@ jobs:
- name: Install built packages for testing - name: Install built packages for testing
run: | run: |
# Create a virtual environment with the correct Python version # Create uv-managed virtual environment with the requested interpreter
uv python install ${{ matrix.python }}
uv venv --python ${{ matrix.python }} uv venv --python ${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate source .venv/bin/activate || source .venv/Scripts/activate
# Install packages using --find-links to prioritize local builds if [[ "$RUNNER_OS" == "Windows" ]]; then
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz UV_PY=".venv\\Scripts\\python.exe"
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl else
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl UV_PY=".venv/bin/python"
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz fi
# Install test dependencies using extras # Install test dependency group only (avoids reinstalling project package)
uv pip install -e ".[test]" uv pip install --python "$UV_PY" --group test
# Install core wheel built in this job
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$CORE_WHL" ]]; then
uv pip install --python "$UV_PY" "$CORE_WHL"
else
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
fi
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
fi
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$HNSW_WHL" ]]; then
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$HNSW_WHL" ]]; then
uv pip install --python "$UV_PY" "$HNSW_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
fi
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$DISKANN_WHL" ]]; then
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$DISKANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$DISKANN_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
fi
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$LEANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$LEANN_WHL"
else
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
fi
- name: Run tests with pytest - name: Run tests with pytest
env: env:

11
.gitignore vendored
View File

@@ -18,10 +18,12 @@ demo/experiment_results/**/*.json
*.eml *.eml
*.emlx *.emlx
*.json *.json
*.png
!.vscode/*.json !.vscode/*.json
*.sh *.sh
*.txt *.txt
!CMakeLists.txt !CMakeLists.txt
!llms.txt
latency_breakdown*.json latency_breakdown*.json
experiment_results/eval_results/diskann/*.json experiment_results/eval_results/diskann/*.json
aws/ aws/
@@ -93,10 +95,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
batchtest.py batchtest.py
tests/__pytest_cache__/ tests/__pytest_cache__/
tests/__pycache__/ tests/__pycache__/
paru-bin/
CLAUDE.md
CLAUDE.local.md
.claude/*.local.*
.claude/local/*
benchmarks/data/ benchmarks/data/
## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py

3
.gitmodules vendored
View File

@@ -14,3 +14,6 @@
[submodule "packages/leann-backend-hnsw/third_party/libzmq"] [submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git url = https://github.com/zeromq/libzmq.git
[submodule "packages/astchunk-leann"]
path = packages/astchunk-leann
url = https://github.com/yichuan-w/astchunk-leann.git

View File

@@ -182,7 +182,10 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
### Generation Model Setup ### Generation Model Setup
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama). #### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
<details> <details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary> <summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
@@ -193,6 +196,68 @@ Set your OpenAI API key as an environment variable:
export OPENAI_API_KEY="your-api-key-here" export OPENAI_API_KEY="your-api-key-here"
``` ```
Make sure to use `--llm openai` flag when using the CLI.
You can also specify the model name with `--llm-model <model-name>` flag.
</details>
<details>
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
```sh
export OPENAI_API_KEY="xxx"
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
```
To use OpenAI compatible endpoint with the CLI interface:
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
-----
Below is a list of base URLs for common providers to get you started.
### 🖥️ Local Inference Engines (Recommended for full privacy)
| Provider | Sample Base URL |
| ---------------- | --------------------------- |
| **Ollama** | `http://localhost:11434/v1` |
| **LM Studio** | `http://localhost:1234/v1` |
| **vLLM** | `http://localhost:8000/v1` |
| **llama.cpp** | `http://localhost:8080/v1` |
| **SGLang** | `http://localhost:30000/v1` |
| **LiteLLM** | `http://localhost:4000` |
-----
### ☁️ Cloud Providers
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
| Provider | Base URL |
| ---------------- | ---------------------------------------------------------- |
| **OpenAI** | `https://api.openai.com/v1` |
| **OpenRouter** | `https://openrouter.ai/api/v1` |
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
| **x.AI (Grok)** | `https://api.x.ai/v1` |
| **Groq AI** | `https://api.groq.com/openai/v1` |
| **DeepSeek** | `https://api.deepseek.com/v1` |
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` |
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
</details> </details>
<details> <details>
@@ -546,6 +611,9 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents # Interactive chat with your documents
leann ask my-docs --interactive leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes # List all your indexes
leann list leann list
@@ -656,6 +724,19 @@ results = searcher.search(
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)** 📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
### 🔍 Grep Search
For exact text matching instead of semantic search, use the `use_grep` parameter:
```python
# Exact text search
results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
```
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
📖 **[Complete grep search guide →](docs/grep_search.md)**
## 🏗️ Architecture & How It Works ## 🏗️ Architecture & How It Works
<p align="center"> <p align="center">
@@ -693,9 +774,8 @@ results = searcher.search(
## Reproduce Our Results ## Reproduce Our Results
```bash ```bash
uv pip install -e ".[dev]" # Install dev dependencies uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
``` ```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data! The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!

View File

@@ -11,6 +11,7 @@ from typing import Any
import dotenv import dotenv
from leann.api import LeannBuilder, LeannChat from leann.api import LeannBuilder, LeannChat
from leann.registry import register_project_directory from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -78,6 +79,24 @@ class BaseRAGExample(ABC):
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama", help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
) )
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
# LLM parameters # LLM parameters
llm_group = parser.add_argument_group("LLM Parameters") llm_group = parser.add_argument_group("LLM Parameters")
@@ -97,8 +116,8 @@ class BaseRAGExample(ABC):
llm_group.add_argument( llm_group.add_argument(
"--llm-host", "--llm-host",
type=str, type=str,
default="http://localhost:11434", default=None,
help="Host for Ollama API (default: http://localhost:11434)", help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
) )
llm_group.add_argument( llm_group.add_argument(
"--thinking-budget", "--thinking-budget",
@@ -107,6 +126,18 @@ class BaseRAGExample(ABC):
default=None, default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.", help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
) )
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters # AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters") ast_group = parser.add_argument_group("AST Chunking Parameters")
@@ -205,9 +236,13 @@ class BaseRAGExample(ABC):
if args.llm == "openai": if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o" config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama": elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b" config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host config["host"] = resolve_ollama_host(args.llm_host)
elif args.llm == "hf": elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct" config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated": elif args.llm == "simulated":
@@ -223,10 +258,20 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...") print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}") print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder( builder = LeannBuilder(
backend_name=args.backend_name, backend_name=args.backend_name,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode, embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree, graph_degree=args.graph_degree,
complexity=args.build_complexity, complexity=args.build_complexity,
is_compact=not args.no_compact, is_compact=not args.no_compact,

View File

@@ -1,9 +1,16 @@
""" """Unified chunking utilities facade.
Chunking utilities for LEANN RAG applications.
Provides AST-aware and traditional text chunking functionality. This module re-exports the packaged utilities from `leann.chunking_utils` so
that both repo apps (importing `chunking`) and installed wheels share one
single implementation. When running from the repo without installation, it
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
""" """
from .utils import ( import sys
from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS, CODE_EXTENSIONS,
create_ast_chunks, create_ast_chunks,
create_text_chunks, create_text_chunks,
@@ -11,6 +18,21 @@ from .utils import (
detect_code_files, detect_code_files,
get_language_from_extension, get_language_from_extension,
) )
except Exception: # pragma: no cover - best-effort fallback for dev environment
repo_root = Path(__file__).resolve().parents[2]
leann_src = repo_root / "packages" / "leann-core" / "src"
if leann_src.exists():
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
else:
raise
__all__ = [ __all__ = [
"CODE_EXTENSIONS", "CODE_EXTENSIONS",

View File

@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
last_visit, url, title, visit_count, typed_count, hidden = row last_visit, url, title, visit_count, typed_count, _hidden = row
# Create document content with metadata embedded in text # Create document content with metadata embedded in text
doc_content = f""" doc_content = f"""

View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
def _ensure_repo_paths_importable(current_file: str) -> None:
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
class LeannMultiVector:
def __init__(
self,
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
) -> None:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
def _meta_dict(self) -> dict:
return {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": self.embedding_model_name,
"embedding_mode": "custom",
"dimensions": self.dim,
"backend_kwargs": self._backend_kwargs,
"is_compact": self._backend_kwargs.get("is_compact", True),
"is_pruned": self._backend_kwargs.get("is_compact", True)
and self._backend_kwargs.get("is_recompute", True),
}
def create_collection(self) -> None:
path = Path(self.index_path)
path.parent.mkdir(parents=True, exist_ok=True)
def insert(self, data: dict) -> None:
self._pending_items.append(
{
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
}
)
def _labels_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
def _meta_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def create_index(self) -> None:
if not self._pending_items:
return
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
labels_meta.append(
{
"id": f"{doc_id}:{seq_id}",
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
}
)
if not embeddings:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
# print shape of embeddings_np
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
if self._labels_meta:
return
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
distances = raw.get("distances")
if labels is None or distances is None:
return []
doc_scores: dict[int, float] = {}
B = len(labels)
for b in range(B):
per_doc_best: dict[int, float] = {}
for k, sid in enumerate(labels[b]):
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
else:
continue
score = float(distances[b][k])
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
per_doc_best[doc_id] = score
for doc_id, best_score in per_doc_best.items():
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores

View File

@@ -0,0 +1,477 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
_ensure_repo_paths_importable(__file__)
from leann_multi_vector import LeannMultiVector # noqa: E402
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train"
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 1
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 128
# %%
# Helpers
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in dataloader:
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
index_base = Path(index_path)
# Rough heuristic: index dir exists AND meta+labels files exist
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_base.exists() and meta.exists() and labels.exists():
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# %%
# Step 1: Prepare data
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset"):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
print(identifier)
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
# %%
# Step 2: Load model and processor
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 3: Build or load index
retriever: Optional[LeannMultiVector] = None
if not REBUILD_INDEX:
try:
one_vec = _embed_images(model, processor, [images[0]])[0]
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
except Exception:
retriever = None
if retriever is None:
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
# %%
# Step 4: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
path = filepaths[doc_id]
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(images[doc_id])
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
base = _Path(SAVE_TOP_IMAGE)
base.parent.mkdir(parents=True, exist_ok=True)
for rank, img in enumerate(top_images[:TOPK], start=1):
if base.suffix:
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 5: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
for rank, img in enumerate(top_images[:TOPK], start=1):
if output_base:
if output_base.suffix:
out_dir = output_base.parent
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
out_path = str(out_dir / out_name)
else:
out_dir = output_base
out_dir.mkdir(parents=True, exist_ok=True)
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
else:
out_path = None
chosen_idx, max_sim = _generate_similarity_map(
model=model,
processor=processor,
image=img,
query=QUERY,
token_idx=token_idx,
output_path=out_path,
)
if out_path:
print(
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
)
else:
print(
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
)
# %%
# Step 6: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
print("\nAnswer:")
print(response)

View File

@@ -0,0 +1,134 @@
# pip install pdf2image
# pip install pymilvus
# pip install colpali_engine
# pip install tqdm
# pip install pillow
# %%
from pdf2image import convert_from_path
pdf_path = "pdfs/2004.12832v2.pdf"
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f"pages/page_{i + 1}.png", "PNG")
# %%
import os
from pathlib import Path
# Make local leann packages importable without installing
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
import sys
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
from leann_multi_vector import LeannMultiVector
class LeannRetriever(LeannMultiVector):
pass
# %%
from typing import cast
import torch
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
# Auto-select device: CUDA > MPS (mac) > CPU
_device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(_device_str)
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=_dtype,
device_map=device,
).eval()
print(f"Using device={_device_str}, dtype={_dtype}")
queries = [
"How to end-to-end retrieval with ColBert",
"Where is ColBERT performance Table, including text representation results?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: list[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
import re
from PIL import Image
from tqdm import tqdm
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: list[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
# %%
# Build HNSW index via LeannRetriever primitives and run search
index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
for query in qs:
query_np = query.float().numpy()
result = retriever.search(query_np, topk=1)
print(filepaths[result[0][1]])

0
benchmarks/__init__.py Normal file
View File

View File

@@ -0,0 +1,23 @@
BM25 vs DiskANN Baselines
```bash
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
```
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
- Machine-specific; results measured locally with the current repo.
DiskANN (NQ queries, search-only)
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
BM25
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
Notes
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -0,0 +1,183 @@
# /// script
# dependencies = [
# "pyserini"
# ]
# ///
# sudo pacman -S jdk21-openjdk
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
# sudo archlinux-java status
# sudo archlinux-java set java-21-openjdk
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
# fish_add_path --global $JAVA_HOME/bin
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
import argparse
import json
import os
import sys
import time
from statistics import mean
def load_queries(path: str, limit: int | None) -> list[str]:
queries: list[str] = []
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
_, ext = os.path.splitext(path)
if ext.lower() in {".jsonl", ".json"}:
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# Not strict JSONL? treat the whole line as the query
queries.append(line)
continue
q = obj.get("query") or obj.get("text") or obj.get("question")
if q:
queries.append(str(q))
else:
with open(path, encoding="utf-8") as f:
for line in f:
s = line.strip()
if s:
queries.append(s)
if limit is not None and limit > 0:
queries = queries[:limit]
return queries
def percentile(values: list[float], p: float) -> float:
if not values:
return 0.0
s = sorted(values)
k = (len(s) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(s) - 1)
if f == c:
return s[f]
return s[f] + (s[c] - s[f]) * (k - f)
def main():
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
ap.add_argument(
"--bm25-index",
default="benchmarks/data/indices/bm25_index",
help="Path to Pyserini Lucene index directory",
)
ap.add_argument(
"--queries",
default="benchmarks/data/queries/nq_open.jsonl",
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
)
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
ap.add_argument(
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
)
ap.add_argument(
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
)
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
args = ap.parse_args()
try:
from pyserini.search.lucene import LuceneSearcher
except Exception:
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
raise
if not os.path.isdir(args.bm25_index):
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
sys.exit(1)
queries = load_queries(args.queries, args.limit)
if not queries:
print("No queries loaded.", file=sys.stderr)
sys.exit(1)
print(f"Loaded {len(queries)} queries from {args.queries}")
print(f"Opening BM25 index: {args.bm25_index}")
searcher = LuceneSearcher(args.bm25_index)
# Some builds of pyserini require explicit set_bm25; others ignore
try:
searcher.set_bm25(k1=args.k1, b=args.b)
except Exception:
pass
latencies: list[float] = []
total_searches = 0
# Warmup
for i in range(min(args.warmup, len(queries))):
_ = searcher.search(queries[i], k=args.k)
t0 = time.time()
for i, q in enumerate(queries):
t1 = time.time()
hits = searcher.search(q, k=args.k)
t2 = time.time()
latencies.append(t2 - t1)
total_searches += 1
if args.fetch_docs:
# Optional doc fetch to include I/O time
for h in hits:
try:
_ = searcher.doc(h.docid)
except Exception:
pass
if (i + 1) % 50 == 0:
print(f"Processed {i + 1}/{len(queries)} queries")
t1 = time.time()
total_time = t1 - t0
if latencies:
avg = mean(latencies)
p50 = percentile(latencies, 50)
p90 = percentile(latencies, 90)
p95 = percentile(latencies, 95)
p99 = percentile(latencies, 99)
qps = total_searches / total_time if total_time > 0 else 0.0
else:
avg = p50 = p90 = p95 = p99 = qps = 0.0
print("BM25 Latency Report")
print(f" queries: {total_searches}")
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
print(f" avg per query: {avg:.6f} s")
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
if args.report:
payload = {
"queries": total_searches,
"k": args.k,
"k1": args.k1,
"b": args.b,
"avg_s": avg,
"p50_s": p50,
"p90_s": p90,
"p95_s": p95,
"p99_s": p99,
"total_time_s": total_time,
"qps": qps,
"index_dir": os.path.abspath(args.bm25_index),
"fetch_docs": bool(args.fetch_docs),
}
with open(args.report, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
print(f"Saved report to {args.report}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,124 @@
# /// script
# dependencies = [
# "leann-backend-diskann"
# ]
# ///
import argparse
import json
import time
from pathlib import Path
import numpy as np
def load_queries(path: Path, limit: int | None) -> list[str]:
out: list[str] = []
with open(path, encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
out.append(obj["query"])
if limit and len(out) >= limit:
break
return out
def main() -> None:
ap = argparse.ArgumentParser(
description="DiskANN baseline on real NQ queries (search-only timing)"
)
ap.add_argument(
"--index-dir",
default="benchmarks/data/indices/diskann_rpj_wiki",
help="Directory containing DiskANN files",
)
ap.add_argument("--index-prefix", default="ann")
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
ap.add_argument("--num-queries", type=int, default=200)
ap.add_argument("--top-k", type=int, default=10)
ap.add_argument("--complexity", type=int, default=62)
ap.add_argument("--threads", type=int, default=1)
ap.add_argument("--beam-width", type=int, default=1)
ap.add_argument("--cache-mechanism", type=int, default=2)
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
args = ap.parse_args()
index_dir = Path(args.index_dir).resolve()
if not index_dir.is_dir():
raise SystemExit(f"Index dir not found: {index_dir}")
qpath = Path(args.queries_file).resolve()
if not qpath.exists():
raise SystemExit(f"Queries file not found: {qpath}")
queries = load_queries(qpath, args.num_queries)
print(f"Loaded {len(queries)} queries from {qpath}")
# Compute embeddings once (exclude from timing)
from leann.api import compute_embeddings as _compute
embs = _compute(
queries,
model_name="facebook/contriever-msmarco",
mode="sentence-transformers",
use_server=False,
).astype(np.float32)
if embs.ndim != 2:
raise SystemExit("Embedding compute failed or returned wrong shape")
# Build searcher
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
index_prefix_path = str(index_dir / args.index_prefix)
searcher = _DiskannSearcher(
index_prefix_path,
num_threads=int(args.threads),
cache_mechanism=int(args.cache_mechanism),
num_nodes_to_cache=int(args.num_nodes_to_cache),
)
# Warmup (not timed)
_ = searcher.search(
embs[0:1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
# Timed loop
times: list[float] = []
for i in range(embs.shape[0]):
t0 = time.time()
_ = searcher.search(
embs[i : i + 1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
times.append(time.time() - t0)
times_sorted = sorted(times)
avg = float(sum(times) / len(times))
p50 = times_sorted[len(times) // 2]
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
print("\nDiskANN (NQ, search-only) Report")
print(f" queries: {len(times)}")
print(
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
)
print(f" avg per query: {avg:.6f} s")
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
print(f" QPS: {1.0 / avg:.2f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,141 @@
# Enron Emails Benchmark
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
## Layout
benchmarks/enron_emails/
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
- data/: Generated passages, queries, embeddings-related files
- baseline/: FAISS Flat baseline files
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
## Quickstart
1) Prepare the data and index
cd benchmarks/enron_emails
python setup_enron_emails.py --data-dir data
Notes:
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
Alternatively, pass a local path to `--emails-csv`.
Notes:
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
2) Run recall evaluation (Stage 2)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
3) Complexity sweep (Stage 3)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
4) Index comparison (Stage 4)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
5) Generation evaluation (Stage 5)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
6) Combined index + generation evaluation (Stages 4+5, recommended)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
Notes:
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
- `--baseline-dir` defaults to `baseline`
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
- `--model-name` defaults to `Qwen/Qwen3-8B`
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
Optional flags:
- --queries data/evaluation_queries.jsonl (custom queries file)
- --baseline-dir baseline (where FAISS baseline lives)
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
- --llm-backend hf|vllm (LLM backend for generation)
- --model-name Qwen/Qwen3-8B (LLM model for generation)
- --max-queries 1000 (limit number of queries for evaluation)
## Files Produced
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
- data/enron_index_hnsw.leann.*: LEANN index files
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
## Notes
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
## Stages Summary
- Stage 2 (Recall@3):
- Compares LEANN vs FAISS Flat baseline on Recall@3.
- Compact index runs with `recompute_embeddings=True`.
- Stage 3 (Binary Search for Complexity):
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
- Stage 4 (Index Comparison):
- Reports .index-only sizes for compact vs non-compact.
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
- Stores retrieval results for Stage 5 generation evaluation.
- Fails fast if compact recompute cannot run.
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
- First from the current run (when running `--stage all`), otherwise
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
- Stage 5 (Generation Evaluation):
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
- Measures generation timing separately from search timing.
- Requires Stage 4 results (no additional searching performed).
## Example Results
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
- Stage 3 (Binary Search):
- Minimal complexity achieving 90% Recall@3: 88
- Sampled points:
- C=8 → 59.9% Recall@3
- C=72 → 89.4% Recall@3
- C=88 → 90.2% Recall@3
- C=96 → 90.7% Recall@3
- C=112 → 91.1% Recall@3
- C=136 → 91.3% Recall@3
- C=256 → 92.0% Recall@3
- Stage 4 (Index Sizes, .index only):
- Compact: ~2.2 MB
- Non-compact: ~82.0 MB
- Storage saving by compact: ~97.3%
- Stage 4 (Search Timing, 988 queries, complexity=88):
- Non-compact (no recompute): ~0.0075 s avg per query
- Compact (with recompute): ~1.981 s avg per query
- Speed ratio (non-compact/compact): ~0.0038x
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
- Average generation time: ~22.302 s per query
- Total queries processed: 988
- LLM backend: HuggingFace transformers
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
Full JSON output is saved by the script (see `--output`), e.g.:
`benchmarks/enron_emails/results_enron_stage45.json`.

View File

@@ -0,0 +1 @@
downloads/

View File

@@ -0,0 +1,614 @@
"""
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
On errors, fail fast without fallbacks.
"""
import argparse
import json
import logging
import os
import pickle
from pathlib import Path
import numpy as np
from leann import LeannBuilder, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 using FAISS Flat as ground truth"""
from leann.api import compute_embeddings
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
for i, query in enumerate(queries):
# Compute query embedding with the same model/mode as the index
q_emb = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS Flat ground truth
n = q_emb.shape[0]
k = 3
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(q_emb),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN (may require embedding server depending on index configuration)
results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {r.id for r in results}
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0
total_recall += recall
if i < 3:
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS: {list(baseline_ids)}")
print(f" LEANN: {list(test_ids)}")
print(f" ∩: {list(intersection)}")
avg = total_recall / max(1, len(queries))
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
return avg
def cleanup(self):
if hasattr(self, "searcher"):
self.searcher.cleanup()
class EnronEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
queries: list[str] = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
if "query" in data:
queries.append(data["query"])
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
return queries
def cleanup(self):
if self.searcher:
self.searcher.cleanup()
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes (.index only), similar to LAION bench."""
print("📏 Analyzing index sizes (.index only)...")
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem
sizes: dict[str, float] = {}
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json"
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
sizes["index_only_mb"] = (
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison using current passages and embeddings."""
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source and embedding model
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
# Load all passages and ids
ids: list[str] = []
texts: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
texts.append(data["text"])
# Compute embeddings using the same method as LEANN
from leann.api import compute_embeddings
embeddings = compute_embeddings(
texts,
meta["embedding_model"],
mode=meta.get("embedding_mode", "sentence-transformers"),
use_server=False,
).astype(np.float32)
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False,
is_compact=False,
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = EnronEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
) -> dict:
"""Compare search speed for non-compact vs compact indexes."""
import time
results: dict = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
"retrieval_results": [], # Store retrieval results for Stage 5
}
print("⚡ Comparing search performance between indexes...")
# Non-compact (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for q in test_queries:
t0 = time.time()
_ = non_compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=False
)
results["non_compact"]["search_times"].append(time.time() - t0)
# Compact (with recompute). Fail fast if it cannot run.
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for q in test_queries:
t0 = time.time()
docs = compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=True
)
results["compact"]["search_times"].append(time.time() - t0)
# Store retrieval results for Stage 5
results["retrieval_results"].append(
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
)
compact_searcher.cleanup()
if results["non_compact"]["search_times"]:
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
if results["compact"]["search_times"]:
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
if results["avg_search_times"].get("compact", 0) > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = 0.0
non_compact_searcher.cleanup()
return results
def evaluate_complexity(
self,
recall_eval: "RecallEvaluator",
queries: list[str],
target: float = 0.90,
c_min: int = 8,
c_max: int = 256,
max_iters: int = 10,
recompute: bool = False,
) -> dict:
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
def round_c(x: int) -> int:
# snap to multiple of 8 like other benches typically do
return max(1, int((x + 7) // 8) * 8)
metrics: list[dict] = []
lo = round_c(c_min)
hi = round_c(c_max)
print(
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
)
# Ensure upper bound can reach target; expand if needed (up to a cap)
r_lo = recall_eval.evaluate_recall_at_3(
queries, complexity=lo, recompute_embeddings=recompute
)
metrics.append({"complexity": lo, "recall_at_3": r_lo})
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
cap = 1024
while r_hi < target and hi < cap:
lo = hi
r_lo = r_hi
hi = round_c(hi * 2)
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
if r_hi < target:
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
print("📈 Observations:")
for m in metrics:
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
# Binary search within [lo, hi]
best = hi
iters = 0
while lo < hi and iters < max_iters:
mid = round_c((lo + hi) // 2)
r_mid = recall_eval.evaluate_recall_at_3(
queries, complexity=mid, recompute_embeddings=recompute
)
metrics.append({"complexity": mid, "recall_at_3": r_mid})
if r_mid >= target:
best = mid
hi = mid
else:
lo = mid + 8 # move past mid, respecting multiple-of-8 step
iters += 1
print("📈 Binary search results (sampled points):")
# Print unique complexity entries ordered by complexity
for m in sorted(
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
):
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
def main():
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all", "45"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument(
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
)
parser.add_argument(
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
)
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
args = parser.parse_args()
# Resolve queries file: if default path not found, fall back to index's directory
if not os.path.exists(args.queries):
from pathlib import Path
idx_dir = Path(args.index).parent
fallback_q = idx_dir / "evaluation_queries.jsonl"
if fallback_q.exists():
args.queries = str(fallback_q)
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_enron_emails.py first to build the baseline")
raise SystemExit(1)
results_out: dict = {}
if args.stage in ("2", "all"):
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[:10]
print(f"🧪 Using first {len(queries)} queries")
complexity = args.complexity or 64
r = evaluator.evaluate_recall_at_3(queries, complexity)
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 2 completed!\n")
if args.stage in ("3", "all"):
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[: args.max_queries]
print(f"🧪 Using first {len(queries)} queries")
# Build non-compact index for fast binary search (recompute_embeddings=False)
from pathlib import Path
index_path = Path(args.index)
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact evaluator for binary search with recompute=False
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
sweep = enron_eval.evaluate_complexity(
evaluator_nc, queries, target=args.target_recall, recompute=False
)
results_out["stage3"] = sweep
# Persist default stage 3 results near the index for Stage 4 auto-pickup
from pathlib import Path
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
with open(default_stage3_path, "w", encoding="utf-8") as f:
json.dump({"stage3": sweep}, f, indent=2)
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
evaluator_nc.cleanup()
enron_eval.cleanup()
print("✅ Stage 3 completed!\n")
if args.stage in ("4", "all", "45"):
print("🚀 Starting Stage 4: Index size + performance comparison")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
test_q = queries[: min(args.max_queries, len(queries))]
current_sizes = enron_eval.analyze_index_sizes()
# Build non-compact index for comparison (no fallback)
from pathlib import Path
index_path = Path(args.index)
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
nc_eval = EnronEvaluator(non_compact_path)
if (
current_sizes.get("index_only_mb", 0) > 0
and non_compact_sizes.get("index_only_mb", 0) > 0
):
storage_saving_percent = max(
0.0,
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
)
else:
storage_saving_percent = 0.0
if args.complexity is None:
# Prefer in-session Stage 3 result
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
complexity = results_out["stage3"]["best_complexity"]
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
else:
# Try to load last saved Stage 3 result near index
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
if default_stage3_path.exists():
with open(default_stage3_path, encoding="utf-8") as f:
prev = json.load(f)
complexity = prev.get("stage3", {}).get("best_complexity")
if complexity is None:
raise SystemExit(
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
)
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
else:
raise SystemExit(
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
)
else:
complexity = args.complexity
comp = enron_eval.compare_index_performance(
non_compact_path, args.index, test_q, complexity=complexity
)
results_out["stage4"] = {
"current_index": current_sizes,
"non_compact_index": non_compact_sizes,
"storage_saving_percent": storage_saving_percent,
"performance_comparison": comp,
}
nc_eval.cleanup()
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
# Check if Stage 4 results exist
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
print("❌ Stage 5 requires Stage 4 retrieval results")
print("💡 Run Stage 4 first or use --stage all")
raise SystemExit(1)
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
if not retrieval_results:
print("❌ No retrieval results found from Stage 4")
raise SystemExit(1)
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
# Load LLM
try:
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Run generation using stored retrieval results
import time
from llm_utils import create_prompt
generation_times = []
responses = []
print("🤖 Running generation on pre-retrieved results...")
for i, item in enumerate(retrieval_results):
query = item["query"]
retrieved_docs = item["retrieved_docs"]
# Prepare context from retrieved docs
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
prompt = create_prompt(context, query, "emails")
# Time generation only
gen_start = time.time()
response = llm_func(prompt)
gen_time = time.time() - gen_start
generation_times.append(gen_time)
responses.append(response)
if i < 3:
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
avg_gen_time = sum(generation_times) / len(generation_times)
print("\n📊 Generation Results:")
print(f" Total Queries: {len(retrieval_results)}")
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
print(" (Search time from Stage 4)")
results_out["stage5"] = {
"total_queries": len(retrieval_results),
"avg_generation_time": avg_gen_time,
"generation_times": generation_times,
"responses": responses,
}
# Show sample results
print("\n📝 Sample Results:")
for i in range(min(3, len(retrieval_results))):
query = retrieval_results[i]["query"]
response = responses[i]
print(f" Q{i + 1}: {query[:60]}...")
print(f" A{i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
print("💡 Make sure transformers/vllm is installed and model is available")
print("✅ Stage 5 completed!\n")
if args.output and results_out:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results_out, f, indent=2)
print(f"📝 Saved results to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,359 @@
"""
Enron Emails Benchmark Setup Script
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
"""
import argparse
import csv
import json
import os
import re
from collections.abc import Iterable
from email import message_from_string
from email.policy import default
from pathlib import Path
from typing import Optional
from leann import LeannBuilder
class EnronSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
self.index_path = self.data_dir / "enron_index_hnsw.leann"
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
self.downloads_dir = self.data_dir / "downloads"
self.downloads_dir.mkdir(parents=True, exist_ok=True)
# ----------------------------
# Dataset acquisition
# ----------------------------
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
"""Return a path to emails.csv, downloading from Kaggle if needed."""
if emails_csv:
p = Path(emails_csv)
if not p.exists():
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
return str(p)
print(
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
)
try:
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
api.dataset_download_files(
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
)
candidate = self.downloads_dir / "emails.csv"
if candidate.exists():
print(f"✅ Downloaded emails.csv: {candidate}")
return str(candidate)
else:
raise FileNotFoundError(
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
)
except Exception as e:
print(
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
)
print(
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
)
raise e
# ----------------------------
# Data preparation
# ----------------------------
@staticmethod
def _extract_message_id(raw_email: str) -> str:
msg = message_from_string(raw_email, policy=default)
val = msg.get("Message-ID", "")
if val.startswith("<") and val.endswith(">"):
val = val[1:-1]
return val or ""
@staticmethod
def _split_header_body(raw_email: str) -> tuple[str, str]:
parts = raw_email.split("\n\n", 1)
if len(parts) == 2:
return parts[0].strip(), parts[1].strip()
# Heuristic fallback
first_lines = raw_email.splitlines()
if first_lines and ":" in first_lines[0]:
return raw_email.strip(), ""
return "", raw_email.strip()
@staticmethod
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
text = (text or "").strip()
if not text:
return []
if chunk_words <= 0:
return [text]
words = text.split()
if not words:
return []
limit = len(words)
if not keep_last:
limit = (len(words) // chunk_words) * chunk_words
if limit == 0:
return []
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
return [c for c in (s.strip() for s in chunks) if c]
def _iter_passages_from_csv(
self,
emails_csv: Path,
chunk_words: int = 256,
keep_last_header: bool = True,
keep_last_body: bool = True,
max_emails: int | None = None,
) -> Iterable[dict]:
with open(emails_csv, encoding="utf-8") as f:
reader = csv.DictReader(f)
count = 0
for i, row in enumerate(reader):
if max_emails is not None and count >= max_emails:
break
raw_message = row.get("message", "")
email_file_id = row.get("file", "")
if not raw_message.strip():
continue
message_id = self._extract_message_id(raw_message)
if not message_id:
# Fallback ID based on CSV position and file path
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
message_id = f"enron_{i}_{safe_file}"
header, body = self._split_header_body(raw_message)
# Header chunks
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": True,
"email_file_id": email_file_id,
},
}
# Body chunks
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": False,
"email_file_id": email_file_id,
},
}
count += 1
# ----------------------------
# Build LEANN index and FAISS baseline
# ----------------------------
def build_leann_index(
self,
emails_csv: Optional[str],
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
chunk_words: int = 256,
max_emails: int | None = None,
) -> str:
emails_csv_path = self.ensure_emails_csv(emails_csv)
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True,
is_compact=True,
num_threads=4,
)
# Stream passages and add to builder
preview_written = 0
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
for p in self._iter_passages_from_csv(
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
):
builder.add_text(p["text"], metadata=p["metadata"])
if preview_written < 200:
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
preview_written += 1
print(f"🔨 Building index at {self.index_path}...")
builder.build_index(str(self.index_path))
print("✅ LEANN index built!")
return str(self.index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
print("🔨 Building FAISS Flat baseline from LEANN passages...")
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read meta for passage source and embedding model
meta_path = f"{index_path}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
# Load passages from builder output so IDs match LEANN
passages: list[str] = []
passage_ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"]) # builder-assigned ID
print(f"📄 Loaded {len(passages)} passages for baseline")
print(f"🤖 Embedding model: {embedding_model}")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
# Build FAISS IndexFlatIP
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
emb_f32 = embeddings.astype(np.float32)
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as pf:
pickle.dump(passage_ids, pf)
print(f"✅ FAISS baseline saved: {baseline_path}")
print(f"✅ Metadata saved: {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
# ----------------------------
# Queries (optional): prepare evaluation queries file
# ----------------------------
def prepare_queries(self, min_realism: float = 0.85) -> Path:
print(
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
)
try:
from datasets import load_dataset
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
except Exception as e:
print(f"⚠️ Failed to load dataset: {e}")
return self.queries_file
kept = 0
with open(self.queries_file, "w", encoding="utf-8") as out:
for i, item in enumerate(ds):
how_realistic = float(item.get("how_realistic", 0.0))
if how_realistic < min_realism:
continue
qid = str(item.get("id", f"enron_q_{i}"))
query = item.get("question", "")
if not query:
continue
record = {
"id": qid,
"query": query,
# For reference only, not used in recall metric below
"gt_message_ids": item.get("message_ids", []),
}
out.write(json.dumps(record) + "\n")
kept += 1
print(f"✅ Wrote {kept} queries to {self.queries_file}")
return self.queries_file
def main():
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
parser.add_argument(
"--emails-csv",
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
)
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model for LEANN",
)
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
args = parser.parse_args()
setup = EnronSetup(args.data_dir)
# Build index
if not args.skip_build:
index_path = setup.build_leann_index(
emails_csv=args.emails_csv,
backend=args.backend,
embedding_model=args.embedding_model,
chunk_words=args.chunk_words,
max_emails=args.max_emails,
)
# Build FAISS baseline from the same passages & embeddings
setup.build_faiss_flat_baseline(index_path)
else:
print("⏭️ Skipping LEANN index build and baseline")
# Queries file (optional)
if not args.skip_queries:
setup.prepare_queries()
else:
print("⏭️ Skipping query preparation")
print("\n🎉 Enron Emails setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("Next steps:")
print(
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,115 @@
# FinanceBench Benchmark for LEANN-RAG
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
## Dataset
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
- **Questions**: 150 financial Q&A examples
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
## Structure
```
benchmarks/financebench/
├── setup_financebench.py # Downloads PDFs and builds index
├── evaluate_financebench.py # Intelligent evaluation script
├── data/
│ ├── financebench_merged.jsonl # Q&A dataset
│ ├── pdfs/ # Downloaded financial documents
│ └── index/ # LEANN indexes
│ └── financebench_full_hnsw.leann
└── README.md
```
## Usage
### 1. Setup (Download & Build Index)
```bash
cd benchmarks/financebench
python setup_financebench.py
```
This will:
- Download the 150 Q&A examples
- Download all 368 PDF documents (parallel processing)
- Build a LEANN index from 53K+ text chunks
- Verify setup with test query
### 2. Evaluation
```bash
# Basic retrieval evaluation
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
# RAG generation evaluation with Qwen3-8B
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
```
## Evaluation Methods
### Retrieval Evaluation
Uses intelligent matching with three strategies:
1. **Exact text overlap** - Direct substring matches
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
3. **Semantic similarity** - Word overlap with 20% threshold
### QA Evaluation
LLM-based answer evaluation using GPT-4o:
- Handles numerical rounding and equivalent representations
- Considers fractions, percentages, and decimal equivalents
- Evaluates semantic meaning rather than exact text match
## Benchmark Results
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
**Retrieval Metrics:**
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
- **Number Match Rate**: 120.7% (key financial figures matched)*
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
- **Average Search Time**: 0.097s
**QA Metrics:**
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
- **Average QA Time**: 4.71s (end-to-end response time)
**System Performance:**
- **Index Size**: 53,985 chunks from 368 PDFs
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
### LEANN-RAG Generation Performance (Qwen3-8B)
- **Stage 4 (Index Comparison):**
- Compact Index: 5.0 MB
- Non-compact Index: 172.2 MB
- **Storage Saving**: 97.1%
- **Search Performance**:
- Non-compact (no recompute): 0.009s avg per query
- Compact (with recompute): 2.203s avg per query
- Speed ratio: 0.004x
**Generation Evaluation (20 queries, complexity=64):**
- **Average Search Time**: 1.638s per query
- **Average Generation Time**: 45.957s per query
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
- **Total Questions Processed**: 20
## Options
```bash
# Use different backends
python setup_financebench.py --backend diskann
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
# Use different embedding models
python setup_financebench.py --embedding-model facebook/contriever
```

View File

@@ -0,0 +1,923 @@
"""
FinanceBench Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
from typing import Optional
import numpy as np
import openai
from leann import LeannChat, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for given queries at specified complexity"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(queries)
for i, query in enumerate(queries):
# Get ground truth: search with FAISS flat
from leann.api import compute_embeddings
query_embedding = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity
test_results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class FinanceBenchEvaluator:
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
self.index_path = index_path
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
self.searcher = LeannSearcher(index_path)
self.chat = LeannChat(index_path) if openai_api_key else None
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
"""Load FinanceBench dataset"""
data = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data.append(json.loads(line))
print(f"📊 Loaded {len(data)} FinanceBench examples")
return data
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes with and without embeddings"""
print("📏 Analyzing index sizes...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes = {}
total_with_embeddings = 0
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
for file_path, name in [
(index_file, "index"),
(meta_file, "metadata"),
(passages_file, "passages_text"),
(passages_idx_file, "passages_index"),
]:
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
sizes[name] = size_mb
total_with_embeddings += size_mb
else:
sizes[name] = 0
sizes["total_with_embeddings"] = total_with_embeddings
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
return sizes
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
"""Create a compact index for comparison purposes"""
print("🏗️ Building compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage
**meta.get("backend_kwargs", {}),
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building compact index at {compact_index_path}...")
builder.build_index(compact_index_path)
# Analyze the compact index size
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
compact_sizes = temp_evaluator.analyze_index_sizes()
compact_sizes["index_type"] = "compact"
return compact_sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build non-compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
builder.build_index(non_compact_index_path)
# Analyze the non-compact index size
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
import time
from leann import LeannSearcher
# Test queries
test_queries = [item["question"] for item in test_data[:5]]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for query in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for query in test_queries:
start_time = time.time()
_ = compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def evaluate_timing_breakdown(
self, data: list[dict], max_samples: Optional[int] = None
) -> dict:
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
if not self.chat or not self.openai_client:
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
return {
"total_questions": 0,
"avg_search_time": 0.0,
"avg_generation_time": 0.0,
"avg_total_time": 0.0,
"accuracy": 0.0,
}
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
if max_samples:
data = data[:max_samples]
print(f"📝 Using first {max_samples} samples for timing evaluation")
search_times = []
generation_times = []
total_times = []
correct_answers = 0
for i, item in enumerate(data):
question = item["question"]
ground_truth = item["answer"]
try:
# Hack: Monkey-patch the ask method to capture internal timing
original_ask = self.chat.ask
captured_search_time = None
captured_generation_time = None
def patched_ask(*args, **kwargs):
nonlocal captured_search_time, captured_generation_time
# Time the search part
search_start = time.time()
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
captured_search_time = time.time() - search_start
# Time the generation part
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {args[0]}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
generation_start = time.time()
answer = self.chat.llm.ask(prompt)
captured_generation_time = time.time() - generation_start
return answer
# Apply the patch
self.chat.ask = patched_ask
# Time the total QA
total_start = time.time()
generated_answer = self.chat.ask(question)
total_time = time.time() - total_start
# Restore original method
self.chat.ask = original_ask
# Store the timings
search_times.append(captured_search_time)
generation_times.append(captured_generation_time)
total_times.append(total_time)
# Check accuracy using LLM as judge
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
if is_correct:
correct_answers += 1
status = "" if is_correct else ""
print(
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
)
print(f" GT: {ground_truth}")
print(f" Gen: {generated_answer[:100]}...")
except Exception as e:
print(f" ❌ Error: {e}")
search_times.append(0.0)
generation_times.append(0.0)
total_times.append(0.0)
accuracy = correct_answers / len(data) if data else 0.0
metrics = {
"total_questions": len(data),
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
"avg_generation_time": sum(generation_times) / len(generation_times)
if generation_times
else 0.0,
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
"accuracy": accuracy,
"correct_answers": correct_answers,
"search_times": search_times,
"generation_times": generation_times,
"total_times": total_times,
}
return metrics
def _check_answer_accuracy(
self, generated_answer: str, ground_truth: str, question: str
) -> bool:
"""Check if generated answer matches ground truth using LLM as judge"""
judge_prompt = f"""You are an expert judge evaluating financial question answering.
Question: {question}
Ground Truth Answer: {ground_truth}
Generated Answer: {generated_answer}
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
1. Numerical accuracy (exact values, units, currency)
2. Key financial concepts and terminology
3. Overall factual correctness
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
try:
judge_response = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=10,
temperature=0,
)
judgment = judge_response.choices[0].message.content.strip().upper()
return judgment == "CORRECT"
except Exception as e:
print(f" ⚠️ Judge error: {e}, falling back to string matching")
# Fallback to simple string matching
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
return gt_clean in gen_clean
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 EVALUATION RESULTS")
print("=" * 50)
# Index comparison analysis
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
print("\n📏 Index Comparison Analysis:")
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
print("\n📊 Accuracy:")
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
print(
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
)
print("\n📊 Timing Breakdown:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
if timing_metrics.get("avg_total_time", 0) > 0:
search_pct = (
timing_metrics.get("avg_search_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
gen_pct = (
timing_metrics.get("avg_generation_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
print("\n📈 Time Distribution:")
print(f" Search: {search_pct:.1f}%")
print(f" Generation: {gen_pct:.1f}%")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
parser.add_argument(
"--stage",
choices=["2", "3", "4", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
)
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_financebench.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load FinanceBench queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Test with more queries for robust measurement
test_queries = queries[:2000]
print(f"🧪 Testing with {len(test_queries)} queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search (will be reused in Stage 4)
print("🏗️ Creating non-compact index for binary search...")
evaluator = FinanceBenchEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Use more queries for robust measurement
test_queries = queries[:200]
print(f"🧪 Testing with {len(test_queries)} queries")
# Binary search for 90% recall complexity (without recompute for speed)
target_recall = 0.9
min_complexity, max_complexity = 1, 32
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_queries[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Comprehensive evaluation with dual index comparison
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
# Use FinanceBench evaluator for QA evaluation
evaluator = FinanceBenchEvaluator(
args.index, args.openai_api_key if args.llm_backend == "openai" else None
)
print("📖 Loading FinanceBench dataset...")
data = evaluator.load_dataset(args.dataset)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
from pathlib import Path
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes
print("\n📊 Index size comparison:")
print(
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print(
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print("\n📊 Index-only size comparison (.index file only):")
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
# Use index-only size for fair comparison (same as Enron emails)
storage_saving = (
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
/ non_compact_size_metrics["index_only_mb"]
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for performance comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, data[:10], complexity=complexity
)
# Step 5: Generation evaluation
test_samples = 20
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
if args.llm_backend == "openai" and args.openai_api_key:
print("🔍🤖 Running OpenAI-based generation evaluation...")
evaluation_start = time.time()
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
evaluation_time = time.time() - evaluation_start
else:
print(
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
)
try:
# Load LLM
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Simple generation evaluation
queries = [item["question"] for item in data[:test_samples]]
gen_results = evaluate_rag(
evaluator.searcher,
llm_func,
queries,
domain="finance",
complexity=complexity,
)
timing_metrics = {
"total_questions": len(queries),
"avg_search_time": gen_results["avg_search_time"],
"avg_generation_time": gen_results["avg_generation_time"],
"results": gen_results["results"],
}
evaluation_time = time.time()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
timing_metrics = {
"total_questions": 0,
"avg_search_time": 0,
"avg_generation_time": 0,
}
evaluation_time = 0
# Combine all metrics
combined_metrics = {
**timing_metrics,
"total_evaluation_time": evaluation_time,
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print results
print("\n📊 Generation Results:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review accuracy and timing breakdown for performance optimization")
print(" - Run full evaluation on complete dataset if needed")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
from pathlib import Path
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,462 @@
#!/usr/bin/env python3
"""
FinanceBench Complete Setup Script
Downloads all PDFs and builds full LEANN datastore
"""
import argparse
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import pymupdf
import requests
from leann import LeannBuilder, LeannSearcher
from tqdm import tqdm
class FinanceBenchSetup:
def __init__(self, data_dir: str = "data"):
self.base_dir = Path(__file__).parent # benchmarks/financebench/
self.data_dir = self.base_dir / data_dir
self.pdf_dir = self.data_dir / "pdfs"
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
self.index_dir = self.data_dir / "index"
self.download_lock = Lock()
def download_dataset(self):
"""Download the main FinanceBench dataset"""
print("📊 Downloading FinanceBench dataset...")
self.data_dir.mkdir(parents=True, exist_ok=True)
if self.dataset_file.exists():
print(f"✅ Dataset already exists: {self.dataset_file}")
return
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
response = requests.get(url, stream=True)
response.raise_for_status()
with open(self.dataset_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Dataset downloaded: {self.dataset_file}")
def get_pdf_list(self):
"""Get list of all PDF files from GitHub"""
print("📋 Fetching PDF list from GitHub...")
response = requests.get(
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
)
response.raise_for_status()
pdf_files = response.json()
print(f"Found {len(pdf_files)} PDF files")
return pdf_files
def download_single_pdf(self, pdf_info, position):
"""Download a single PDF file"""
pdf_name = pdf_info["name"]
pdf_path = self.pdf_dir / pdf_name
# Skip if already downloaded
if pdf_path.exists() and pdf_path.stat().st_size > 0:
return f"{pdf_name} (cached)"
try:
# Download PDF
response = requests.get(pdf_info["download_url"], timeout=60)
response.raise_for_status()
# Write to file
with self.download_lock:
with open(pdf_path, "wb") as f:
f.write(response.content)
return f"{pdf_name} ({len(response.content) // 1024}KB)"
except Exception as e:
return f"{pdf_name}: {e!s}"
def download_all_pdfs(self, max_workers: int = 5):
"""Download all PDF files with parallel processing"""
self.pdf_dir.mkdir(parents=True, exist_ok=True)
pdf_files = self.get_pdf_list()
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks
future_to_pdf = {
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
for i, pdf_info in enumerate(pdf_files)
}
# Process completed downloads with progress bar
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
for future in as_completed(future_to_pdf):
result = future.result()
pbar.set_postfix_str(result.split()[-1] if "" in result else "Error")
pbar.update(1)
# Verify downloads
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
# Show any failures
missing_pdfs = []
for pdf_info in pdf_files:
pdf_path = self.pdf_dir / pdf_info["name"]
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
missing_pdfs.append(pdf_info["name"])
if missing_pdfs:
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
for pdf in missing_pdfs[:5]: # Show first 5
print(f" - {pdf}")
if len(missing_pdfs) > 5:
print(f" ... and {len(missing_pdfs) - 5} more")
def build_leann_index(
self,
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
):
"""Build LEANN index from all PDFs"""
print(f"🏗️ Building LEANN index with {backend} backend...")
# Check if we have PDFs
pdf_files = list(self.pdf_dir.glob("*.pdf"))
if not pdf_files:
raise RuntimeError("No PDF files found! Run download first.")
print(f"Found {len(pdf_files)} PDF files to process")
start_time = time.time()
# Initialize builder with standard compact configuration
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage (pruned)
num_threads=4,
)
# Process PDFs and extract text
total_chunks = 0
failed_pdfs = []
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
try:
chunks = self.extract_pdf_text(pdf_path)
for chunk in chunks:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
total_chunks += 1
except Exception as e:
print(f"❌ Failed to process {pdf_path.name}: {e}")
failed_pdfs.append(pdf_path.name)
continue
# Build index in index directory
self.index_dir.mkdir(parents=True, exist_ok=True)
index_path = self.index_dir / f"financebench_full_{backend}.leann"
print(f"🔨 Building index: {index_path}")
builder.build_index(str(index_path))
build_time = time.time() - start_time
print("✅ Index built successfully!")
print(f" 📁 Index path: {index_path}")
print(f" 📊 Total chunks: {total_chunks:,}")
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
print(f" ⏱️ Build time: {build_time:.1f}s")
if failed_pdfs:
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
return str(index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
print("🔨 Building FAISS Flat baseline...")
import os
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read metadata from the built index
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
import json
meta = json.loads(f.read())
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
print(f"📊 Loading passages from {passage_file}...")
print(f"🤖 Using embedding model: {embedding_model}")
# Load all passages for baseline
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"📄 Loaded {len(passages)} passages")
# Compute embeddings using the same method as LEANN
print("🧮 Computing embeddings...")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
print(f"📐 Embedding shape: {embeddings.shape}")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(passage_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
"""Extract and chunk text from a PDF file"""
chunks = []
doc = pymupdf.open(pdf_path)
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text() # type: ignore
if not text.strip():
continue
# Create metadata
metadata = {
"source_file": pdf_path.name,
"page_number": page_num + 1,
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
"company": pdf_path.name.split("_")[0],
"doc_period": self.extract_year_from_filename(pdf_path.name),
}
# Use recursive character splitting like LangChain
if len(text.split()) > 500:
# Split by double newlines (paragraphs)
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
current_chunk = ""
for para in paragraphs:
# If adding this paragraph would make chunk too long, save current chunk
if current_chunk and len((current_chunk + " " + para).split()) > 300:
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
current_chunk = para
else:
current_chunk = (current_chunk + " " + para).strip()
# Add the last chunk
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
else:
# Page is short enough, use as single chunk
chunks.append(
{
"text": text.strip(),
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
}
)
doc.close()
return chunks
def extract_year_from_filename(self, filename: str) -> str:
"""Extract year from PDF filename"""
# Try to find 4-digit year in filename
match = re.search(r"(\d{4})", filename)
return match.group(1) if match else "unknown"
def verify_setup(self, index_path: str):
"""Verify the setup by testing a simple query"""
print("🧪 Verifying setup with test query...")
try:
searcher = LeannSearcher(index_path)
# Test query
test_query = "What is the capital expenditure for 3M in 2018?"
results = searcher.search(test_query, top_k=3)
print(f"✅ Test query successful! Found {len(results)} results:")
for i, result in enumerate(results, 1):
company = result.metadata.get("company", "Unknown")
year = result.metadata.get("doc_period", "Unknown")
page = result.metadata.get("page_number", "Unknown")
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
print(f" {result.text[:100]}...")
searcher.cleanup()
print("✅ Setup verification completed successfully!")
except Exception as e:
print(f"❌ Setup verification failed: {e}")
raise
def main():
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument(
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
)
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model",
)
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
parser.add_argument(
"--build-baseline-only",
action="store_true",
help="Only build FAISS baseline from existing index",
)
args = parser.parse_args()
print("🏦 FinanceBench Complete Setup")
print("=" * 50)
setup = FinanceBenchSetup(args.data_dir)
try:
if args.build_baseline_only:
# Only build baseline from existing index
index_path = setup.index_dir / f"financebench_full_{args.backend}"
index_file = f"{index_path}.index"
meta_file = f"{index_path}.leann.meta.json"
if not os.path.exists(index_file) or not os.path.exists(meta_file):
print("❌ Index files not found:")
print(f" Index: {index_file}")
print(f" Meta: {meta_file}")
print("💡 Run without --build-baseline-only to build the index first")
exit(1)
print(f"🔨 Building baseline from existing index: {index_path}")
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
print(f"✅ Baseline built at {baseline_path}")
return
# Step 1: Download dataset
setup.download_dataset()
# Step 2: Download PDFs
if not args.skip_download:
setup.download_all_pdfs(max_workers=args.max_workers)
else:
print("⏭️ Skipping PDF download")
# Step 3: Build LEANN index
if not args.skip_build:
index_path = setup.build_leann_index(
backend=args.backend, embedding_model=args.embedding_model
)
# Step 4: Build FAISS flat baseline
print("\n🔨 Building FAISS flat baseline...")
baseline_path = setup.build_faiss_flat_baseline(index_path)
print(f"✅ Baseline built at {baseline_path}")
# Step 5: Verify setup
setup.verify_setup(index_path)
else:
print("⏭️ Skipping index building")
print("\n🎉 FinanceBench setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("\nNext steps:")
print(
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
)
print(
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
)
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "faiss-cpu",
# "numpy",
# "sentence-transformers",
# "torch",
# "tqdm",
# ]
# ///
"""
Independent recall verification script using standard FAISS.
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
"""
import json
import time
from pathlib import Path
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
"""
Direct embedding computation using sentence-transformers.
Copied logic to avoid dependency issues.
"""
print(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
print(f"Computing embeddings for {len(chunks)} chunks...")
embeddings = model.encode(
chunks,
show_progress_bar=True,
batch_size=32,
convert_to_numpy=True,
normalize_embeddings=False,
)
return embeddings.astype(np.float32)
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
"""Load FinanceBench queries from dataset"""
queries = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
if len(queries) >= max_queries:
break
return queries
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
"""Load passages from LEANN index structure"""
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
index_dir = Path(index_path).parent
passage_file = index_dir / Path(passage_file).name
print(f"Loading passages from {passage_file}")
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Loading passages"):
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"Loaded {len(passages)} passages")
return passages, passage_ids
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
dimension = embeddings.shape[1]
# Build Flat index (ground truth)
print("Building FAISS IndexFlatIP (ground truth)...")
flat_index = faiss.IndexFlatIP(dimension)
flat_index.add(embeddings)
# Build HNSW index
print("Building FAISS IndexHNSWFlat...")
M = 32 # Same as LEANN default
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
hnsw_index.add(embeddings)
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
return flat_index, hnsw_index
def evaluate_recall_at_k(
query_embeddings: np.ndarray,
flat_index: faiss.Index,
hnsw_index: faiss.Index,
passage_ids: list[str],
k: int = 3,
ef_search: int = 64,
) -> float:
"""Evaluate recall@k comparing HNSW vs Flat"""
# Set search parameters for HNSW
hnsw_index.hnsw.efSearch = ef_search
total_recall = 0.0
num_queries = query_embeddings.shape[0]
for i in range(num_queries):
query = query_embeddings[i : i + 1] # Keep 2D shape
# Get ground truth from Flat index (standard FAISS API)
flat_distances, flat_indices = flat_index.search(query, k)
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
# Get results from HNSW index (standard FAISS API)
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
# Calculate recall
intersection = ground_truth_ids.intersection(hnsw_ids)
recall = len(intersection) / k
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
print(f" Flat: {list(ground_truth_ids)}")
print(f" HNSW: {list(hnsw_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
return avg_recall
def main():
# Configuration
dataset_path = "data/financebench_merged.jsonl"
index_path = "data/index/financebench_full_hnsw.leann"
embedding_model = "sentence-transformers/all-mpnet-base-v2"
print("🔍 FAISS Recall Verification")
print("=" * 50)
# Check if files exist
if not Path(dataset_path).exists():
print(f"❌ Dataset not found: {dataset_path}")
return
if not Path(f"{index_path}.meta.json").exists():
print(f"❌ Index metadata not found: {index_path}.meta.json")
return
# Load data
print("📖 Loading FinanceBench queries...")
queries = load_financebench_queries(dataset_path, max_queries=50)
print(f"Loaded {len(queries)} queries")
print("📄 Loading passages from LEANN index...")
passages, passage_ids = load_passages_from_leann_index(index_path)
# Compute embeddings
print("🧮 Computing passage embeddings...")
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
print("🧮 Computing query embeddings...")
query_embeddings = compute_embeddings_direct(queries, embedding_model)
# Build FAISS indexes
print("🏗️ Building FAISS indexes...")
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
# Test different efSearch values (equivalent to LEANN complexity)
print("\n📊 Evaluating Recall@3 at different efSearch values...")
ef_search_values = [16, 32, 64, 128, 256]
for ef_search in ef_search_values:
print(f"\n🧪 Testing efSearch = {ef_search}")
start_time = time.time()
recall = evaluate_recall_at_k(
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
)
elapsed = time.time() - start_time
print(
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
)
print("\n✅ Verification completed!")
print("\n📋 Summary:")
print(" - Built independent FAISS Flat and HNSW indexes")
print(" - Compared recall@3 at different efSearch values")
print(" - Used same embedding model as LEANN")
print(" - This validates LEANN's recall measurements")
if __name__ == "__main__":
main()

1
benchmarks/laion/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
data/

199
benchmarks/laion/README.md Normal file
View File

@@ -0,0 +1,199 @@
# LAION Multimodal Benchmark
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
## Overview
This benchmark evaluates:
- **Image retrieval timing** using caption-based queries
- **Recall@K performance** for image search
- **Complexity analysis** across different search parameters
- **Index size and storage efficiency**
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
## Dataset Configuration
- **Dataset**: LAION-400M subset (10,000 images)
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
- **Queries**: 200 random captions from the dataset
- **Ground Truth**: Self-recall (query caption → original image)
## Quick Start
### 1. Setup the benchmark
```bash
cd benchmarks/laion
python setup_laion.py --num-samples 10000 --num-queries 200
```
This will:
- Create dummy LAION data (10K samples)
- Generate CLIP embeddings (512-dim)
- Build LEANN index with HNSW backend
- Create 200 evaluation queries
### 2. Run evaluation
```bash
# Run all evaluation stages
python evaluate_laion.py --index data/laion_index.leann
# Run specific stages
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
# Multimodal generation with Qwen2.5-VL
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
```
### 3. Save results
```bash
python evaluate_laion.py --index data/laion_index.leann --output results.json
```
## Configuration Options
### Setup Options
```bash
python setup_laion.py \
--num-samples 10000 \
--num-queries 200 \
--index-path data/laion_index.leann \
--backend hnsw
```
### Evaluation Options
```bash
python evaluate_laion.py \
--index data/laion_index.leann \
--queries data/evaluation_queries.jsonl \
--complexity 64 \
--top-k 3 \
--num-samples 100 \
--stage all
```
## Evaluation Stages
### Stage 2: Recall Evaluation
- Evaluates Recall@3 for multimodal retrieval
- Compares LEANN vs FAISS baseline performance
- Self-recall: query caption should retrieve original image
### Stage 3: Complexity Analysis
- Binary search for optimal complexity (90% recall target)
- Tests performance across different complexity levels
- Analyzes speed vs. accuracy tradeoffs
### Stage 4: Index Comparison
- Compares compact vs non-compact index sizes
- Measures search performance differences
- Reports storage efficiency and speed ratios
### Stage 5: Multimodal Generation
- Uses Qwen2.5-VL for image understanding and description
- Retrieval-Augmented Generation (RAG) with multimodal context
- Measures both search and generation timing
## Output Metrics
### Timing Metrics
- Average/median/min/max search time
- Standard deviation
- Searches per second
- Latency in milliseconds
### Recall Metrics
- Recall@3 percentage for image retrieval
- Number of queries with ground truth
### Index Metrics
- Total index size (MB)
- Component breakdown (index, passages, metadata)
- Storage savings (compact vs non-compact)
- Backend and embedding model info
### Generation Metrics (Stage 5)
- Average search time per query
- Average generation time per query
- Time distribution (search vs generation)
- Sample multimodal responses
- Model: Qwen2.5-VL performance
## Benchmark Results
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
**Stage 3: Optimal Complexity Analysis**
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
- **Binary Search Range**: 1-128
- **Target Recall**: 90%
- **Index Type**: Non-compact (for fast binary search)
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
- **Total Queries**: 20
- **Average Search Time**: 1.200s per query
- **Average Generation Time**: 6.558s per query
- **Time Distribution**: Search 15.5%, Generation 84.5%
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
- **Optimal Complexity**: 85
**System Performance:**
- **Index Size**: ~10,000 image embeddings from LAION subset
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
- **Backend**: HNSW with cosine distance
### Example Results
```
🎯 LAION MULTIMODAL BENCHMARK RESULTS
============================================================
📊 Multimodal Generation Results:
Total Queries: 20
Avg Search Time: 1.200s
Avg Generation Time: 6.558s
Time Distribution: Search 15.5%, Generation 84.5%
LLM Backend: HuggingFace transformers
Model: Qwen/Qwen2.5-VL-7B-Instruct
⚙️ Optimal Complexity Analysis:
Target Recall: 90%
Optimal Complexity: 85
Binary Search Range: 1-128
Non-compact Index (fast search, no recompute)
🚀 Performance Summary:
Multimodal RAG: 7.758s total per query
Search: 15.5% of total time
Generation: 84.5% of total time
```
## Directory Structure
```
benchmarks/laion/
├── setup_laion.py # Setup script
├── evaluate_laion.py # Evaluation script
├── README.md # This file
└── data/ # Generated data
├── laion_images/ # Image files (placeholder)
├── laion_metadata.jsonl # Image metadata
├── laion_passages.jsonl # LEANN passages
├── laion_embeddings.npy # CLIP embeddings
├── evaluation_queries.jsonl # Evaluation queries
└── laion_index.leann/ # LEANN index files
```
## Notes
- Current implementation uses dummy data for demonstration
- For real LAION data, implement actual download logic in `setup_laion.py`
- CLIP embeddings are randomly generated - replace with real CLIP model for production
- Adjust `num_samples` and `num_queries` based on available resources
- Consider using `--num-samples` during evaluation for faster testing

View File

@@ -0,0 +1,725 @@
"""
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
import numpy as np
from leann import LeannSearcher
from leann_backend_hnsw import faiss
from sentence_transformers import SentenceTransformer
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline (image embeddings)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.image_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
self.st_clip = SentenceTransformer("clip-ViT-L-14")
def evaluate_recall_at_3(
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(captions)
for i, caption in enumerate(captions):
# Get ground truth: search with FAISS flat using caption text embedding
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
query_embedding = self.st_clip.encode(
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results (image IDs from FAISS)
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity (using caption as text query)
test_results = self.searcher.search(
caption,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class LAIONEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
"""Load caption queries from evaluation file"""
captions = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if line.strip():
query_data = json.loads(line)
captions.append(query_data["query"])
print(f"📊 Loaded {len(captions)} caption queries")
return captions
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
print("📏 Analyzing index sizes (.index only)...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes: dict[str, float] = {}
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
# Core index size (.index only)
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
sizes["index_only_mb"] = index_mb
# Other files for reference (not counted in index_only_mb)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {index_mb:.1f} MB")
if sizes["metadata_mb"]:
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
print(
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
)
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Load CLIP embeddings
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
embeddings = np.load(embeddings_file)
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
distance_metric="cosine",
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact", "distance_metric"]
},
)
# Prepare ids and add passages
ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
# Ensure metadata contains the id used by the vector index
metadata = {**data.get("metadata", {}), "id": data["id"]}
builder.add_text(text=data["text"], metadata=metadata)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
# Test queries
test_queries = test_captions[:5]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for caption in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for caption in test_queries:
start_time = time.time()
_ = compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
print("=" * 60)
# Index comparison analysis (prefer .index-only view if present)
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
if "index_only_mb" in current and "index_only_mb" in non_compact:
print("\n📏 Index Comparison Analysis (.index only):")
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
# Show excluded components for reference if available
if any(
k in non_compact
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
):
print(" (passages excluded in totals, shown for reference):")
print(
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
)
else:
# Fallback to legacy totals if running with older metrics
print("\n📏 Index Comparison Analysis:")
print(
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
)
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend",
choices=["hf"],
default="hf",
help="LLM backend (Qwen2.5-VL only supports HF)",
)
parser.add_argument(
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
)
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_laion.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load caption queries for testing
laion_evaluator = LAIONEvaluator(args.index)
captions = laion_evaluator.load_queries(args.queries)
# Test with queries for robust measurement
test_captions = captions[:100] # Use subset for speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search
print("🏗️ Creating non-compact index for binary search...")
evaluator = LAIONEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load caption queries for testing
captions = evaluator.load_queries(args.queries)
# Use subset for robust measurement
test_captions = captions[:50] # Smaller subset for binary search speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Binary search for 90% recall complexity
target_recall = 0.9
min_complexity, max_complexity = 1, 128
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_captions[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Index comparison (without LLM generation)
print("🚀 Starting Stage 4: Index comparison analysis")
# Use LAION evaluator for index comparison
evaluator = LAIONEvaluator(args.index)
# Load caption queries
captions = evaluator.load_queries(args.queries)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes (.index only)
print("\n📊 Index size comparison (.index only):")
print(
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
)
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
storage_saving = 0.0
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
storage_saving = (
(
non_compact_size_metrics.get("index_only_mb", 0)
- compact_size_metrics.get("index_only_mb", 0)
)
/ non_compact_size_metrics.get("index_only_mb", 1)
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for index comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, captions[:10], complexity=complexity
)
# Combine all metrics
combined_metrics = {
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print comprehensive results
evaluator._print_results(combined_metrics)
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
evaluator = LAIONEvaluator(args.index)
captions = evaluator.load_queries(args.queries)
test_captions = captions[: min(20, len(captions))] # Use subset for generation
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
# Load Qwen2.5-VL model
try:
print("Loading Qwen2.5-VL model...")
processor, model = load_qwen_vl_model(args.model_name)
# Run multimodal generation evaluation
complexity = args.complexity or 64
gen_results = evaluate_multimodal_rag(
evaluator.searcher,
test_captions,
processor=processor,
model=model,
complexity=complexity,
)
print("\n📊 Multimodal Generation Results:")
print(f" Total Queries: {len(test_captions)}")
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
search_pct = (gen_results["avg_search_time"] / total_time) * 100
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
print(" LLM Backend: HuggingFace transformers")
print(f" Model: {args.model_name}")
# Show sample results
print("\n📝 Sample Multimodal Generations:")
for i, response in enumerate(gen_results["results"][:3]):
# Handle both string and dict formats for captions
if isinstance(test_captions[i], dict):
caption_text = test_captions[i].get("query", str(test_captions[i]))
else:
caption_text = str(test_captions[i])
print(f" Query {i + 1}: {caption_text[:60]}...")
print(f" Response {i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Multimodal generation evaluation failed: {e}")
print("💡 Make sure transformers and Qwen2.5-VL are installed")
import traceback
traceback.print_exc()
evaluator.cleanup()
print("✅ Stage 5 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Index comparison analysis completed")
print(" Stage 5: ✅ Multimodal generation evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review index comparison for storage vs performance tradeoffs")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
import traceback
traceback.print_exc()
exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,576 @@
"""
LAION Multimodal Benchmark Setup Script
Downloads LAION subset and builds LEANN index with sentence embeddings
"""
import argparse
import asyncio
import io
import json
import os
import pickle
import time
from pathlib import Path
import aiohttp
import numpy as np
from datasets import load_dataset
from leann import LeannBuilder
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
class LAIONSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.images_dir = self.data_dir / "laion_images"
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
# Create directories
self.data_dir.mkdir(exist_ok=True)
self.images_dir.mkdir(exist_ok=True)
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
"""Download a single image asynchronously"""
async with semaphore: # Limit concurrent downloads
try:
image_url = sample_data["url"]
image_path = sample_data["image_path"]
# Skip if already exists
if os.path.exists(image_path):
progress_bar.update(1)
return sample_data
async with session.get(image_url, timeout=10) as response:
if response.status == 200:
content = await response.read()
# Verify it's a valid image
try:
img = Image.open(io.BytesIO(content))
img = img.convert("RGB")
img.save(image_path, "JPEG")
progress_bar.update(1)
return sample_data
except Exception:
progress_bar.update(1)
return None # Skip invalid images
else:
progress_bar.update(1)
return None
except Exception:
progress_bar.update(1)
return None
def download_laion_subset(self, num_samples: int = 1000):
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
# Load LAION-400M subset from HuggingFace
print("🤗 Loading from HuggingFace datasets...")
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
# Collect sample metadata first (fast)
print("📋 Collecting sample metadata...")
candidates = []
for sample in dataset:
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
break
image_url = sample.get("url", "")
caption = sample.get("caption", "")
if not image_url or not caption:
continue
image_filename = f"laion_{len(candidates):06d}.jpg"
image_path = self.images_dir / image_filename
candidate = {
"id": f"laion_{len(candidates):06d}",
"url": image_url,
"caption": caption,
"image_path": str(image_path),
"width": sample.get("original_width", 512),
"height": sample.get("original_height", 512),
"similarity": sample.get("similarity", 0.0),
}
candidates.append(candidate)
print(
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
)
# Download images in parallel
async def download_batch():
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
timeout = aiohttp.ClientTimeout(total=30)
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
tasks = []
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
task = self.download_single_image(session, candidate, semaphore, progress_bar)
tasks.append(task)
# Wait for all downloads
results = await asyncio.gather(*tasks, return_exceptions=True)
progress_bar.close()
# Filter successful downloads
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
return successful[:num_samples]
# Run async download
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
samples = loop.run_until_complete(download_batch())
finally:
loop.close()
# Save metadata
with open(self.metadata_file, "w", encoding="utf-8") as f:
for sample in samples:
f.write(json.dumps(sample) + "\n")
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
return samples
def generate_clip_image_embeddings(self, samples: list[dict]):
"""Generate CLIP image embeddings for downloaded images"""
print("🔍 Generating CLIP image embeddings...")
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
# This single model can encode both images and text.
model = SentenceTransformer("clip-ViT-L-14")
embeddings = []
valid_samples = []
for sample in tqdm(samples, desc="Processing images"):
try:
# Load image
image_path = sample["image_path"]
image = Image.open(image_path).convert("RGB")
# Encode image to 768-dim embedding via sentence-transformers (normalized)
vec = model.encode(
[image],
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=1,
show_progress_bar=False,
)[0]
embeddings.append(vec.astype(np.float32))
valid_samples.append(sample)
except Exception as e:
print(f" ⚠️ Failed to process {sample['id']}: {e}")
# Skip invalid images
embeddings = np.array(embeddings, dtype=np.float32)
# Save embeddings
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
np.save(embeddings_file, embeddings)
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
return embeddings, valid_samples
def build_faiss_baseline(
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
):
"""Build FAISS flat baseline using CLIP image embeddings"""
print("🔨 Building FAISS Flat baseline...")
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Extract image IDs (must be present)
if not samples or "id" not in samples[0]:
raise KeyError("samples missing 'id' field for FAISS baseline")
image_ids: list[str] = [str(sample["id"]) for sample in samples]
print(f"📐 Embedding shape: {embeddings.shape}")
print(f"📄 Processing {len(image_ids)} images")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(image_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def create_leann_passages(self, samples: list[dict]):
"""Create LEANN-compatible passages from LAION data"""
print("📝 Creating LEANN passages...")
passages_file = self.data_dir / "laion_passages.jsonl"
with open(passages_file, "w", encoding="utf-8") as f:
for i, sample in enumerate(samples):
passage = {
"id": sample["id"],
"text": sample["caption"], # Use caption as searchable text
"metadata": {
"image_url": sample["url"],
"image_path": sample.get("image_path", ""),
"width": sample["width"],
"height": sample["height"],
"similarity": sample["similarity"],
"image_index": i, # Index for embedding lookup
},
}
f.write(json.dumps(passage) + "\n")
print(f"✅ Created {len(samples)} passages")
return passages_file
def build_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
print(f"🏗️ Building compact LEANN index with {backend} backend...")
start_time = time.time()
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
npy_path = self.data_dir / "clip_image_embeddings.npy"
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in the same order as passages_file (matches embeddings order)
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - compact with recompute
# Note: For multimodal case, we need to handle embeddings differently
# Let's try using sentence-transformers mode but with custom embeddings
builder = LeannBuilder(
backend_name=backend,
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
# HNSW params (or forwarded to chosen backend)
graph_degree=32,
complexity=64,
# Compact/pruned with recompute at query time
is_recompute=True,
is_compact=True,
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
num_threads=4,
)
# Add passages (text + metadata)
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def build_non_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
start_time = time.time()
# Ensure embeddings are saved (npy + pickle)
npy_path = self.data_dir / "clip_image_embeddings.npy"
if not npy_path.exists():
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in same order as passages_file
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
if not pkl_path.exists():
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - non-compact without recompute
builder = LeannBuilder(
backend_name=backend,
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=False, # Store embeddings (no recompute needed)
is_compact=False, # Store full index (not pruned)
distance_metric="cosine",
num_threads=4,
)
# Add passages - embeddings will be loaded from file
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Non-compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
"""Helper to add passages with pre-computed CLIP embeddings"""
with open(passages_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Adding passages"):
if line.strip():
passage = json.loads(line)
# Add image metadata - LEANN will handle embeddings separately
# Note: We store image metadata and caption text for searchability
# Important: ensure passage ID in metadata matches vector ID
builder.add_text(
text=passage["text"], # Image caption for searchability
metadata={**passage["metadata"], "id": passage["id"]},
)
def _analyze_index_size(self, index_path: str):
"""Analyze index file sizes"""
print("📏 Analyzing index sizes...")
index_path = Path(index_path)
index_dir = index_path.parent
index_name = index_path.name # e.g., laion_index.leann
index_prefix = index_path.stem # e.g., laion_index
files = [
(f"{index_prefix}.index", ".index", "core"),
(f"{index_name}.meta.json", ".meta.json", "core"),
(f"{index_name}.ids.txt", ".ids.txt", "core"),
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
]
def _fmt_size(bytes_val: int) -> str:
if bytes_val < 1024:
return f"{bytes_val} B"
kb = bytes_val / 1024
if kb < 1024:
return f"{kb:.1f} KB"
mb = kb / 1024
if mb < 1024:
return f"{mb:.2f} MB"
gb = mb / 1024
return f"{gb:.2f} GB"
total_index_only_mb = 0.0
total_all_mb = 0.0
for filename, label, group in files:
file_path = index_dir / filename
if file_path.exists():
size_bytes = file_path.stat().st_size
print(f" {label}: {_fmt_size(size_bytes)}")
size_mb = size_bytes / (1024 * 1024)
total_all_mb += size_mb
if group == "core":
total_index_only_mb += size_mb
else:
print(f" {label}: (missing)")
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
print(f" Total (including passages): {total_all_mb:.2f} MB")
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
"""Create evaluation queries from captions"""
print(f"📝 Creating {num_queries} evaluation queries...")
# Sample random captions as queries
import random
random.seed(42) # For reproducibility
query_samples = random.sample(samples, min(num_queries, len(samples)))
queries_file = self.data_dir / "evaluation_queries.jsonl"
with open(queries_file, "w", encoding="utf-8") as f:
for sample in query_samples:
query = {
"id": sample["id"],
"query": sample["caption"],
"ground_truth_id": sample["id"], # For potential recall evaluation
}
f.write(json.dumps(query) + "\n")
print(f"✅ Created {len(query_samples)} evaluation queries")
return queries_file
def main():
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
parser.add_argument(
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
)
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
args = parser.parse_args()
print("🚀 Setting up LAION Multimodal Benchmark")
print("=" * 50)
try:
# Initialize setup
setup = LAIONSetup(args.data_dir)
# Step 1: Download LAION subset
if not args.skip_download:
print("\n📦 Step 1: Download LAION subset")
samples = setup.download_laion_subset(args.num_samples)
# Step 2: Generate CLIP image embeddings
print("\n🔍 Step 2: Generate CLIP image embeddings")
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
# Step 3: Create LEANN passages (image metadata with embeddings)
print("\n📝 Step 3: Create LEANN passages")
passages_file = setup.create_leann_passages(valid_samples)
else:
print("⏭️ Skipping LAION dataset download")
# Load existing data
passages_file = setup.data_dir / "laion_passages.jsonl"
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
if not passages_file.exists() or not embeddings_file.exists():
raise FileNotFoundError(
"Passages or embeddings file not found. Run without --skip-download first."
)
embeddings = np.load(embeddings_file)
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
# Step 4: Build LEANN indexes (both compact and non-compact)
if not args.skip_build:
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
# Build compact index (production mode - small, recompute required)
compact_index_path = args.index_path
print(f"Building compact index: {compact_index_path}")
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
# Build non-compact index (comparison mode - large, fast search)
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
print(f"Building non-compact index: {non_compact_index_path}")
setup.build_non_compact_index(
passages_file, embeddings, non_compact_index_path, args.backend
)
# Step 5: Build FAISS flat baseline
print("\n🔨 Step 5: Build FAISS flat baseline")
if not args.skip_download:
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
else:
# Load valid_samples from passages file for FAISS baseline
valid_samples = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
passage = json.loads(line)
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
# Step 6: Create evaluation queries
print("\n📝 Step 6: Create evaluation queries")
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
else:
print("⏭️ Skipping index building")
baseline_path = "data/baseline/faiss_index.bin"
queries_file = setup.data_dir / "evaluation_queries.jsonl"
print("\n🎉 Setup completed successfully!")
print("📊 Summary:")
if not args.skip_download:
print(f" Downloaded samples: {len(samples)}")
print(f" Valid samples with embeddings: {len(valid_samples)}")
else:
print(f" Loaded {len(embeddings)} embeddings")
if not args.skip_build:
print(f" Compact index: {compact_index_path}")
print(f" Non-compact index: {non_compact_index_path}")
print(f" FAISS baseline: {baseline_path}")
print(f" Queries: {queries_file}")
print("\n🔧 Next steps:")
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
else:
print(" Skipped building indexes")
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

301
benchmarks/llm_utils.py Normal file
View File

@@ -0,0 +1,301 @@
"""
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
"""
import time
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
def is_qwen3_model(model_name):
"""Check if model is Qwen3"""
return "Qwen3" in model_name or "qwen3" in model_name.lower()
def is_qwen_vl_model(model_name):
"""Check if model is Qwen2.5-VL"""
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
def apply_qwen3_chat_template(tokenizer, prompt):
"""Apply Qwen3 chat template with thinking enabled"""
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
def extract_thinking_answer(response):
"""Extract final answer from Qwen3 thinking model response"""
if "<think>" in response and "</think>" in response:
try:
think_end = response.index("</think>") + len("</think>")
final_answer = response[think_end:].strip()
return final_answer
except (ValueError, IndexError):
pass
return response.strip()
def load_hf_model(model_name="Qwen/Qwen3-8B"):
"""Load HuggingFace model"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
print(f"Loading HF: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
)
return tokenizer, model
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
"""Load vLLM model"""
if not VLLM_AVAILABLE:
raise ImportError("vllm not available")
print(f"Loading vLLM: {model_name}")
llm = LLM(model=model_name, trust_remote_code=True)
# Qwen3 specific config
if is_qwen3_model(model_name):
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
max_tokens = 2048
else:
stop_tokens = None
max_tokens = 1024
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
return llm, sampling_params
def generate_hf(tokenizer, model, prompt, max_tokens=None):
"""Generate with HF - supports Qwen3 thinking models"""
model_name = getattr(model, "name_or_path", "unknown")
is_qwen3 = is_qwen3_model(model_name)
# Apply chat template for Qwen3
if is_qwen3:
prompt = apply_qwen3_chat_template(tokenizer, prompt)
max_tokens = max_tokens or 2048
else:
max_tokens = max_tokens or 1024
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt) :].strip()
# Extract final answer for thinking models
if is_qwen3:
return extract_thinking_answer(response)
return response
def generate_vllm(llm, sampling_params, prompt):
"""Generate with vLLM - supports Qwen3 thinking models"""
outputs = llm.generate([prompt], sampling_params)
response = outputs[0].outputs[0].text.strip()
# Extract final answer for Qwen3 thinking models
model_name = str(llm.llm_engine.model_config.model)
if is_qwen3_model(model_name):
return extract_thinking_answer(response)
return response
def create_prompt(context, query, domain="default"):
"""Create RAG prompt"""
if domain == "emails":
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "finance":
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "multimodal":
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
else:
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
"""Simple RAG evaluation with timing"""
search_times = []
gen_times = []
results = []
for i, query in enumerate(queries):
# Search
start = time.time()
docs = searcher.search(query, top_k=top_k, complexity=complexity)
search_time = time.time() - start
# Generate
context = "\n\n".join([doc.text for doc in docs])
prompt = create_prompt(context, query, domain)
start = time.time()
response = llm_func(prompt)
gen_time = time.time() - start
search_times.append(search_time)
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
"""Load Qwen2.5-VL multimodal model"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
print(f"Loading Qwen2.5-VL: {model_name}")
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForVision2Seq.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
return processor, model
except Exception as e:
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
# Fallback to specific class
try:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
return processor, model
except Exception as e2:
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
"""Generate with Qwen2.5-VL multimodal model"""
from PIL import Image
# Prepare inputs
if image_path:
image = Image.open(image_path)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
)
# Decode response
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
response = processor.decode(generated_ids[0], skip_special_tokens=True)
return response
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
"""Create prompt for multimodal RAG"""
if task_type == "images":
return f"""Based on the retrieved images and their descriptions, answer the following question.
Retrieved Image Descriptions:
{context}
Question: {query}
Provide a detailed answer based on the visual content described above."""
return f"Context: {context}\nQuestion: {query}\nAnswer:"
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
"""Evaluate multimodal RAG with Qwen2.5-VL"""
search_times = []
gen_times = []
results = []
for i, query_item in enumerate(queries):
# Handle both string and dict formats for queries
if isinstance(query_item, dict):
query = query_item.get("query", "")
image_path = query_item.get("image_path") # Optional reference image
else:
query = str(query_item)
image_path = None
# Search
start_time = time.time()
search_results = searcher.search(query, top_k=3, complexity=complexity)
search_time = time.time() - start_time
search_times.append(search_time)
# Prepare context from search results
context_parts = []
for result in search_results:
context_parts.append(f"- {result.text}")
context = "\n".join(context_parts)
# Generate with multimodal model
start_time = time.time()
if processor and model:
prompt = create_multimodal_prompt(context, query, context_parts)
response = generate_qwen_vl(processor, model, prompt, image_path)
else:
response = f"Context: {context}"
gen_time = time.time() - start_time
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}

View File

@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
print( print(
"Error: huggingface_hub is not installed. Please install it to download the data:" "Error: huggingface_hub is not installed. Please install it to download the data:"
) )
print("uv pip install -e '.[dev]'") print("uv sync --only-group dev")
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
print(f"An error occurred during data download: {e}") print(f"An error occurred during data download: {e}")

View File

@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
### Setup Pre-commit ### Setup Pre-commit
1. **Install pre-commit** (already included when you run `uv sync`): 1. **Install pre-commit tools**:
```bash ```bash
uv pip install pre-commit uv sync lint
``` ```
2. **Install the git hooks**: 2. **Install the git hooks**:
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
3. **Run pre-commit manually** (optional): 3. **Run pre-commit manually** (optional):
```bash ```bash
pre-commit run --all-files uv run pre-commit run --all-files
``` ```
### Pre-commit Checks ### Pre-commit Checks
@@ -85,6 +85,9 @@ Our pre-commit configuration includes:
### Running Tests ### Running Tests
```bash ```bash
# Install test tools only (no project runtime)
uv sync --group test
# Run all tests # Run all tests
uv run pytest uv run pytest

View File

@@ -26,6 +26,21 @@ leann build my-code-index --docs ./src --use-ast-chunking
uv pip install -e "." uv pip install -e "."
``` ```
#### For normal users (PyPI install)
- Use `pip install leann` or `uv pip install leann`.
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
#### For developers (from source, editable)
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
## Best Practices ## Best Practices
### When to Use AST Chunking ### When to Use AST Chunking

View File

@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
</details> </details>
## Local & Remote Inference Endpoints
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint either on the same machine or across the network with a couple of flags or environment variables.
### One-Time Environment Setup
```bash
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
export OPENAI_BASE_URL="http://localhost:1234/v1"
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
```
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
### Passing Hosts Per Command
```bash
# Build an index with a remote embedding server
leann build my-notes \
--docs ./notes \
--embedding-mode openai \
--embedding-model text-embedding-qwen3-embedding-0.6b \
--embedding-api-base http://192.168.1.50:1234/v1 \
--embedding-api-key local-dev-key
# Query using a local LM Studio instance via OpenAI-compatible API
leann ask my-notes \
--llm openai \
--llm-model qwen3-8b \
--api-base http://localhost:1234/v1 \
--api-key local-dev-key
# Query an Ollama instance running on another box
leann ask my-notes \
--llm ollama \
--llm-model qwen3:14b \
--host http://192.168.1.101:11434
```
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
- Configure router or cloud provider port forwarding.
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
**Tip:** If your runtime does not require an API key (many local stacks dont), leave `--api-key` unset. LEANN will skip injecting credentials.
### Python API Usage
You can pass the same configuration from Python:
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
backend_name="hnsw",
embedding_mode="openai",
embedding_model="text-embedding-qwen3-embedding-0.6b",
embedding_options={
"base_url": "http://192.168.1.50:1234/v1",
"api_key": "local-dev-key",
},
)
builder.build_index("./indexes/my-notes", chunks)
```
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Index Selection: Matching Your Scale ## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World) ### HNSW (Hierarchical Navigable Small World)

149
docs/grep_search.md Normal file
View File

@@ -0,0 +1,149 @@
# LEANN Grep Search Usage Guide
## Overview
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
## Basic Usage
### Simple Grep Search
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("your_index_path")
# Exact text search
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:100]}...")
print("-" * 40)
```
### Comparison: Semantic vs Grep Search
```python
# Semantic search - finds conceptually similar content
semantic_results = searcher.search("machine learning algorithms", top_k=3)
# Grep search - finds exact text matches
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
```
## When to Use Grep Search
### Use Cases
- **Code Search**: Finding specific function definitions, class names, or variable references
- **Error Debugging**: Locating exact error messages or stack traces
- **Documentation**: Finding specific API endpoints or exact terminology
### Examples
```python
# Find function definitions
functions = searcher.search("def __init__", use_grep=True)
# Find import statements
imports = searcher.search("from sklearn import", use_grep=True)
# Find specific error types
errors = searcher.search("FileNotFoundError", use_grep=True)
# Find TODO comments
todos = searcher.search("TODO:", use_grep=True)
# Find configuration entries
configs = searcher.search("server_port=", use_grep=True)
```
## Technical Details
### How It Works
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
3. **Result Processing**: Parses JSON lines and extracts text and metadata
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
### Search Process
```
Query: "def train_model"
grep -i -n "def train_model" documents.leann.passages.jsonl
Parse matching JSON lines
Calculate scores based on term frequency
Return top_k results
```
### Scoring Algorithm
```python
# Term frequency in document
score = text.lower().count(query.lower())
```
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
## Error Handling
### Common Issues
#### Grep Command Not Found
```
RuntimeError: grep command not found. Please install grep or use semantic search.
```
**Solution**: Install grep on your system:
- **Ubuntu/Debian**: `sudo apt-get install grep`
- **macOS**: grep is pre-installed
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
#### No Results Found
```python
# Check if your query exists in the raw data
results = searcher.search("your_query", use_grep=True)
if not results:
print("No exact matches found. Try:")
print("1. Check spelling and case")
print("2. Use partial terms")
print("3. Switch to semantic search")
```
## Complete Example
```python
#!/usr/bin/env python3
"""
Grep Search Example
Demonstrates grep search for exact text matching.
"""
from leann.api import LeannSearcher
def demonstrate_grep_search():
# Initialize searcher
searcher = LeannSearcher("my_index")
print("=== Function Search ===")
functions = searcher.search("def __init__", use_grep=True, top_k=5)
for i, result in enumerate(functions, 1):
print(f"{i}. Score: {result.score}")
print(f" Preview: {result.text[:60]}...")
print()
print("=== Error Search ===")
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
for result in errors:
print(f"Content: {result.text.strip()}")
print("-" * 40)
if __name__ == "__main__":
demonstrate_grep_search()
```

0
examples/__init__.py Normal file
View File

View File

@@ -0,0 +1,404 @@
"""Dynamic HNSW update demo without compact storage.
This script reproduces the minimal scenario we used while debugging on-the-fly
recompute:
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
2. Print the top results with `recompute_embeddings=True`.
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
4. Run the same query again to show the newly inserted passages.
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
ZMQ activity)::
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
uv run -m examples.dynamic_update_no_recompute \
--index-path .leann/examples/leann-demo.leann
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
It issues the query "What's LEANN?" before and after the update to show how the
new passages become immediately searchable. The script uses the
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
freshly added passages are embedded locally just like the initial build.
To make storage comparisons easy, the script can also build a matching
``is_recompute=False`` baseline (enabled by default) and report the index size
delta after the update. Disable the baseline run with
``--skip-compare-no-recompute`` if you only need the recompute flow.
"""
import argparse
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from leann.api import LeannBuilder, LeannSearcher
from leann.registry import register_project_directory
from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path]) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
return [c for c in chunks if isinstance(c, str) and c.strip()]
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
searcher = LeannSearcher(str(index_path))
try:
return searcher.search(
query=query,
top_k=top_k,
recompute_embeddings=recompute_embeddings,
batch_size=16,
)
finally:
searcher.cleanup()
def print_results(title: str, results: Iterable) -> None:
print(f"\n=== {title} ===")
res_list = list(results)
print(f"results count: {len(res_list)}")
print("passages:")
if not res_list:
print(" (no passages returned)")
for res in res_list:
snippet = res.text.replace("\n", " ")[:120]
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def update_index(
index_path: Path,
start_id: int,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
updater = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for offset, passage in enumerate(paragraphs, start=start_id):
updater.add_text(passage, metadata={"id": str(offset)})
updater.update_index(str(index_path))
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
"""Remove leftover index artifacts for a clean rebuild."""
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def index_file_size(index_path: Path) -> int:
"""Return the size of the primary .index file for the given index path."""
index_file = index_path.parent / f"{index_path.stem}.index"
return index_file.stat().st_size if index_file.exists() else 0
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
if not meta_path.exists():
return None
try:
return json.loads(meta_path.read_text())
except json.JSONDecodeError:
return None
def run_workflow(
*,
label: str,
index_path: Path,
initial_paragraphs: list[str],
update_paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
query: str,
top_k: int,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
ensure_index_dir(index_path)
cleanup_index_files(index_path)
print(f"{prefix}Building initial index...")
build_initial_index(
index_path,
initial_paragraphs,
model_name,
embedding_mode,
is_recompute=is_recompute,
)
initial_size = index_file_size(index_path)
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
print(f"\n{prefix}Updating index with additional passages...")
update_index(
index_path,
start_id=len(initial_paragraphs),
paragraphs=update_paragraphs,
model_name=model_name,
embedding_mode=embedding_mode,
is_recompute=is_recompute,
)
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results,
"after_results": after_results,
"metadata": load_metadata_snapshot(index_path),
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--initial-files",
type=Path,
nargs="+",
default=DEFAULT_INITIAL_FILES,
help="Initial document files (PDF/TXT) used to build the base index",
)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/examples/leann-demo.leann"),
help="Destination index path (default: .leann/examples/leann-demo.leann)",
)
parser.add_argument(
"--initial-count",
type=int,
default=8,
help="Number of chunks to use from the initial documents (default: 8)",
)
parser.add_argument(
"--update-files",
type=Path,
nargs="*",
default=DEFAULT_UPDATE_FILES,
help="Additional documents to add during update (PDF/TXT)",
)
parser.add_argument(
"--update-count",
type=int,
default=4,
help="Number of chunks to append from update documents (default: 4)",
)
parser.add_argument(
"--update-text",
type=str,
default=(
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
"on demand to keep disk usage minimal."
),
help="Fallback text to append if --update-files is omitted",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
help="Number of results to show for each search (default: 4)",
)
parser.add_argument(
"--query",
type=str,
default=DEFAULT_QUERY,
help="Query to run before/after the update",
)
parser.add_argument(
"--embedding-model",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--compare-no-recompute",
dest="compare_no_recompute",
action="store_true",
help="Also run a baseline with is_recompute=False and report its index growth.",
)
parser.add_argument(
"--skip-compare-no-recompute",
dest="compare_no_recompute",
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
ensure_index_dir(args.index_path)
register_project_directory(REPO_ROOT)
initial_chunks = load_chunks_from_files(list(args.initial_files))
if not initial_chunks:
raise ValueError("No text chunks extracted from the initial files.")
initial = initial_chunks[: args.initial_count]
if not initial:
raise ValueError("Initial chunk set is empty after applying --initial-count.")
if args.update_files:
update_chunks = load_chunks_from_files(list(args.update_files))
if not update_chunks:
raise ValueError("No text chunks extracted from the update files.")
to_add = update_chunks[: args.update_count]
else:
if not args.update_text:
raise ValueError("Provide --update-files or --update-text for the update step.")
to_add = [args.update_text]
if not to_add:
raise ValueError("Update chunk set is empty after applying --update-count.")
recompute_stats = run_workflow(
label="recompute",
index_path=args.index_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=True,
query=args.query,
top_k=args.top_k,
)
print_results("initial search", recompute_stats["before_results"])
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
)
if recompute_stats["metadata"]:
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if args.compare_no_recompute:
baseline_path = (
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
)
baseline_stats = run_workflow(
label="no-recompute",
index_path=baseline_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=False,
query=args.query,
top_k=args.top_k,
)
print(
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
f"{baseline_stats['delta']})"
)
after_texts = [res.text for res in recompute_stats["after_results"]]
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."
)
else:
print("[no-recompute] WARNING: search results differ from recompute baseline.")
if baseline_stats["metadata"]:
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[no-recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,35 @@
"""
Grep Search Example
Shows how to use grep-based text search instead of semantic search.
Useful when you need exact text matches rather than meaning-based results.
"""
from leann import LeannSearcher
# Load your index
searcher = LeannSearcher("my-documents.leann")
# Regular semantic search
print("=== Semantic Search ===")
results = searcher.search("machine learning algorithms", top_k=3)
for result in results:
print(f"Score: {result.score:.3f}")
print(f"Text: {result.text[:80]}...")
print()
# Grep-based search for exact text matches
print("=== Grep Search ===")
results = searcher.search("def train_model", top_k=3, use_grep=True)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:80]}...")
print()
# Find specific error messages
error_results = searcher.search("FileNotFoundError", use_grep=True)
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
# Search for function definitions
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
print(f"Found {len(func_results)} class definitions")

28
llms.txt Normal file
View File

@@ -0,0 +1,28 @@
# llms.txt — LEANN MCP and Agent Integration
product: LEANN
homepage: https://github.com/yichuan-w/LEANN
contact: https://github.com/yichuan-w/LEANN/issues
# Installation
install: uv tool install leann-core --with leann
# MCP Server Entry Point
mcp.server: leann_mcp
mcp.protocol_version: 2024-11-05
# Tools
mcp.tools: leann_list, leann_search
mcp.tool.leann_list.description: List available LEANN indexes
mcp.tool.leann_list.input: {}
mcp.tool.leann_search.description: Semantic search across a named LEANN index
mcp.tool.leann_search.input.index_name: string, required
mcp.tool.leann_search.input.query: string, required
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
# Notes
note: Build indexes with `leann build <name> --docs <files...>` before searching.
example.add: claude mcp add --scope user leann-server -- leann_mcp
example.verify: claude mcp list | cat

View File

@@ -343,7 +343,8 @@ class DiskannSearcher(BaseSearcher):
"full_index_prefix": full_index_prefix, "full_index_prefix": full_index_prefix,
"num_threads": self.num_threads, "num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0), "num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1, # 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
"cache_mechanism": kwargs.get("cache_mechanism", 1),
"pq_prefix": "", "pq_prefix": "",
"partition_prefix": partition_prefix, "partition_prefix": partition_prefix,
} }

View File

@@ -10,7 +10,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
import numpy as np import numpy as np
import zmq import zmq
@@ -32,6 +32,16 @@ if not logger.handlers:
logger.propagate = False logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_diskann_embedding_server( def create_diskann_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
zmq_port: int = 5555, zmq_port: int = 5555,
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5 logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation # Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
continue continue
# Process the request # Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(f"Computed embeddings shape: {embeddings.shape}") logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation # Validation

View File

@@ -1,11 +1,11 @@
[build-system] [build-system]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"] requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
build-backend = "scikit_build_core.build" build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.3.2" version = "0.3.4"
dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -49,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE) set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable additional SIMD versions to speed up compilation # Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
# ARM64-specific configuration
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
message(STATUS "Configuring Faiss for ARM64 architecture")
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
else()
# Use generic optimization for other ARM64 platforms (like macOS)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
endif()
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
message(STATUS "Using ARM64-compatible Faiss submodule")
endif()
# Additional optimization options from INSTALL.md # Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE) set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)

View File

@@ -5,6 +5,8 @@ import os
import struct import struct
import sys import sys
import time import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np import numpy as np
@@ -237,6 +239,288 @@ def write_compact_format(
f_out.write(storage_data) f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic --- # --- Main Conversion Logic ---
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution --- # --- Script Execution ---
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@@ -14,7 +14,7 @@ from leann.interface import (
from leann.registry import register_backend from leann.registry import register_backend
from leann.searcher_base import BaseSearcher from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -90,8 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index_file = index_dir / f"{index_prefix}.index" index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file)) faiss.write_index(index, str(index_file))
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path): def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format""" """Convert built index to CSR format"""
@@ -133,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = ( backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.meta.get("is_compact", True), self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
self.meta.get("is_pruned", True), default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
) self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index" index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists(): if not index_file.exists():
@@ -150,6 +161,16 @@ class HNSWSearcher(BaseSearcher):
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search( def search(
self, self,
query: np.ndarray, query: np.ndarray,
@@ -248,6 +269,19 @@ class HNSWSearcher(BaseSearcher):
) )
search_time = time.time() - search_time search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds") logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels] if self._id_map:
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}

View File

@@ -10,7 +10,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
import msgpack import msgpack
import numpy as np import numpy as np
@@ -24,14 +24,36 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING) log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level) logger.setLevel(log_level)
# Ensure we have a handler if none exists # Ensure we have handlers if none exist
if not logger.handlers: if not logger.handlers:
handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
@@ -92,6 +114,35 @@ def create_hnsw_embedding_server(
embedding_dim = 0 embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata") logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
id_map: list[str] = []
try:
meta_path = Path(passages_file)
base = meta_path.name
if base.endswith(".meta.json"):
base = base[: -len(".meta.json")] # e.g., laion_index.leann
if base.endswith(".leann"):
base = base[: -len(".leann")] # e.g., laion_index
idmap_file = meta_path.parent / f"{base}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
id_map = [line.rstrip("\n") for line in f]
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
else:
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def _map_node_id(nid) -> str:
try:
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
idx = int(nid)
if 0 <= idx < len(id_map):
return id_map[idx]
except Exception:
pass
return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only) # (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
@@ -138,7 +189,12 @@ def create_hnsw_embedding_server(
): ):
last_request_type = "text" last_request_type = "text"
last_request_length = len(request) last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode) embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist())) rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -168,13 +224,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = [] found_indices: list[int] = []
for idx, nid in enumerate(node_ids): for idx, nid in enumerate(node_ids):
try: try:
passage_data = passages.get_passage(str(nid)) passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "") txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0: if isinstance(txt, str) and len(txt) > 0:
texts.append(txt) texts.append(txt)
found_indices.append(idx) found_indices.append(idx)
else: else:
logger.error(f"Empty text for passage ID {nid}") logger.error(f"Empty text for passage ID {passage_id}")
except KeyError: except KeyError:
logger.error(f"Passage ID {nid} not found") logger.error(f"Passage ID {nid} not found")
except Exception as e: except Exception as e:
@@ -187,7 +244,10 @@ def create_hnsw_embedding_server(
if texts: if texts:
try: try:
embeddings = compute_embeddings( embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
) )
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -238,13 +298,14 @@ def create_hnsw_embedding_server(
found_indices: list[int] = [] found_indices: list[int] = []
for idx, nid in enumerate(node_ids): for idx, nid in enumerate(node_ids):
try: try:
passage_data = passages.get_passage(str(nid)) passage_id = _map_node_id(nid)
passage_data = passages.get_passage(passage_id)
txt = passage_data.get("text", "") txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0: if isinstance(txt, str) and len(txt) > 0:
texts.append(txt) texts.append(txt)
found_indices.append(idx) found_indices.append(idx)
else: else:
logger.error(f"Empty text for passage ID {nid}") logger.error(f"Empty text for passage ID {passage_id}")
except KeyError: except KeyError:
logger.error(f"Passage with ID {nid} not found") logger.error(f"Passage with ID {nid} not found")
except Exception as e: except Exception as e:
@@ -252,7 +313,12 @@ def create_hnsw_embedding_server(
if texts: if texts:
try: try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.3.2" version = "0.3.4"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.3.2", "leann-core==0.3.4",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.3.2" version = "0.3.4"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -6,6 +6,8 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import logging import logging
import pickle import pickle
import re
import subprocess
import time import time
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -13,6 +15,7 @@ from pathlib import Path
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
@@ -36,6 +39,7 @@ def compute_embeddings(
use_server: bool = True, use_server: bool = True,
port: Optional[int] = None, port: Optional[int] = None,
is_build=False, is_build=False,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Computes embeddings using different backends. Computes embeddings using different backends.
@@ -69,6 +73,7 @@ def compute_embeddings(
model_name, model_name,
mode=mode, mode=mode,
is_build=is_build, is_build=is_build,
provider_options=provider_options,
) )
@@ -275,6 +280,7 @@ class LeannBuilder:
embedding_model: str = "facebook/contriever", embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None, dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
**backend_kwargs, **backend_kwargs,
): ):
self.backend_name = backend_name self.backend_name = backend_name
@@ -297,6 +303,7 @@ class LeannBuilder:
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions self.dimensions = dimensions
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.embedding_options = embedding_options or {}
# Check if we need to use cosine distance for normalized embeddings # Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = { normalized_embeddings_models = {
@@ -404,6 +411,7 @@ class LeannBuilder:
self.embedding_model, self.embedding_model,
self.embedding_mode, self.embedding_mode,
use_server=False, use_server=False,
provider_options=self.embedding_options,
)[0] )[0]
) )
path = Path(index_path) path = Path(index_path)
@@ -443,8 +451,20 @@ class LeannBuilder:
self.embedding_mode, self.embedding_mode,
use_server=False, use_server=False,
is_build=True, is_build=True,
provider_options=self.embedding_options,
) )
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs) builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
@@ -469,14 +489,15 @@ class LeannBuilder:
], ],
} }
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend # Add storage status flags for HNSW backend
if self.backend_name == "hnsw": if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True) is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True) is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = ( meta_data["is_pruned"] = bool(is_recompute)
is_compact and is_recompute
) # Pruned only if compact and recompute
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
@@ -563,6 +584,17 @@ class LeannBuilder:
# Build the vector index using precomputed embeddings # Build the vector index using precomputed embeddings
string_ids = [str(id_val) for id_val in ids] string_ids = [str(id_val) for id_val in ids]
# Persist ID map (order == embeddings order)
try:
idmap_file = (
index_dir
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
)
with open(idmap_file, "w", encoding="utf-8") as f:
for sid in string_ids:
f.write(str(sid) + "\n")
except Exception:
pass
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path) builder_instance.build(embeddings, string_ids, index_path)
@@ -591,18 +623,166 @@ class LeannBuilder:
"embeddings_source": str(embeddings_file), "embeddings_source": str(embeddings_file),
} }
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend # Add storage status flags for HNSW backend
if self.backend_name == "hnsw": if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True) is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True) is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}") logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
def update_index(self, index_path: str):
"""Append new passages and vectors to an existing HNSW index."""
if not self.chunks:
raise ValueError("No new chunks provided for update.")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_prefix = path.stem
meta_path = index_dir / f"{index_name}.meta.json"
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
index_file = index_dir / f"{index_prefix}.index"
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
backend_name = meta.get("backend_name")
if backend_name != self.backend_name:
raise ValueError(
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
)
meta_backend_kwargs = meta.get("backend_kwargs", {})
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
if index_is_compact:
raise ValueError(
"Compact HNSW indices do not support in-place updates. Rebuild required."
)
distance_metric = meta_backend_kwargs.get(
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
).lower()
needs_recompute = bool(
meta.get("is_pruned")
or meta_backend_kwargs.get("is_recompute")
or self.backend_kwargs.get("is_recompute")
)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in self.chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
embedding_dim = embeddings.shape[1]
expected_dim = meta.get("dimensions")
if expected_dim is not None and expected_dim != embedding_dim:
raise ValueError(
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
)
from leann_backend_hnsw import faiss # type: ignore
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(
"Appended %d passages to index '%s'. New total: %d",
len(valid_chunks),
index_path,
len(offset_map),
)
self.chunks.clear()
if needs_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -626,6 +806,7 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta_data.get("embedding_options", {})
# Delegate portability handling to PassageManager # Delegate portability handling to PassageManager
self.passage_manager = PassageManager( self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
@@ -637,6 +818,8 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs index_path, **final_kwargs
) )
@@ -653,6 +836,7 @@ class LeannSearcher:
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0, batch_size: int = 0,
use_grep: bool = False,
**kwargs, **kwargs,
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
@@ -679,6 +863,10 @@ class LeannSearcher:
Returns: Returns:
List of SearchResult objects with text, metadata, and similarity scores List of SearchResult objects with text, metadata, and similarity scores
""" """
# Handle grep search
if use_grep:
return self._grep_search(query, top_k)
logger.info("🔍 LeannSearcher.search() called:") logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'") logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}") logger.info(f" Top_k: {top_k}")
@@ -795,9 +983,96 @@ class LeannSearcher:
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}") logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results return enriched_results
def _find_jsonl_file(self) -> Optional[str]:
"""Find the .jsonl file containing raw passages for grep search"""
index_path = Path(self.meta_path_str).parent
potential_files = [
index_path / "documents.leann.passages.jsonl",
index_path.parent / "documents.leann.passages.jsonl",
]
for file_path in potential_files:
if file_path.exists():
return str(file_path)
return None
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Perform grep-based search on raw passages"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl passages file found for grep search")
try:
cmd = ["grep", "-i", "-n", query, jsonl_file]
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 1:
return []
elif result.returncode != 0:
raise RuntimeError(f"Grep failed: {result.stderr}")
matches = []
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split(":", 1)
if len(parts) != 2:
continue
try:
data = json.loads(parts[1])
text = data.get("text", "")
score = text.lower().count(query.lower())
matches.append(
SearchResult(
id=data.get("id", parts[0]),
text=text,
metadata=data.get("metadata", {}),
score=float(score),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
except FileNotFoundError:
raise RuntimeError(
"grep command not found. Please install grep or use semantic search."
)
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Fallback regex search"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl file found")
pattern = re.compile(re.escape(query), re.IGNORECASE)
matches = []
with open(jsonl_file, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if pattern.search(line):
try:
data = json.loads(line.strip())
matches.append(
SearchResult(
id=data.get("id", str(line_num)),
text=data.get("text", ""),
metadata=data.get("metadata", {}),
score=float(len(pattern.findall(data.get("text", "")))),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
def cleanup(self): def cleanup(self):
"""Explicitly cleanup embedding server resources. """Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher, This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios. especially in test environments or batch processing scenarios.
""" """
@@ -853,6 +1128,7 @@ class LeannChat:
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0, batch_size: int = 0,
use_grep: bool = False,
**search_kwargs, **search_kwargs,
): ):
if llm_kwargs is None: if llm_kwargs is None:

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest( def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434" model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
"""Validate model name and provide suggestions if invalid""" """Validate model name and provide suggestions if invalid"""
if llm_type == "ollama": if llm_type == "ollama":
available_models = check_ollama_models(host) resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models: if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation." error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface): class OllamaChat(LLMInterface):
"""LLM interface for Ollama models.""" """LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model self.model = model
self.host = host self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try: try:
import requests import requests
# Check if the Ollama server is responsive # Check if the Ollama server is responsive
if host: if self.host:
requests.get(host) requests.get(self.host)
# Pre-check model availability with helpful suggestions # Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host) model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error: if model_error:
raise ValueError(model_error) raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'." "The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError( raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running." f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
) )
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface): class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models.""" """LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
) )
logger.info(f"Initializing OpenAI Chat with model='{model}'") logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try: try:
import openai import openai
self.client = openai.OpenAI(api_key=self.api_key) self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama": if llm_type == "ollama":
return OllamaChat( return OllamaChat(
model=model or "llama3:8b", model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"), host=llm_config.get("host"),
) )
elif llm_type == "hf": elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai": elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated": elif llm_type == "simulated":

View File

@@ -1,6 +1,6 @@
""" """
Enhanced chunking utilities with AST-aware code chunking support. Enhanced chunking utilities with AST-aware code chunking support.
Provides unified interface for both traditional and AST-based text chunking. Packaged within leann-core so installed wheels can import it reliably.
""" """
import logging import logging
@@ -22,30 +22,9 @@ CODE_EXTENSIONS = {
".jsx": "typescript", ".jsx": "typescript",
} }
# Default chunk parameters for different content types
DEFAULT_CHUNK_PARAMS = {
"code": {
"max_chunk_size": 512,
"chunk_overlap": 64,
},
"text": {
"chunk_size": 256,
"chunk_overlap": 128,
},
}
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]: def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
""" """Separate documents into code files and regular text files."""
Separate documents into code files and regular text files.
Args:
documents: List of LlamaIndex Document objects
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
Returns:
Tuple of (code_documents, text_documents)
"""
if code_extensions is None: if code_extensions is None:
code_extensions = CODE_EXTENSIONS code_extensions = CODE_EXTENSIONS
@@ -53,16 +32,10 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
text_docs = [] text_docs = []
for doc in documents: for doc in documents:
# Get file path from metadata file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
file_path = doc.metadata.get("file_path", "")
if not file_path:
# Fallback to file_name
file_path = doc.metadata.get("file_name", "")
if file_path: if file_path:
file_ext = Path(file_path).suffix.lower() file_ext = Path(file_path).suffix.lower()
if file_ext in code_extensions: if file_ext in code_extensions:
# Add language info to metadata
doc.metadata["language"] = code_extensions[file_ext] doc.metadata["language"] = code_extensions[file_ext]
doc.metadata["is_code"] = True doc.metadata["is_code"] = True
code_docs.append(doc) code_docs.append(doc)
@@ -70,7 +43,6 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
doc.metadata["is_code"] = False doc.metadata["is_code"] = False
text_docs.append(doc) text_docs.append(doc)
else: else:
# If no file path, treat as text
doc.metadata["is_code"] = False doc.metadata["is_code"] = False
text_docs.append(doc) text_docs.append(doc)
@@ -79,7 +51,7 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
def get_language_from_extension(file_path: str) -> Optional[str]: def get_language_from_extension(file_path: str) -> Optional[str]:
"""Get the programming language from file extension.""" """Return language string from a filename/extension using CODE_EXTENSIONS."""
ext = Path(file_path).suffix.lower() ext = Path(file_path).suffix.lower()
return CODE_EXTENSIONS.get(ext) return CODE_EXTENSIONS.get(ext)
@@ -90,40 +62,26 @@ def create_ast_chunks(
chunk_overlap: int = 64, chunk_overlap: int = 64,
metadata_template: str = "default", metadata_template: str = "default",
) -> list[str]: ) -> list[str]:
""" """Create AST-aware chunks from code documents using astchunk.
Create AST-aware chunks from code documents using astchunk.
Args: Falls back to traditional chunking if astchunk is unavailable.
documents: List of code documents
max_chunk_size: Maximum characters per chunk
chunk_overlap: Number of AST nodes to overlap between chunks
metadata_template: Template for chunk metadata
Returns:
List of text chunks with preserved code structure
""" """
try: try:
from astchunk import ASTChunkBuilder from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e: except ImportError as e:
logger.error(f"astchunk not available: {e}") logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files") logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap) return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = [] all_chunks = []
for doc in documents: for doc in documents:
# Get language from metadata (set by detect_code_files)
language = doc.metadata.get("language") language = doc.metadata.get("language")
if not language: if not language:
logger.warning( logger.warning("No language detected; falling back to traditional chunking")
"No language detected for document, falling back to traditional chunking" all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
)
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
continue continue
try: try:
# Configure astchunk
configs = { configs = {
"max_chunk_size": max_chunk_size, "max_chunk_size": max_chunk_size,
"language": language, "language": language,
@@ -131,7 +89,6 @@ def create_ast_chunks(
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0, "chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
} }
# Add repository-level metadata if available
repo_metadata = { repo_metadata = {
"file_path": doc.metadata.get("file_path", ""), "file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""), "file_name": doc.metadata.get("file_name", ""),
@@ -140,17 +97,13 @@ def create_ast_chunks(
} }
configs["repo_level_metadata"] = repo_metadata configs["repo_level_metadata"] = repo_metadata
# Create chunk builder and process
chunk_builder = ASTChunkBuilder(**configs) chunk_builder = ASTChunkBuilder(**configs)
code_content = doc.get_content() code_content = doc.get_content()
if not code_content or not code_content.strip(): if not code_content or not code_content.strip():
logger.warning("Empty code content, skipping") logger.warning("Empty code content, skipping")
continue continue
chunks = chunk_builder.chunkify(code_content) chunks = chunk_builder.chunkify(code_content)
# Extract text content from chunks
for chunk in chunks: for chunk in chunks:
if hasattr(chunk, "text"): if hasattr(chunk, "text"):
chunk_text = chunk.text chunk_text = chunk.text
@@ -159,7 +112,6 @@ def create_ast_chunks(
elif isinstance(chunk, str): elif isinstance(chunk, str):
chunk_text = chunk chunk_text = chunk
else: else:
# Try to convert to string
chunk_text = str(chunk) chunk_text = str(chunk)
if chunk_text and chunk_text.strip(): if chunk_text and chunk_text.strip():
@@ -168,12 +120,10 @@ def create_ast_chunks(
logger.info( logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}" f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
) )
except Exception as e: except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}") logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking") logger.info("Falling back to traditional chunking")
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap) all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(traditional_chunks)
return all_chunks return all_chunks
@@ -181,23 +131,10 @@ def create_ast_chunks(
def create_traditional_chunks( def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128 documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]: ) -> list[str]:
""" """Create traditional text chunks using LlamaIndex SentenceSplitter."""
Create traditional text chunks using LlamaIndex SentenceSplitter.
Args:
documents: List of documents to chunk
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
Returns:
List of text chunks
"""
# Handle invalid chunk_size values
if chunk_size <= 0: if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256") logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256 chunk_size = 256
# Ensure chunk_overlap is not negative and not larger than chunk_size
if chunk_overlap < 0: if chunk_overlap < 0:
chunk_overlap = 0 chunk_overlap = 0
if chunk_overlap >= chunk_size: if chunk_overlap >= chunk_size:
@@ -215,12 +152,9 @@ def create_traditional_chunks(
try: try:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
if nodes: if nodes:
chunk_texts = [node.get_content() for node in nodes] all_texts.extend(node.get_content() for node in nodes)
all_texts.extend(chunk_texts)
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
except Exception as e: except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}") logger.error(f"Traditional chunking failed for document: {e}")
# As last resort, add the raw content
content = doc.get_content() content = doc.get_content()
if content and content.strip(): if content and content.strip():
all_texts.append(content.strip()) all_texts.append(content.strip())
@@ -238,32 +172,13 @@ def create_text_chunks(
code_file_extensions: Optional[list[str]] = None, code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True, ast_fallback_traditional: bool = True,
) -> list[str]: ) -> list[str]:
""" """Create text chunks from documents with optional AST support for code files."""
Create text chunks from documents with optional AST support for code files.
Args:
documents: List of LlamaIndex Document objects
chunk_size: Size for traditional text chunks
chunk_overlap: Overlap for traditional text chunks
use_ast_chunking: Whether to use AST chunking for code files
ast_chunk_size: Size for AST chunks
ast_chunk_overlap: Overlap for AST chunks
code_file_extensions: Custom list of code file extensions
ast_fallback_traditional: Fall back to traditional chunking on AST errors
Returns:
List of text chunks
"""
if not documents: if not documents:
logger.warning("No documents provided for chunking") logger.warning("No documents provided for chunking")
return [] return []
# Create a local copy of supported extensions for this function call
local_code_extensions = CODE_EXTENSIONS.copy() local_code_extensions = CODE_EXTENSIONS.copy()
# Update supported extensions if provided
if code_file_extensions: if code_file_extensions:
# Map extensions to languages (simplified mapping)
ext_mapping = { ext_mapping = {
".py": "python", ".py": "python",
".java": "java", ".java": "java",
@@ -273,47 +188,32 @@ def create_text_chunks(
} }
for ext in code_file_extensions: for ext in code_file_extensions:
if ext.lower() not in local_code_extensions: if ext.lower() not in local_code_extensions:
# Try to guess language from extension
if ext.lower() in ext_mapping: if ext.lower() in ext_mapping:
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()] local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
else: else:
logger.warning(f"Unsupported extension {ext}, will use traditional chunking") logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
all_chunks = [] all_chunks = []
if use_ast_chunking: if use_ast_chunking:
# Separate code and text documents using local extensions
code_docs, text_docs = detect_code_files(documents, local_code_extensions) code_docs, text_docs = detect_code_files(documents, local_code_extensions)
# Process code files with AST chunking
if code_docs: if code_docs:
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
try: try:
ast_chunks = create_ast_chunks( all_chunks.extend(
create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
) )
all_chunks.extend(ast_chunks) )
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
except Exception as e: except Exception as e:
logger.error(f"AST chunking failed: {e}") logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional: if ast_fallback_traditional:
logger.info("Falling back to traditional chunking for code files") all_chunks.extend(
traditional_code_chunks = create_traditional_chunks( create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
code_docs, chunk_size, chunk_overlap
) )
all_chunks.extend(traditional_code_chunks)
else: else:
raise raise
# Process text files with traditional chunking
if text_docs: if text_docs:
logger.info(f"Processing {len(text_docs)} text files with traditional chunking") all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
all_chunks.extend(text_chunks)
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
else: else:
# Use traditional chunking for all files
logger.info(f"Processing {len(documents)} documents with traditional chunking")
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap) all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}") logger.info(f"Total chunks created: {len(all_chunks)}")

View File

@@ -1,6 +1,5 @@
import argparse import argparse
import asyncio import asyncio
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any, Optional, Union
@@ -10,6 +9,7 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import register_project_directory from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def extract_pdf_text_with_pymupdf(file_path: str) -> str: def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -124,6 +124,24 @@ Examples:
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)", help="Embedding backend mode (default: sentence-transformers)",
) )
build_parser.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
build_parser.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
build_parser.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument( build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index" "--force", "-f", action="store_true", help="Force rebuild existing index"
) )
@@ -239,6 +257,11 @@ Examples:
# Ask command # Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions") ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name") ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--llm", "--llm",
type=str, type=str,
@@ -249,7 +272,12 @@ Examples:
ask_parser.add_argument( ask_parser.add_argument(
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)" "--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
) )
ask_parser.add_argument("--host", type=str, default="http://localhost:11434") ask_parser.add_argument(
"--host",
type=str,
default=None,
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--interactive", "-i", action="store_true", help="Interactive chat mode" "--interactive", "-i", action="store_true", help="Interactive chat mode"
) )
@@ -278,6 +306,18 @@ Examples:
default=None, default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.", help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
) )
ask_parser.add_argument(
"--api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
)
ask_parser.add_argument(
"--api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# List command # List command
subparsers.add_parser("list", help="List all indexes") subparsers.add_parser("list", help="List all indexes")
@@ -1216,13 +1256,8 @@ Examples:
if use_ast: if use_ast:
print("🧠 Using AST-aware chunking for code files") print("🧠 Using AST-aware chunking for code files")
try: try:
# Import enhanced chunking utilities # Import enhanced chunking utilities from packaged module
# Add apps directory to path to import chunking utilities from .chunking_utils import create_text_chunks
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
if apps_dir.exists():
sys.path.insert(0, str(apps_dir))
from chunking import create_text_chunks
# Use enhanced chunking with AST support # Use enhanced chunking with AST support
all_texts = create_text_chunks( all_texts = create_text_chunks(
@@ -1237,7 +1272,9 @@ Examples:
) )
except ImportError as e: except ImportError as e:
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking") print(
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
)
use_ast = False use_ast = False
if not use_ast: if not use_ast:
@@ -1329,10 +1366,20 @@ Examples:
print(f"Building index '{index_name}' with {args.backend} backend...") print(f"Building index '{index_name}' with {args.backend} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder( builder = LeannBuilder(
backend_name=args.backend, backend_name=args.backend,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode, embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree, graph_degree=args.graph_degree,
complexity=args.complexity, complexity=args.complexity,
is_compact=args.compact, is_compact=args.compact,
@@ -1480,11 +1527,38 @@ Examples:
llm_config = {"type": args.llm, "model": args.model} llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama": if args.llm == "ollama":
llm_config["host"] = args.host llm_config["host"] = resolve_ollama_host(args.host)
elif args.llm == "openai":
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key:
llm_config["api_key"] = resolved_api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive: if args.interactive:
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit") print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40) print("=" * 40)
@@ -1497,41 +1571,14 @@ Examples:
if not user_input: if not user_input:
continue continue
# Prepare LLM kwargs with thinking budget if specified _ask_once(user_input)
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
else: else:
query = input("Enter your question: ").strip() query = initial_query or input("Enter your question: ").strip()
if query: if not query:
# Prepare LLM kwargs with thinking budget if specified print("No question provided. Exiting.")
llm_kwargs = {} return
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask( _ask_once(query)
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
async def run(self, args=None): async def run(self, args=None):
parser = self.create_parser() parser = self.create_parser()

View File

@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
import logging import logging
import os import os
import time import time
from typing import Any from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level # Set up logger with proper level
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -31,6 +33,7 @@ def compute_embeddings(
adaptive_optimization: bool = True, adaptive_optimization: bool = True,
manual_tokenize: bool = False, manual_tokenize: bool = False,
max_length: int = 512, max_length: int = 512,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Unified embedding computation entry point Unified embedding computation entry point
@@ -46,6 +49,8 @@ def compute_embeddings(
Returns: Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
provider_options = provider_options or {}
if mode == "sentence-transformers": if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers( return compute_embeddings_sentence_transformers(
texts, texts,
@@ -57,11 +62,21 @@ def compute_embeddings(
max_length=max_length, max_length=max_length,
) )
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai(texts, model_name) return compute_embeddings_openai(
texts,
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
)
elif mode == "mlx": elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name) return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama": elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build) return compute_embeddings_ollama(
texts,
model_name,
is_build=is_build,
host=provider_options.get("host"),
)
elif mode == "gemini": elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build) return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else: else:
@@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers(
return embeddings return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray: def compute_embeddings_openai(
texts: list[str],
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API""" """Compute embeddings using OpenAI API"""
try: try:
import os
import openai import openai
except ImportError as e: except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}") raise ImportError(f"OpenAI package not installed: {e}")
@@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI." f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
) )
api_key = os.getenv("OPENAI_API_KEY") resolved_base_url = resolve_openai_base_url(base_url)
if not api_key: resolved_api_key = resolve_openai_api_key(api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set") raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client # Cache OpenAI client
cache_key = "openai_client" cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache: if cache_key in _model_cache:
client = _model_cache[cache_key] client = _model_cache[cache_key]
else: else:
client = openai.OpenAI(api_key=api_key) client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client _model_cache[cache_key] = client
logger.info("OpenAI client cached") logger.info("OpenAI client cached")
@@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama( def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434" texts: list[str],
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using Ollama API with simplified batch processing. Compute embeddings using Ollama API with simplified batch processing.
@@ -518,7 +541,7 @@ def compute_embeddings_ollama(
texts: List of texts to compute embeddings for texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large") model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434) host: Ollama host URL (defaults to environment or http://localhost:11434)
Returns: Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -533,17 +556,19 @@ def compute_embeddings_ollama(
if not texts: if not texts:
raise ValueError("Cannot compute embeddings for empty text list") raise ValueError("Cannot compute embeddings for empty text list")
resolved_host = resolve_ollama_host(host)
logger.info( logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'" f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
) )
# Check if Ollama is running # Check if Ollama is running
try: try:
response = requests.get(f"{host}/api/version", timeout=5) response = requests.get(f"{resolved_host}/api/version", timeout=5)
response.raise_for_status() response.raise_for_status()
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
error_msg = ( error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n" f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
"Please ensure Ollama is running:\n" "Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n" " • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n" " • Windows: Make sure Ollama is running in the system tray\n\n"
@@ -555,7 +580,7 @@ def compute_embeddings_ollama(
# Check if model exists and provide helpful suggestions # Check if model exists and provide helpful suggestions
try: try:
response = requests.get(f"{host}/api/tags", timeout=5) response = requests.get(f"{resolved_host}/api/tags", timeout=5)
response.raise_for_status() response.raise_for_status()
models = response.json() models = response.json()
model_names = [model["name"] for model in models.get("models", [])] model_names = [model["name"] for model in models.get("models", [])]
@@ -618,7 +643,9 @@ def compute_embeddings_ollama(
# Verify the model supports embeddings by testing it # Verify the model supports embeddings by testing it
try: try:
test_response = requests.post( test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10 f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
) )
if test_response.status_code != 200: if test_response.status_code != 200:
error_msg = ( error_msg = (
@@ -665,7 +692,7 @@ def compute_embeddings_ollama(
while retry_count < max_retries: while retry_count < max_retries:
try: try:
response = requests.post( response = requests.post(
f"{host}/api/embeddings", f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text}, json={"model": model_name, "prompt": truncated_text},
timeout=30, timeout=30,
) )

View File

@@ -8,6 +8,8 @@ import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from .settings import encode_provider_options
# Lightweight, self-contained server manager with no cross-process inspection # Lightweight, self-contained server manager with no cross-process inspection
# Set up logging based on environment variable # Set up logging based on environment variable
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here # passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
config_signature = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
# If this manager already has a live server, just reuse it # If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port: if (
self.server_process
and self.server_process.poll() is None
and self.server_port
and self._server_config == config_signature
):
logger.info("Reusing in-process server") logger.info("Reusing in-process server")
return True, self.server_port return True, self.server_port
# Configuration changed, stop existing server before starting a new one
if self.server_process and self.server_process.poll() is None:
logger.info("Existing server configuration differs; restarting embedding server")
self.stop_server()
# For Colab environment, use a different strategy # For Colab environment, use a different strategy
if _is_colab_environment(): if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy") logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs) return self._start_server_colab(
port,
model_name,
embedding_mode,
provider_options=provider_options,
**kwargs,
)
# Always pick a fresh available port # Always pick a fresh available port
try: try:
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
return False, port return False, port
# Start a new server # Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) return self._start_new_server(
actual_port,
model_name,
embedding_mode,
provider_options=provider_options,
config_signature=config_signature,
**kwargs,
)
def _start_server_colab( def _start_server_colab(
self, self,
port: int, port: int,
model_name: str, model_name: str,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
provider_options: Optional[dict] = None,
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start server with Colab-specific configuration.""" """Start server with Colab-specific configuration."""
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
try: try:
# In Colab, we'll use a more direct approach # In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port) self._launch_server_process_colab(
return self._wait_for_server_ready_colab(actual_port) command,
actual_port,
provider_options=provider_options,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e: except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}") logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port return False, actual_port
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance # Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _start_new_server( def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs self,
port: int,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start a new embedding server on the given port.""" """Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...") logger.info(f"Starting embedding server on port {port}...")
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
command = self._build_server_command(port, model_name, embedding_mode, **kwargs) command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try: try:
self._launch_server_process(command, port) self._launch_server_process(
return self._wait_for_server_ready(port) command,
port,
provider_options=provider_options,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e: except Exception as e:
logger.error(f"Failed to start embedding server: {e}") logger.error(f"Failed to start embedding server: {e}")
return False, port return False, port
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
return command return command
def _launch_server_process(self, command: list, port: int) -> None: def _launch_server_process(
self,
command: list,
port: int,
provider_options: Optional[dict] = None,
) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}") logger.info(f"Command: {' '.join(command)}")
@@ -193,14 +262,20 @@ class EmbeddingServerManager:
# Start embedding server subprocess # Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}") logger.info(f"Starting server process with command: {' '.join(command)}")
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
cwd=project_root, cwd=project_root,
stdout=stdout_target, stdout=stdout_target,
stderr=stderr_target, stderr=stderr_target,
env=env,
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse # Record config for in-process reuse (best effort; refined later when ready)
try: try:
self._server_config = { self._server_config = {
"model_name": command[command.index("--model-name") + 1] "model_name": command[command.index("--model-name") + 1]
@@ -212,12 +287,14 @@ class EmbeddingServerManager:
"embedding_mode": command[command.index("--embedding-mode") + 1] "embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command if "--embedding-mode" in command
else "sentence-transformers", else "sentence-transformers",
"provider_options": provider_options or {},
} }
except Exception: except Exception:
self._server_config = { self._server_config = {
"model_name": "", "model_name": "",
"passages_file": "", "passages_file": "",
"embedding_mode": "sentence-transformers", "embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
} }
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
@@ -322,16 +399,27 @@ class EmbeddingServerManager:
# Removed: cross-process adoption no longer supported # Removed: cross-process adoption no longer supported
return return
def _launch_server_process_colab(self, command: list, port: int) -> None: def _launch_server_process_colab(
self,
command: list,
port: int,
provider_options: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}") logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management # In Colab, we need to be more careful about process management
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
env=env,
) )
self.server_port = port self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}") logger.info(f"Colab server process started with PID: {self.server_process.pid}")
@@ -345,6 +433,7 @@ class EmbeddingServerManager:
"model_name": "", "model_name": "",
"passages_file": "", "passages_file": "",
"embedding_mode": "sentence-transformers", "embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
} }
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:

View File

@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("WARNING: embedding_model not found in meta.json. Recompute will fail.") print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta.get("embedding_options", {})
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name, backend_module_name=backend_module_name,
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=distance_metric, distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}") raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
from .embedding_compute import compute_embeddings from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode) return compute_embeddings(
[query],
self.embedding_model,
embedding_mode,
provider_options=self.embedding_options,
)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server.""" """Compute embeddings using the ZMQ embedding server."""

View File

@@ -0,0 +1,74 @@
"""Runtime configuration helpers for LEANN."""
from __future__ import annotations
import json
import os
from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
def _clean_url(value: str) -> str:
"""Normalize URL strings by stripping trailing slashes."""
return value.rstrip("/") if value else value
def resolve_ollama_host(explicit: str | None = None) -> str:
"""Resolve the Ollama-compatible endpoint to use."""
candidates = (
explicit,
os.getenv("LEANN_LOCAL_LLM_HOST"),
os.getenv("LEANN_OLLAMA_HOST"),
os.getenv("OLLAMA_HOST"),
os.getenv("LOCAL_LLM_ENDPOINT"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OLLAMA_HOST)
def resolve_openai_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for OpenAI-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_OPENAI_BASE_URL"),
os.getenv("OPENAI_BASE_URL"),
os.getenv("LOCAL_OPENAI_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services."""
if explicit:
return explicit
return os.getenv("OPENAI_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes."""
if not options:
return None
try:
return json.dumps(options)
except (TypeError, ValueError):
# Fall back to empty payload if serialization fails
return None

View File

@@ -2,6 +2,8 @@
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code. Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
For agent-facing discovery details, see `llms.txt` in the repository root.
## Prerequisites ## Prerequisites
Install LEANN globally for MCP integration (with default backend): Install LEANN globally for MCP integration (with default backend):

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.3.2" version = "0.3.4"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -53,27 +53,10 @@ dependencies = [
"tree-sitter-java>=0.20.0", "tree-sitter-java>=0.20.0",
"tree-sitter-c-sharp>=0.20.0", "tree-sitter-c-sharp>=0.20.0",
"tree-sitter-typescript>=0.20.0", "tree-sitter-typescript>=0.20.0",
"torchvision>=0.23.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
dev = [
"pytest>=7.0",
"pytest-cov>=4.0",
"pytest-xdist>=3.0", # For parallel test execution
"black>=23.0",
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
"matplotlib",
"huggingface-hub>=0.20.0",
"pre-commit>=3.5.0",
]
test = [
"pytest>=7.0",
"pytest-timeout>=2.0",
"llama-index-core>=0.12.0",
"python-dotenv>=1.0.0",
]
diskann = [ diskann = [
"leann-backend-diskann", "leann-backend-diskann",
] ]
@@ -99,11 +82,38 @@ wechat-exporter = "wechat_exporter.main:main"
leann-core = { path = "packages/leann-core", editable = true } leann-core = { path = "packages/leann-core", editable = true }
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true } leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true } leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
astchunk = { path = "packages/astchunk-leann", editable = true }
[dependency-groups]
# Minimal lint toolchain for CI and local hooks
lint = [
"pre-commit>=3.5.0",
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
]
# Test toolchain (no heavy project runtime deps)
test = [
"pytest>=7.0",
"pytest-cov>=4.0",
"pytest-xdist>=3.0",
"pytest-timeout>=2.0",
"python-dotenv>=1.0.0",
]
# dependencies by apps/ should list here
dev = [
"matplotlib",
"huggingface-hub>=0.20.0",
]
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"
line-length = 100 line-length = 100
extend-exclude = ["third_party"] extend-exclude = [
"third_party",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
]
[tool.ruff.lint] [tool.ruff.lint]

121
scripts/hf_upload.py Normal file
View File

@@ -0,0 +1,121 @@
#!/usr/bin/env python3
"""
Upload local evaluation data to Hugging Face Hub, excluding diskann_rpj_wiki.
Defaults:
- repo_id: LEANN-RAG/leann-rag-evaluation-data (dataset)
- folder_path: benchmarks/data
- ignore_patterns: diskann_rpj_wiki/** and .cache/**
Requires authentication via `huggingface-cli login` or HF_TOKEN env var.
"""
from __future__ import annotations
import argparse
import os
try:
from huggingface_hub import HfApi
except Exception as e:
raise SystemExit(
"huggingface_hub is required. Install with: pip install huggingface_hub hf_transfer"
) from e
def _enable_transfer_accel_if_available() -> None:
"""Best-effort enabling of accelerated transfers across hub versions.
Tries the public util if present; otherwise, falls back to env flag when
hf_transfer is installed. Silently no-ops if unavailable.
"""
try:
# Newer huggingface_hub exposes this under utils
from huggingface_hub.utils import hf_hub_enable_hf_transfer # type: ignore
hf_hub_enable_hf_transfer()
return
except Exception:
pass
try:
# If hf_transfer is installed, set env flag recognized by the hub
import hf_transfer # noqa: F401
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
except Exception:
# Acceleration not available; proceed without it
pass
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Upload local data to HF, excluding diskann_rpj_wiki")
p.add_argument(
"--repo-id",
default="LEANN-RAG/leann-rag-evaluation-data",
help="Target dataset repo id (namespace/name)",
)
p.add_argument(
"--folder-path",
default="benchmarks/data",
help="Local folder to upload (default: benchmarks/data)",
)
p.add_argument(
"--ignore",
default=["diskann_rpj_wiki/**", ".cache/**"],
nargs="+",
help="Glob patterns to ignore (space-separated)",
)
p.add_argument(
"--allow",
default=["**"],
nargs="+",
help="Glob patterns to allow (space-separated). Defaults to everything.",
)
p.add_argument(
"--message",
default="sync local data (exclude diskann_rpj_wiki)",
help="Commit message",
)
p.add_argument(
"--no-transfer-accel",
action="store_true",
help="Disable hf_transfer accelerated uploads",
)
return p.parse_args()
def main() -> None:
args = parse_args()
if not args.no_transfer_accel:
_enable_transfer_accel_if_available()
if not os.path.isdir(args.folder_path):
raise SystemExit(f"Folder not found: {args.folder_path}")
print("Uploading to Hugging Face Hub:")
print(f" repo_id: {args.repo_id}")
print(" repo_type: dataset")
print(f" folder_path: {args.folder_path}")
print(f" allow_patterns: {args.allow}")
print(f" ignore_patterns:{args.ignore}")
api = HfApi()
# Perform upload. This skips unchanged files by content hash.
api.upload_folder(
repo_id=args.repo_id,
repo_type="dataset",
folder_path=args.folder_path,
path_in_repo=".",
allow_patterns=args.allow,
ignore_patterns=args.ignore,
commit_message=args.message,
)
print("Upload completed (unchanged files were skipped by the Hub).")
if __name__ == "__main__":
main()

View File

@@ -40,8 +40,8 @@ Tests DiskANN graph partitioning functionality:
### Install test dependencies: ### Install test dependencies:
```bash ```bash
# Using extras # Using uv dependency groups (tools only)
uv pip install -e ".[test]" uv sync --only-group test
``` ```
### Run all tests: ### Run all tests:

14
tests/test_cli_ask.py Normal file
View File

@@ -0,0 +1,14 @@
from leann.cli import LeannCLI
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
cli = LeannCLI()
parser = cli.create_parser()
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
assert args.command == "ask"
assert args.index_name == "my-docs"
assert args.query == "Where are prompts configured?"

7915
uv.lock generated
View File

File diff suppressed because it is too large Load Diff