Compare commits

..

1 Commits

Author SHA1 Message Date
Andy Lee
fcbcde1ea8 feat: implement smart memory configuration for DiskANN
- Add intelligent memory calculation based on data size and system specs
- search_memory_maximum: 1/10 of embedding size (controls PQ compression)
- build_memory_maximum: 50% of available RAM (controls sharding)
- Provides optimal balance between performance and memory usage
- Automatic fallback to default values if parameters are explicitly provided
2025-08-03 22:54:08 -07:00
182 changed files with 7860 additions and 31479 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text

View File

@@ -1,50 +0,0 @@
name: Bug Report
description: Report a bug in LEANN
labels: ["bug"]
body:
- type: textarea
id: description
attributes:
label: What happened?
description: A clear description of the bug
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: How to reproduce
placeholder: |
1. Install with...
2. Run command...
3. See error
validations:
required: true
- type: textarea
id: error
attributes:
label: Error message
description: Paste any error messages
render: shell
- type: input
id: version
attributes:
label: LEANN Version
placeholder: "0.1.0"
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating System
options:
- macOS
- Linux
- Windows
- Docker
validations:
required: true

View File

@@ -1,8 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Documentation
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
about: Read the docs first
- name: Discussions
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
about: Ask questions and share ideas

View File

@@ -1,27 +0,0 @@
name: Feature Request
description: Suggest a new feature for LEANN
labels: ["enhancement"]
body:
- type: textarea
id: problem
attributes:
label: What problem does this solve?
description: Describe the problem or need
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: How would you like this to work?
validations:
required: true
- type: textarea
id: example
attributes:
label: Example usage
description: Show how the API might look
render: python

View File

@@ -1,13 +0,0 @@
## What does this PR do?
<!-- Brief description of your changes -->
## Related Issues
Fixes #
## Checklist
- [ ] Tests pass (`uv run pytest`)
- [ ] Code formatted (`ruff format` and `ruff check`)
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)

View File

@@ -5,7 +5,6 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
jobs:
build:

View File

@@ -17,17 +17,26 @@ jobs:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Install uv and Python
uses: astral-sh/setup-uv@v6
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Run pre-commit with only lint group (no project deps)
run: |
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: |
uv tool install ruff
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build:
needs: lint
@@ -45,108 +54,45 @@ jobs:
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
# ARM64 Linux builds
- os: ubuntu-24.04-arm
- os: macos-latest
python: '3.9'
- os: ubuntu-24.04-arm
- os: macos-latest
python: '3.10'
- os: ubuntu-24.04-arm
- os: macos-latest
python: '3.11'
- os: ubuntu-24.04-arm
- os: macos-latest
python: '3.12'
- os: ubuntu-24.04-arm
- os: macos-latest
python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14
python: '3.10'
- os: macos-14
python: '3.11'
- os: macos-14
python: '3.12'
- os: macos-14
python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15
python: '3.10'
- os: macos-15
python: '3.11'
- os: macos-15
python: '3.12'
- os: macos-15
python: '3.13'
- os: macos-13
python: '3.9'
- os: macos-13
python: '3.10'
- os: macos-13
python: '3.11'
- os: macos-13
python: '3.12'
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Install uv and Python
uses: astral-sh/setup-uv@v6
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
patchelf
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Debug: Show system information
echo "🔍 System Information:"
echo "Architecture: $(uname -m)"
echo "OS: $(uname -a)"
echo "CPU info: $(lscpu | head -5)"
# Install math library based on architecture
ARCH=$(uname -m)
echo "🔍 Setting up math library for architecture: $ARCH"
if [[ "$ARCH" == "x86_64" ]]; then
# Install Intel MKL for DiskANN on x86_64
echo "📦 Installing Intel MKL for x86_64..."
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
echo "✅ Intel MKL installed for x86_64"
# Debug: Check MKL installation
echo "🔍 MKL Installation Check:"
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
elif [[ "$ARCH" == "aarch64" ]]; then
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
echo "📦 Installing OpenBLAS for ARM64..."
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
echo "✅ OpenBLAS installed for ARM64"
# Debug: Check OpenBLAS installation
echo "🔍 OpenBLAS Installation Check:"
dpkg -l | grep openblas || echo "OpenBLAS package not found"
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
fi
# Debug: Show final library paths
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
@@ -156,93 +102,55 @@ jobs:
- name: Install build dependencies
run: |
uv python install ${{ matrix.python }}
uv venv --python ${{ matrix.python }} .uv-build
if [[ "$RUNNER_OS" == "Windows" ]]; then
BUILD_PY=".uv-build\\Scripts\\python.exe"
else
BUILD_PY=".uv-build/bin/python"
fi
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --python "$BUILD_PY" auditwheel
uv pip install --system auditwheel
else
uv pip install --python "$BUILD_PY" delocate
uv pip install --system delocate
fi
if [[ "$RUNNER_OS" == "Windows" ]]; then
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
else
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
fi
- name: Set macOS environment variables
if: runner.os == 'macOS'
run: |
# Use brew --prefix to automatically detect Homebrew installation path
HOMEBREW_PREFIX=$(brew --prefix)
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
# Set compiler flags for OpenMP (required for both backends)
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
- name: Build packages
run: |
# Build core (platform independent)
cd packages/leann-core
uv build
cd ../..
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann-core
uv build
cd ../..
fi
# Build HNSW backend
cd packages/leann-backend-hnsw
if [[ "${{ matrix.os }}" == macos-* ]]; then
# Use system clang for better compatibility
if [ "${{ matrix.os }}" == "macos-latest" ]; then
# Use system clang instead of homebrew LLVM for better compatibility
export CC=clang
export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.0
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
export MACOSX_DEPLOYMENT_TARGET=11.0
uv build --wheel --python python
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
uv build --wheel --python python
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [[ "${{ matrix.os }}" == macos-* ]]; then
# Use system clang for better compatibility
if [ "${{ matrix.os }}" == "macos-latest" ]; then
# Use system clang instead of homebrew LLVM for better compatibility
export CC=clang
export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
# But Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
export MACOSX_DEPLOYMENT_TARGET=13.3
uv build --wheel --python python
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
uv build --wheel --python python
fi
cd ../..
# Build meta package (platform independent)
cd packages/leann
uv build
cd ../..
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann
uv build
cd ../..
fi
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
@@ -268,24 +176,10 @@ jobs:
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
fi
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
@@ -294,8 +188,7 @@ jobs:
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
@@ -306,82 +199,39 @@ jobs:
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Install built packages for testing
run: |
# Create uv-managed virtual environment with the requested interpreter
uv python install ${{ matrix.python }}
uv venv --python ${{ matrix.python }}
# Create a virtual environment
uv venv
source .venv/bin/activate || source .venv/Scripts/activate
if [[ "$RUNNER_OS" == "Windows" ]]; then
UV_PY=".venv\\Scripts\\python.exe"
else
UV_PY=".venv/bin/python"
# Install the built wheels
# Use --find-links to let uv choose the correct wheel for the platform
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
uv pip install leann-core --find-links packages/leann-core/dist
uv pip install leann --find-links packages/leann/dist
fi
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
# Install test dependency group only (avoids reinstalling project package)
uv pip install --python "$UV_PY" --group test
# Install core wheel built in this job
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$CORE_WHL" ]]; then
uv pip install --python "$UV_PY" "$CORE_WHL"
else
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
fi
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
if [[ "$RUNNER_OS" == "macOS" ]]; then
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
fi
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$HNSW_WHL" ]]; then
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$HNSW_WHL" ]]; then
uv pip install --python "$UV_PY" "$HNSW_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
fi
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
if [[ -z "$DISKANN_WHL" ]]; then
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
fi
if [[ -n "$DISKANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$DISKANN_WHL"
else
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
fi
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
if [[ -n "$LEANN_WHL" ]]; then
uv pip install --python "$UV_PY" "$LEANN_WHL"
else
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
fi
# Install test dependencies using extras
uv pip install -e ".[test]"
- name: Run tests with pytest
env:
CI: true
CI: true # Mark as CI environment to skip memory-intensive tests
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
MKL_NUM_THREADS: 1 # Single thread for MKL operations
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
pytest tests/ -v --tb=short
# Run all tests
pytest tests/
- name: Run sanity checks (optional)
run: |
@@ -399,53 +249,3 @@ jobs:
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/
arch-smoke:
name: Arch Linux smoke test (install & import)
needs: build
runs-on: ubuntu-latest
container:
image: archlinux:latest
steps:
- name: Prepare system
run: |
pacman -Syu --noconfirm
pacman -S --noconfirm python python-pip gcc git zlib openssl
- name: Download ALL wheel artifacts from this run
uses: actions/download-artifact@v5
with:
# Don't specify name, download all artifacts
path: ./wheels
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Create virtual environment and install wheels
run: |
uv venv
source .venv/bin/activate || source .venv/Scripts/activate
uv pip install --find-links wheels leann-core
uv pip install --find-links wheels leann-backend-hnsw
uv pip install --find-links wheels leann-backend-diskann
uv pip install --find-links wheels leann
- name: Import & tiny runtime check
env:
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
run: |
source .venv/bin/activate || source .venv/Scripts/activate
python - <<'PY'
import leann
import leann_backend_hnsw as h
import leann_backend_diskann as d
from leann import LeannBuilder, LeannSearcher
b = LeannBuilder(backend_name="hnsw")
b.add_text("hello arch")
b.build_index("arch_demo.leann")
s = LeannSearcher("arch_demo.leann")
print("search:", s.search("hello", top_k=1))
PY

View File

@@ -1,19 +0,0 @@
name: Link Check
on:
push:
branches: [ main, master ]
pull_request:
schedule:
- cron: "0 3 * * 1"
jobs:
link-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

30
.gitignore vendored
View File

@@ -18,12 +18,9 @@ demo/experiment_results/**/*.json
*.eml
*.emlx
*.json
*.png
!.vscode/*.json
*.sh
*.txt
!CMakeLists.txt
!llms.txt
latency_breakdown*.json
experiment_results/eval_results/diskann/*.json
aws/
@@ -37,15 +34,11 @@ build/
nprobe_logs/
micro/results
micro/contriever-INT8
data/*
!data/2501.14312v1 (1).pdf
!data/2506.08276v1.pdf
!data/PrideandPrejudice.txt
!data/huawei_pangu.md
!data/ground_truth/
!data/indices/
!data/queries/
!data/.gitattributes
examples/data/*
!examples/data/2501.14312v1 (1).pdf
!examples/data/2506.08276v1.pdf
!examples/data/PrideandPrejudice.txt
!examples/data/README.md
*.qdstrm
benchmark_results/
results/
@@ -95,16 +88,3 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
benchmarks/data/
## multi vector
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
# If you need to commit a specific demo PDF, remove this negation locally.
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
# AUR build directory (Arch Linux)
paru-bin/

3
.gitmodules vendored
View File

@@ -14,6 +14,3 @@
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git
[submodule "packages/astchunk-leann"]
path = packages/astchunk-leann
url = https://github.com/yichuan-w/astchunk-leann.git

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@@ -10,8 +10,7 @@ repos:
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7 # Fixed version to match pyproject.toml
rev: v0.2.1
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

View File

@@ -1,5 +0,0 @@
{
"recommendations": [
"charliermarsh.ruff",
]
}

22
.vscode/settings.json vendored
View File

@@ -1,22 +0,0 @@
{
"python.defaultInterpreterPath": ".venv/bin/python",
"python.terminal.activateEnvironment": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",
"source.fixAll": "explicit"
},
"editor.insertSpaces": true,
"editor.tabSize": 4
},
"ruff.enable": true,
"files.watcherExclude": {
"**/.venv/**": true,
"**/__pycache__/**": true,
"**/*.egg-info/**": true,
"**/build/**": true,
"**/dist/**": true
}
}

View File

@@ -1,110 +0,0 @@
# Issue #159 Performance Analysis - Conclusion
## Problem Summary
User reported search times of 15-30 seconds instead of the ~2 seconds mentioned in the paper.
**Configuration:**
- GPU: 4090×1
- Embedding Model: BAAI/bge-large-zh-v1.5 (~300M parameters)
- Data Size: 180MB text (~90K chunks)
- Backend: HNSW
- beam_width: 10
- Other parameters: Default values
## Root Cause Analysis
### 1. **Search Complexity Parameter**
The **default `complexity` parameter is 64**, which is too high for achieving ~2 second search times with this configuration.
**Test Results (Reproduced):**
- **Complexity 64 (default)**: **36.17 seconds**
- **Complexity 32**: **2.49 seconds**
- **Complexity 16**: **2.24 seconds** ✅ (Close to paper's ~2 seconds)
- **Complexity 8**: **1.67 seconds**
### 2. **beam_width Parameter**
The `beam_width` parameter is **mainly for DiskANN backend**, not HNSW. Setting it to 10 has minimal/no effect on HNSW search performance.
### 3. **Embedding Model Size**
The paper uses a smaller embedding model (~100M parameters), while the user is using `BAAI/bge-large-zh-v1.5` (~300M parameters). This contributes to slower embedding computation during search, but the main bottleneck is the search complexity parameter.
## Solution
### **Recommended Fix: Reduce Search Complexity**
To achieve search times close to ~2 seconds, use:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
query="your query",
top_k=10,
complexity=16, # or complexity=32 for slightly better accuracy
# beam_width parameter doesn't affect HNSW, can be ignored
)
```
Or via CLI:
```bash
leann search your-index "your query" --complexity 16
```
### **Alternative Solutions**
1. **Use DiskANN Backend** (Recommended by maintainer)
- DiskANN is faster for large datasets
- Better performance scaling
- `beam_width` parameter is relevant here
```python
builder = LeannBuilder(backend_name="diskann")
```
2. **Use Smaller Embedding Model**
- Switch to a smaller model (~100M parameters) like the paper
- Faster embedding computation
- Example: `BAAI/bge-base-zh-v1.5` instead of `bge-large-zh-v1.5`
3. **Disable Recomputation** (Trade storage for speed)
- Use `--no-recompute` flag
- Stores all embeddings (much larger storage)
- Faster search (no embedding recomputation)
```bash
leann build your-index --no-recompute --no-compact
leann search your-index "query" --no-recompute
```
## Performance Comparison
| Complexity | Search Time | Accuracy | Recommendation |
|------------|-------------|----------|---------------|
| 64 (default) | ~36s | Highest | ❌ Too slow |
| 32 | ~2.5s | High | ✅ Good balance |
| 16 | ~2.2s | Good | ✅ **Recommended** (matches paper) |
| 8 | ~1.7s | Lower | ⚠️ May sacrifice accuracy |
## Key Takeaways
1. **The default `complexity=64` is optimized for accuracy, not speed**
2. **For ~2 second search times, use `complexity=16` or `complexity=32`**
3. **`beam_width` parameter is for DiskANN, not HNSW**
4. **The paper's ~2 second results likely used:**
- Smaller embedding model (~100M params)
- Lower complexity (16-32)
- Possibly DiskANN backend
## Verification
The issue has been reproduced and verified. The test script `test_issue_159.py` demonstrates:
- Default complexity (64) results in ~36 second search times
- Reducing complexity to 16-32 achieves ~2 second search times
- This matches the user's reported issue and provides a clear solution
## Next Steps
1. ✅ Issue reproduced and root cause identified
2. ✅ Solution provided (reduce complexity parameter)
3. ⏳ User should test with `complexity=16` or `complexity=32`
4. ⏳ Consider updating documentation to clarify complexity parameter trade-offs

922
README.md
View File

File diff suppressed because it is too large Load Diff

View File

View File

@@ -1,407 +0,0 @@
"""
Base class for unified RAG examples interface.
Provides common parameters and functionality for all RAG examples.
"""
import argparse
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
# Optional import: older PyPI builds may not include interactive_utils
try:
from leann.interactive_utils import create_rag_session
except ImportError:
def create_rag_session(app_name: str, data_description: str):
class _SimpleSession:
def run_interactive_loop(self, handler):
print(f"Interactive session for {app_name}: {data_description}")
print("Interactive mode not available in this build")
return _SimpleSession()
from leann.registry import register_project_directory
# Optional import: older PyPI builds may not include settings
try:
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
except ImportError:
# Minimal fallbacks if settings helpers are unavailable
import os
def resolve_ollama_host(value: str | None) -> str | None:
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
def resolve_openai_api_key(value: str | None) -> str | None:
return value or os.getenv("OPENAI_API_KEY")
def resolve_openai_base_url(value: str | None) -> str | None:
return value or os.getenv("OPENAI_BASE_URL")
dotenv.load_dotenv()
class BaseRAGExample(ABC):
"""Base class for all RAG examples with unified interface."""
def __init__(
self,
name: str,
description: str,
default_index_name: str,
):
self.name = name
self.description = description
self.default_index_name = default_index_name
self.parser = self._create_parser()
def _create_parser(self) -> argparse.ArgumentParser:
"""Create argument parser with common parameters."""
parser = argparse.ArgumentParser(
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
)
# Core parameters (all examples share these)
core_group = parser.add_argument_group("Core Parameters")
core_group.add_argument(
"--index-dir",
type=str,
default=f"./{self.default_index_name}",
help=f"Directory to store the index (default: ./{self.default_index_name})",
)
core_group.add_argument(
"--query",
type=str,
default=None,
help="Query to run (if not provided, will run in interactive mode)",
)
# Allow subclasses to override default max_items
max_items_default = getattr(self, "max_items_default", -1)
core_group.add_argument(
"--max-items",
type=int,
default=max_items_default,
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
)
core_group.add_argument(
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
)
# Embedding parameters
embedding_group = parser.add_argument_group("Embedding Parameters")
# Allow subclasses to override default embedding_model
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
embedding_group.add_argument(
"--embedding-model",
type=str,
default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
)
embedding_group.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
llm_group.add_argument(
"--llm",
type=str,
default="openai",
choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend: openai, ollama, or hf (default: openai)",
)
llm_group.add_argument(
"--llm-model",
type=str,
default=None,
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
)
llm_group.add_argument(
"--llm-host",
type=str,
default=None,
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
llm_group.add_argument(
"--thinking-budget",
type=str,
choices=["low", "medium", "high"],
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
ast_group.add_argument(
"--use-ast-chunking",
action="store_true",
help="Enable AST-aware chunking for code files (requires astchunk)",
)
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=300,
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
)
ast_group.add_argument(
"--code-file-extensions",
nargs="+",
default=None,
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
)
ast_group.add_argument(
"--ast-fallback-traditional",
action="store_true",
default=True,
help="Fall back to traditional chunking if AST chunking fails (default: True)",
)
# Search parameters
search_group = parser.add_argument_group("Search Parameters")
search_group.add_argument(
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
)
search_group.add_argument(
"--search-complexity",
type=int,
default=32,
help="Search complexity for graph traversal (default: 64)",
)
# Index building parameters
index_group = parser.add_argument_group("Index Building Parameters")
index_group.add_argument(
"--backend-name",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
help="Backend to use for index (default: hnsw)",
)
index_group.add_argument(
"--graph-degree",
type=int,
default=32,
help="Graph degree for index construction (default: 32)",
)
index_group.add_argument(
"--build-complexity",
type=int,
default=64,
help="Build complexity for index construction (default: 64)",
)
index_group.add_argument(
"--no-compact",
action="store_true",
help="Disable compact index storage",
)
index_group.add_argument(
"--no-recompute",
action="store_true",
help="Disable embedding recomputation",
)
# Add source-specific parameters
self._add_specific_arguments(parser)
return parser
@abstractmethod
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add source-specific arguments. Override in subclasses."""
pass
@abstractmethod
async def load_data(self, args) -> list[str]:
"""Load data from the source. Returns list of text chunks."""
pass
def get_llm_config(self, args) -> dict[str, Any]:
"""Get LLM configuration based on arguments."""
config = {"type": args.llm}
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = resolve_ollama_host(args.llm_host)
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config
async def build_index(self, args, texts: list[str]) -> str:
"""Build LEANN index from texts."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
is_recompute=not args.no_recompute,
num_threads=1, # Force single-threaded mode
)
# Add texts in batches for better progress tracking
batch_size = 1000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
for text in batch:
builder.add_text(text)
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
print("Building index structure...")
builder.build_index(index_path)
print(f"Index saved to: {index_path}")
# Register project directory so leann list can discover this index
# The index is saved as args.index_dir/index_name.leann
# We want to register the current working directory where the app is run
register_project_directory(Path.cwd())
return index_path
async def run_interactive_chat(self, args, index_path: str):
"""Run interactive chat with the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)
# Create interactive session
session = create_rag_session(
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
)
def handle_query(query: str):
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
session.run_interactive_loop(handle_query)
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
complexity=args.search_complexity,
)
print(f"\n[Query]: \033[36m{query}\033[0m")
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
)
print(f"\n[Response]: \033[36m{response}\033[0m")
async def run(self):
"""Main entry point for the example."""
args = self.parser.parse_args()
# Check if index exists
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
index_exists = Path(args.index_dir).exists()
if not index_exists or args.force_rebuild:
# Load data and build index
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
texts = await self.load_data(args)
if not texts:
print("No data found to index!")
return
index_path = await self.build_index(args, texts)
else:
print(f"\nUsing existing index in {args.index_dir}")
# Run query or interactive mode
if args.query:
await self.run_single_query(args, index_path, args.query)
else:
await self.run_interactive_chat(args, index_path)

View File

@@ -1,171 +0,0 @@
"""
Browser History RAG example using the unified interface.
Supports Chrome browser history.
"""
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .history_data.history import ChromeHistoryReader
class BrowserRAG(BaseRAGExample):
"""RAG example for Chrome browser history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Browser History",
description="Process and query Chrome browser history with LEANN",
default_index_name="google_history_index",
)
def _add_specific_arguments(self, parser):
"""Add browser-specific arguments."""
browser_group = parser.add_argument_group("Browser Parameters")
browser_group.add_argument(
"--chrome-profile",
type=str,
default=None,
help="Path to Chrome profile directory (auto-detected if not specified)",
)
browser_group.add_argument(
"--auto-find-profiles",
action="store_true",
default=True,
help="Automatically find all Chrome profiles (default: True)",
)
browser_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
browser_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _get_chrome_base_path(self) -> Path:
"""Get the base Chrome profile path based on OS."""
if sys.platform == "darwin":
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
elif sys.platform.startswith("linux"):
return Path.home() / ".config" / "google-chrome"
elif sys.platform == "win32":
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
else:
raise ValueError(f"Unsupported platform: {sys.platform}")
def _find_chrome_profiles(self) -> list[Path]:
"""Auto-detect all Chrome profiles."""
base_path = self._get_chrome_base_path()
if not base_path.exists():
return []
profiles = []
# Check Default profile
default_profile = base_path / "Default"
if default_profile.exists() and (default_profile / "History").exists():
profiles.append(default_profile)
# Check numbered profiles
for item in base_path.iterdir():
if item.is_dir() and item.name.startswith("Profile "):
if (item / "History").exists():
profiles.append(item)
return profiles
async def load_data(self, args) -> list[str]:
"""Load browser history and convert to text chunks."""
# Determine Chrome profiles
if args.chrome_profile and not args.auto_find_profiles:
profile_dirs = [Path(args.chrome_profile)]
else:
print("Auto-detecting Chrome profiles...")
profile_dirs = self._find_chrome_profiles()
# If specific profile given, filter to just that one
if args.chrome_profile:
profile_path = Path(args.chrome_profile)
profile_dirs = [p for p in profile_dirs if p == profile_path]
if not profile_dirs:
print("No Chrome profiles found!")
print("Please specify --chrome-profile manually")
return []
print(f"Found {len(profile_dirs)} Chrome profiles")
# Create reader
reader = ChromeHistoryReader()
# Process each profile
all_documents = []
total_processed = 0
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
try:
# Apply max_items limit per profile
max_per_profile = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_profile = remaining
# Load history
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_per_profile,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} history entries from this profile")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No browser history found to process!")
return []
print(f"\nTotal history entries processed: {len(all_documents)}")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for browser history RAG
print("\n🌐 Browser History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What websites did I visit about machine learning?'")
print("- 'Find my search history about programming'")
print("- 'What YouTube videos did I watch recently?'")
print("- 'Show me websites about travel planning'")
print("\nNote: Make sure Chrome is closed before running\n")
rag = BrowserRAG()
asyncio.run(rag.run())

View File

View File

@@ -1,413 +0,0 @@
"""
ChatGPT export data reader.
Reads and processes ChatGPT export data from chat.html files.
"""
import re
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChatGPTReader(BaseReader):
"""
ChatGPT export data reader.
Reads ChatGPT conversation data from exported chat.html files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
self.concatenate_conversations = concatenate_conversations
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
"""
Extract chat.html from ChatGPT export zip file.
Args:
zip_path: Path to the ChatGPT export zip file
Returns:
HTML content as string, or None if not found
"""
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for chat.html or conversations.html
html_files = [
f
for f in zip_file.namelist()
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
]
if not html_files:
print(f"No HTML chat file found in {zip_path}")
return None
# Use the first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
with zip_file.open(html_file) as f:
return f.read().decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error extracting HTML from zip {zip_path}: {e}")
return None
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
"""
Parse ChatGPT HTML export to extract conversations.
Args:
html_content: HTML content from ChatGPT export
Returns:
List of conversation dictionaries
"""
soup = BeautifulSoup(html_content, "html.parser")
conversations = []
# Try different possible structures for ChatGPT exports
# Structure 1: Look for conversation containers
conversation_containers = soup.find_all(
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
)
if not conversation_containers:
# Structure 2: Look for message containers directly
conversation_containers = [soup] # Use the entire document as one conversation
for container in conversation_containers:
conversation = self._extract_conversation_from_container(container)
if conversation and conversation.get("messages"):
conversations.append(conversation)
# If no structured conversations found, try to extract all text as one conversation
if not conversations:
all_text = soup.get_text(separator="\n", strip=True)
if all_text:
conversations.append(
{
"title": "ChatGPT Conversation",
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
"timestamp": None,
}
)
return conversations
def _extract_conversation_from_container(self, container) -> dict | None:
"""
Extract conversation data from a container element.
Args:
container: BeautifulSoup element containing conversation
Returns:
Dictionary with conversation data or None
"""
messages = []
# Look for message elements with various possible structures
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
for selector in message_selectors:
message_elements = container.select(selector)
if message_elements:
break
else:
message_elements = []
# If no structured messages found, treat the entire container as one message
if not message_elements:
text_content = container.get_text(separator="\n", strip=True)
if text_content:
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
else:
for element in message_elements:
message = self._extract_message_from_element(element)
if message:
messages.append(message)
if not messages:
return None
# Try to extract conversation title
title_element = container.find(["h1", "h2", "h3", "title"])
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
# Try to extract timestamp from various possible locations
timestamp = self._extract_timestamp_from_container(container)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_element(self, element) -> dict | None:
"""
Extract message data from an element.
Args:
element: BeautifulSoup element containing message
Returns:
Dictionary with message data or None
"""
text_content = element.get_text(separator=" ", strip=True)
# Skip empty or very short messages
if not text_content or len(text_content.strip()) < 3:
return None
# Try to determine role (user/assistant) from class names or content
role = "mixed" # Default role
class_names = " ".join(element.get("class", [])).lower()
if "user" in class_names or "human" in class_names:
role = "user"
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
role = "assistant"
elif text_content.lower().startswith(("you:", "user:", "me:")):
role = "user"
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
role = "assistant"
text_content = re.sub(
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
)
# Try to extract timestamp
timestamp = self._extract_timestamp_from_element(element)
return {"role": role, "content": text_content, "timestamp": timestamp}
def _extract_timestamp_from_element(self, element) -> str | None:
"""Extract timestamp from element."""
# Look for timestamp in various attributes and child elements
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
for attr in timestamp_attrs:
if element.get(attr):
return element.get(attr)
# Look for time elements
time_element = element.find("time")
if time_element:
return time_element.get("datetime") or time_element.get_text(strip=True)
# Look for date-like text patterns
text = element.get_text()
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
for pattern in date_patterns:
match = re.search(pattern, text)
if match:
return match.group()
return None
def _extract_timestamp_from_container(self, container) -> str | None:
"""Extract timestamp from conversation container."""
return self._extract_timestamp_from_element(container)
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "ChatGPT Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[ChatGPT]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load ChatGPT export data.
Args:
input_dir: Directory containing ChatGPT export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not chatgpt_export_path:
print("No ChatGPT export path provided")
return docs
export_path = Path(chatgpt_export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
return docs
html_content = None
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract HTML from zip file
html_content = self._extract_html_from_zip(export_path)
elif export_path.suffix.lower() == ".html":
# Read HTML file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for HTML files in directory
html_files = list(export_path.glob("*.html"))
zip_files = list(export_path.glob("*.zip"))
if html_files:
# Use first HTML file found
html_file = html_files[0]
print(f"Found HTML file: {html_file}")
try:
with open(html_file, encoding="utf-8", errors="ignore") as f:
html_content = f.read()
except Exception as e:
print(f"Error reading HTML file {html_file}: {e}")
return docs
elif zip_files:
# Use first zip file found
zip_file = zip_files[0]
print(f"Found zip file: {zip_file}")
html_content = self._extract_html_from_zip(zip_file)
else:
print(f"No HTML or zip files found in {export_path}")
return docs
if not html_content:
print("No HTML content found to process")
return docs
# Parse conversations from HTML
print("Parsing ChatGPT conversations from HTML...")
conversations = self._parse_chatgpt_html(html_content)
if not conversations:
print("No conversations found in HTML content")
return docs
print(f"Found {len(conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "ChatGPT Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "ChatGPT Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from ChatGPT export")
return docs

View File

@@ -1,186 +0,0 @@
"""
ChatGPT RAG example using the unified interface.
Supports ChatGPT export data from chat.html files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .chatgpt_data.chatgpt_reader import ChatGPTReader
class ChatGPTRAG(BaseRAGExample):
"""RAG example for ChatGPT conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="ChatGPT",
description="Process and query ChatGPT conversation exports with LEANN",
default_index_name="chatgpt_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add ChatGPT-specific arguments."""
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
chatgpt_group.add_argument(
"--export-path",
type=str,
default="./chatgpt_export",
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
)
chatgpt_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
chatgpt_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
chatgpt_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
chatgpt_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
"""
Find ChatGPT export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to ChatGPT export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".html"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and html files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.html"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load ChatGPT export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"ChatGPT export path not found: {export_path}")
print(
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
)
print("\nTo export your ChatGPT data:")
print("1. Sign in to ChatGPT")
print("2. Click on your profile icon → Settings → Data Controls")
print("3. Click 'Export' under Export Data")
print("4. Download the zip file from the email link")
print("5. Extract or place the file/directory at the specified path")
return []
# Find export files
export_files = self._find_chatgpt_exports(export_path)
if not export_files:
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
return []
print(f"Found {len(export_files)} ChatGPT export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ChatGPTReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
chatgpt_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid ChatGPT export")
print("- Check that the HTML file contains conversation data")
print("- Try extracting the zip file and pointing to the HTML file directly")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for ChatGPT RAG
print("\n🤖 ChatGPT RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about travel planning'")
print("- 'What advice did ChatGPT give me about career development?'")
print("- 'Search for conversations about cooking recipes'")
print("\nTo get started:")
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
print("3. Run this script to build your personal ChatGPT knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ChatGPTRAG()
asyncio.run(rag.run())

View File

@@ -1,47 +0,0 @@
"""Unified chunking utilities facade.
This module re-exports the packaged utilities from `leann.chunking_utils` so
that both repo apps (importing `chunking`) and installed wheels share one
single implementation. When running from the repo without installation, it
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
"""
import sys
from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
except Exception: # pragma: no cover - best-effort fallback for dev environment
repo_root = Path(__file__).resolve().parents[2]
leann_src = repo_root / "packages" / "leann-core" / "src"
if leann_src.exists():
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
_traditional_chunks_as_dicts,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
else:
raise
__all__ = [
"CODE_EXTENSIONS",
"_traditional_chunks_as_dicts",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",
"detect_code_files",
"get_language_from_extension",
]

View File

View File

@@ -1,420 +0,0 @@
"""
Claude export data reader.
Reads and processes Claude conversation data from exported JSON files.
"""
import json
from pathlib import Path
from typing import Any
from zipfile import ZipFile
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ClaudeReader(BaseReader):
"""
Claude export data reader.
Reads Claude conversation data from exported JSON files or zip archives.
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
"""
Extract JSON files from Claude export zip file.
Args:
zip_path: Path to the Claude export zip file
Returns:
List of JSON content strings, or empty list if not found
"""
json_contents = []
try:
with ZipFile(zip_path, "r") as zip_file:
# Look for JSON files
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
if not json_files:
print(f"No JSON files found in {zip_path}")
return []
print(f"Found {len(json_files)} JSON files in archive")
for json_file in json_files:
with zip_file.open(json_file) as f:
content = f.read().decode("utf-8", errors="ignore")
json_contents.append(content)
except Exception as e:
print(f"Error extracting JSON from zip {zip_path}: {e}")
return json_contents
def _parse_claude_json(self, json_content: str) -> list[dict]:
"""
Parse Claude JSON export to extract conversations.
Args:
json_content: JSON content from Claude export
Returns:
List of conversation dictionaries
"""
try:
data = json.loads(json_content)
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
return []
conversations = []
# Handle different possible JSON structures
if isinstance(data, list):
# If data is a list of conversations
for item in data:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif isinstance(data, dict):
# Check for common structures
if "conversations" in data:
# Structure: {"conversations": [...]}
for item in data["conversations"]:
conversation = self._extract_conversation_from_json(item)
if conversation:
conversations.append(conversation)
elif "messages" in data:
# Single conversation with messages
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
else:
# Try to treat the whole object as a conversation
conversation = self._extract_conversation_from_json(data)
if conversation:
conversations.append(conversation)
return conversations
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
"""
Extract conversation data from a JSON object.
Args:
conv_data: Dictionary containing conversation data
Returns:
Dictionary with conversation data or None
"""
if not isinstance(conv_data, dict):
return None
messages = []
# Look for messages in various possible structures
message_sources = []
if "messages" in conv_data:
message_sources = conv_data["messages"]
elif "chat" in conv_data:
message_sources = conv_data["chat"]
elif "conversation" in conv_data:
message_sources = conv_data["conversation"]
else:
# If no clear message structure, try to extract from the object itself
if "content" in conv_data and "role" in conv_data:
message_sources = [conv_data]
for msg_data in message_sources:
message = self._extract_message_from_json(msg_data)
if message:
messages.append(message)
if not messages:
return None
# Extract conversation metadata
title = self._extract_title_from_conversation(conv_data, messages)
timestamp = self._extract_timestamp_from_conversation(conv_data)
return {"title": title, "messages": messages, "timestamp": timestamp}
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
"""
Extract message data from a JSON message object.
Args:
msg_data: Dictionary containing message data
Returns:
Dictionary with message data or None
"""
if not isinstance(msg_data, dict):
return None
# Extract content from various possible fields
content = ""
content_fields = ["content", "text", "message", "body"]
for field in content_fields:
if msg_data.get(field):
content = str(msg_data[field])
break
if not content or len(content.strip()) < 3:
return None
# Extract role (user/assistant/human/ai/claude)
role = "mixed" # Default role
role_fields = ["role", "sender", "from", "author", "type"]
for field in role_fields:
if msg_data.get(field):
role_value = str(msg_data[field]).lower()
if role_value in ["user", "human", "person"]:
role = "user"
elif role_value in ["assistant", "ai", "claude", "bot"]:
role = "assistant"
break
# Extract timestamp
timestamp = self._extract_timestamp_from_message(msg_data)
return {"role": role, "content": content, "timestamp": timestamp}
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
"""Extract timestamp from message data."""
timestamp_fields = ["timestamp", "created_at", "date", "time"]
for field in timestamp_fields:
if msg_data.get(field):
return str(msg_data[field])
return None
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
"""Extract timestamp from conversation data."""
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
for field in timestamp_fields:
if conv_data.get(field):
return str(conv_data[field])
return None
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
"""Extract or generate title for conversation."""
# Try to find explicit title
title_fields = ["title", "name", "subject", "topic"]
for field in title_fields:
if conv_data.get(field):
return str(conv_data[field])
# Generate title from first user message
for message in messages:
if message.get("role") == "user":
content = message.get("content", "")
if content:
# Use first 50 characters as title
title = content[:50].strip()
if len(content) > 50:
title += "..."
return title
return "Claude Conversation"
def _create_concatenated_content(self, conversation: dict) -> str:
"""
Create concatenated content from conversation messages.
Args:
conversation: Dictionary containing conversation data
Returns:
Formatted concatenated content
"""
title = conversation.get("title", "Claude Conversation")
messages = conversation.get("messages", [])
timestamp = conversation.get("timestamp", "Unknown")
# Build message content
message_parts = []
for message in messages:
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if role == "user":
prefix = "[You]"
elif role == "assistant":
prefix = "[Claude]"
else:
prefix = "[Message]"
# Add timestamp if available
if msg_timestamp:
prefix += f" ({msg_timestamp})"
message_parts.append(f"{prefix}: {content}")
concatenated_text = "\n\n".join(message_parts)
# Create final document content
doc_content = f"""Conversation: {title}
Date: {timestamp}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load Claude export data.
Args:
input_dir: Directory containing Claude export files or path to specific file
**load_kwargs:
max_count (int): Maximum number of conversations to process
claude_export_path (str): Specific path to Claude export file/directory
include_metadata (bool): Whether to include metadata in documents
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", -1)
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
include_metadata = load_kwargs.get("include_metadata", True)
if not claude_export_path:
print("No Claude export path provided")
return docs
export_path = Path(claude_export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
return docs
json_contents = []
# Handle different input types
if export_path.is_file():
if export_path.suffix.lower() == ".zip":
# Extract JSON from zip file
json_contents = self._extract_json_from_zip(export_path)
elif export_path.suffix.lower() == ".json":
# Read JSON file directly
try:
with open(export_path, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {export_path}: {e}")
return docs
else:
print(f"Unsupported file type: {export_path.suffix}")
return docs
elif export_path.is_dir():
# Look for JSON files in directory
json_files = list(export_path.glob("*.json"))
zip_files = list(export_path.glob("*.zip"))
if json_files:
print(f"Found {len(json_files)} JSON files in directory")
for json_file in json_files:
try:
with open(json_file, encoding="utf-8", errors="ignore") as f:
json_contents.append(f.read())
except Exception as e:
print(f"Error reading JSON file {json_file}: {e}")
continue
if zip_files:
print(f"Found {len(zip_files)} ZIP files in directory")
for zip_file in zip_files:
zip_contents = self._extract_json_from_zip(zip_file)
json_contents.extend(zip_contents)
if not json_files and not zip_files:
print(f"No JSON or ZIP files found in {export_path}")
return docs
if not json_contents:
print("No JSON content found to process")
return docs
# Parse conversations from JSON content
print("Parsing Claude conversations from JSON...")
all_conversations = []
for json_content in json_contents:
conversations = self._parse_claude_json(json_content)
all_conversations.extend(conversations)
if not all_conversations:
print("No conversations found in JSON content")
return docs
print(f"Found {len(all_conversations)} conversations")
# Process conversations into documents
count = 0
for conversation in all_conversations:
if max_count > 0 and count >= max_count:
break
if self.concatenate_conversations:
# Create one document per conversation with concatenated messages
doc_content = self._create_concatenated_content(conversation)
metadata = {}
if include_metadata:
metadata = {
"title": conversation.get("title", "Claude Conversation"),
"timestamp": conversation.get("timestamp", "Unknown"),
"message_count": len(conversation.get("messages", [])),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
else:
# Create separate documents for each message
for message in conversation.get("messages", []):
if max_count > 0 and count >= max_count:
break
role = message.get("role", "mixed")
content = message.get("content", "")
msg_timestamp = message.get("timestamp", "")
if not content.strip():
continue
# Create document content with context
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
Role: {role}
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
Message: {content}
"""
metadata = {}
if include_metadata:
metadata = {
"conversation_title": conversation.get("title", "Claude Conversation"),
"role": role,
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
"source": "Claude Export",
}
doc = Document(text=doc_content, metadata=metadata)
docs.append(doc)
count += 1
print(f"Created {len(docs)} documents from Claude export")
return docs

View File

@@ -1,189 +0,0 @@
"""
Claude RAG example using the unified interface.
Supports Claude export data from JSON files.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .claude_data.claude_reader import ClaudeReader
class ClaudeRAG(BaseRAGExample):
"""RAG example for Claude conversation data."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all conversations by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Claude",
description="Process and query Claude conversation exports with LEANN",
default_index_name="claude_conversations_index",
)
def _add_specific_arguments(self, parser):
"""Add Claude-specific arguments."""
claude_group = parser.add_argument_group("Claude Parameters")
claude_group.add_argument(
"--export-path",
type=str,
default="./claude_export",
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
)
claude_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
claude_group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
claude_group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
claude_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _find_claude_exports(self, export_path: Path) -> list[Path]:
"""
Find Claude export files in the given path.
Args:
export_path: Path to search for exports
Returns:
List of paths to Claude export files
"""
export_files = []
if export_path.is_file():
if export_path.suffix.lower() in [".zip", ".json"]:
export_files.append(export_path)
elif export_path.is_dir():
# Look for zip and json files
export_files.extend(export_path.glob("*.zip"))
export_files.extend(export_path.glob("*.json"))
return export_files
async def load_data(self, args) -> list[str]:
"""Load Claude export data and convert to text chunks."""
export_path = Path(args.export_path)
if not export_path.exists():
print(f"Claude export path not found: {export_path}")
print(
"Please ensure you have exported your Claude data and placed it in the correct location."
)
print("\nTo export your Claude data:")
print("1. Open Claude in your browser")
print("2. Look for export/download options in settings or conversation menu")
print("3. Download the conversation data (usually in JSON format)")
print("4. Place the file/directory at the specified path")
print(
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
)
return []
# Find export files
export_files = self._find_claude_exports(export_path)
if not export_files:
print(f"No Claude export files (.json or .zip) found in: {export_path}")
return []
print(f"Found {len(export_files)} Claude export files")
# Create reader with appropriate settings
concatenate = args.concatenate_conversations and not args.separate_messages
reader = ClaudeReader(concatenate_conversations=concatenate)
# Process each export file
all_documents = []
total_processed = 0
for i, export_file in enumerate(export_files):
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
try:
# Apply max_items limit per file
max_per_file = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_file = remaining
# Load conversations
documents = reader.load_data(
claude_export_path=str(export_file),
max_count=max_per_file,
include_metadata=True,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} conversations from this file")
else:
print(f"No conversations loaded from {export_file}")
except Exception as e:
print(f"Error processing {export_file}: {e}")
continue
if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid Claude export")
print("- Check that the JSON file contains conversation data")
print("- Try using a different export format or method")
print("- Check Claude's documentation for current export procedures")
return []
print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for Claude RAG
print("\n🤖 Claude RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did I ask Claude about Python programming?'")
print("- 'Show me conversations about machine learning'")
print("- 'Find discussions about code optimization'")
print("- 'What advice did Claude give me about software design?'")
print("- 'Search for conversations about debugging techniques'")
print("\nTo get started:")
print("1. Export your Claude conversation data")
print("2. Place the JSON/ZIP file in ./claude_export/")
print("3. Run this script to build your personal Claude knowledge base!")
print("\nOr run without --query for interactive mode\n")
rag = ClaudeRAG()
asyncio.run(rag.run())

View File

@@ -1,211 +0,0 @@
"""
Code RAG example using AST-aware chunking for optimal code understanding.
Specialized for code repositories with automatic language detection and
optimized chunking parameters.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import CODE_EXTENSIONS, create_text_chunks
from llama_index.core import SimpleDirectoryReader
class CodeRAG(BaseRAGExample):
"""Specialized RAG example for code repositories with AST-aware chunking."""
def __init__(self):
super().__init__(
name="Code",
description="Process and query code repositories with AST-aware chunking",
default_index_name="code_index",
)
# Override defaults for code-specific usage
self.embedding_model_default = "facebook/contriever" # Good for code
self.max_items_default = -1 # Process all code files by default
def _add_specific_arguments(self, parser):
"""Add code-specific arguments."""
code_group = parser.add_argument_group("Code Repository Parameters")
code_group.add_argument(
"--repo-dir",
type=str,
default=".",
help="Code repository directory to index (default: current directory)",
)
code_group.add_argument(
"--include-extensions",
nargs="+",
default=list(CODE_EXTENSIONS.keys()),
help="File extensions to include (default: supported code extensions)",
)
code_group.add_argument(
"--exclude-dirs",
nargs="+",
default=[
".git",
"__pycache__",
"node_modules",
"venv",
".venv",
"build",
"dist",
"target",
],
help="Directories to exclude from indexing",
)
code_group.add_argument(
"--max-file-size",
type=int,
default=1000000, # 1MB
help="Maximum file size in bytes to process (default: 1MB)",
)
code_group.add_argument(
"--include-comments",
action="store_true",
help="Include comments in chunking (useful for documentation)",
)
code_group.add_argument(
"--preserve-imports",
action="store_true",
default=True,
help="Try to preserve import statements in chunks (default: True)",
)
async def load_data(self, args) -> list[str]:
"""Load code files and convert to AST-aware chunks."""
print(f"🔍 Scanning code repository: {args.repo_dir}")
print(f"📁 Including extensions: {args.include_extensions}")
print(f"🚫 Excluding directories: {args.exclude_dirs}")
# Check if repository directory exists
repo_path = Path(args.repo_dir)
if not repo_path.exists():
raise ValueError(f"Repository directory not found: {args.repo_dir}")
# Load code files with filtering
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
"required_exts": args.include_extensions,
"exclude_hidden": True,
}
# Create exclusion filter
def file_filter(file_path: str) -> bool:
"""Filter out unwanted files and directories."""
path = Path(file_path)
# Check file size
try:
if path.stat().st_size > args.max_file_size:
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
return False
except Exception:
return False
# Check if in excluded directory
for exclude_dir in args.exclude_dirs:
if exclude_dir in path.parts:
return False
return True
try:
# Load documents with file filtering
documents = SimpleDirectoryReader(
args.repo_dir,
file_extractor=None, # Use default extractors
**reader_kwargs,
).load_data(show_progress=True)
# Apply custom filtering
filtered_docs = []
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents = filtered_docs
except Exception as e:
print(f"❌ Error loading code files: {e}")
return []
if not documents:
print(
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
)
return []
print(f"✅ Loaded {len(documents)} code files")
# Show breakdown by language/extension
ext_counts = {}
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_path:
ext = Path(file_path).suffix.lower()
ext_counts[ext] = ext_counts.get(ext, 0) + 1
print("📊 Files by extension:")
for ext, count in sorted(ext_counts.items()):
print(f" {ext}: {count} files")
# Use AST-aware chunking by default for code
print(
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
)
all_texts = create_text_chunks(
documents,
chunk_size=256, # Fallback for non-code files
chunk_overlap=64,
use_ast_chunking=True, # Always use AST for code RAG
ast_chunk_size=args.ast_chunk_size,
ast_chunk_overlap=args.ast_chunk_overlap,
code_file_extensions=args.include_extensions,
ast_fallback_traditional=True,
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
print(f"✅ Generated {len(all_texts)} code chunks")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for code RAG
print("\n💻 Code RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'How does the embedding computation work?'")
print("- 'What are the main classes in this codebase?'")
print("- 'Show me the search implementation'")
print("- 'How is error handling implemented?'")
print("- 'What design patterns are used?'")
print("- 'Explain the chunking logic'")
print("\n🚀 Features:")
print("- ✅ AST-aware chunking preserves code structure")
print("- ✅ Automatic language detection")
print("- ✅ Smart filtering of large files and common excludes")
print("- ✅ Optimized for code understanding")
print("\nUsage examples:")
print(" python -m apps.code_rag --repo-dir ./my_project")
print(
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
)
print("\nOr run without --query for interactive mode\n")
rag = CodeRAG()
asyncio.run(rag.run())

View File

@@ -1,131 +0,0 @@
"""
Document RAG example using the unified interface.
Supports PDF, TXT, MD, and other document formats.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from llama_index.core import SimpleDirectoryReader
class DocumentRAG(BaseRAGExample):
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
def __init__(self):
super().__init__(
name="Document",
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
default_index_name="test_doc_files",
)
def _add_specific_arguments(self, parser):
"""Add document-specific arguments."""
doc_group = parser.add_argument_group("Document Parameters")
doc_group.add_argument(
"--data-dir",
type=str,
default="data",
help="Directory containing documents to index (default: data)",
)
doc_group.add_argument(
"--file-types",
nargs="+",
default=None,
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
)
doc_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
doc_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
doc_group.add_argument(
"--enable-code-chunking",
action="store_true",
help="Enable AST-aware chunking for code files in the data directory",
)
async def load_data(self, args) -> list[str]:
"""Load documents and convert to text chunks."""
print(f"Loading documents from: {args.data_dir}")
if args.file_types:
print(f"Filtering by file types: {args.file_types}")
else:
print("Processing all supported file types")
# Check if data directory exists
data_path = Path(args.data_dir)
if not data_path.exists():
raise ValueError(f"Data directory not found: {args.data_dir}")
# Load documents
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
}
if args.file_types:
reader_kwargs["required_exts"] = args.file_types
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
show_progress=True
)
if not documents:
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
return []
print(f"Loaded {len(documents)} documents")
# Determine chunking strategy
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
if use_ast:
print("Using AST-aware chunking for code files")
# Convert to text chunks with optional AST support
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
use_ast_chunking=use_ast,
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
code_file_extensions=getattr(args, "code_file_extensions", None),
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for document RAG
print("\n📄 Document RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What are the main techniques LEANN uses?'")
print("- 'What is the technique DLPM?'")
print("- 'Who does Elizabeth Bennet marry?'")
print(
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
)
print("\n🚀 NEW: Code-aware chunking available!")
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
print("- Supports Python, Java, C#, TypeScript files")
print("- Better semantic understanding of code structure")
print("\nOr run without --query for interactive mode\n")
rag = DocumentRAG()
asyncio.run(rag.run())

View File

@@ -1,157 +0,0 @@
"""
Email RAG example using the unified interface.
Supports Apple Mail on macOS.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .email_data.LEANN_email_reader import EmlxReader
class EmailRAG(BaseRAGExample):
"""RAG example for Apple Mail processing."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all emails by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Email",
description="Process and query Apple Mail emails with LEANN",
default_index_name="mail_index",
)
def _add_specific_arguments(self, parser):
"""Add email-specific arguments."""
email_group = parser.add_argument_group("Email Parameters")
email_group.add_argument(
"--mail-path",
type=str,
default=None,
help="Path to Apple Mail directory (auto-detected if not specified)",
)
email_group.add_argument(
"--include-html", action="store_true", help="Include HTML content in email processing"
)
email_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
email_group.add_argument(
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
)
def _find_mail_directories(self) -> list[Path]:
"""Auto-detect all Apple Mail directories."""
mail_base = Path.home() / "Library" / "Mail"
if not mail_base.exists():
return []
# Find all Messages directories
messages_dirs = []
for item in mail_base.rglob("Messages"):
if item.is_dir():
messages_dirs.append(item)
return messages_dirs
async def load_data(self, args) -> list[str]:
"""Load emails and convert to text chunks."""
# Determine mail directories
if args.mail_path:
messages_dirs = [Path(args.mail_path)]
else:
print("Auto-detecting Apple Mail directories...")
messages_dirs = self._find_mail_directories()
if not messages_dirs:
print("No Apple Mail directories found!")
print("Please specify --mail-path manually")
return []
print(f"Found {len(messages_dirs)} mail directories")
# Create reader
reader = EmlxReader(include_html=args.include_html)
# Process each directory
all_documents = []
total_processed = 0
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
try:
# Count emlx files
emlx_files = list(messages_dir.glob("*.emlx"))
print(f"Found {len(emlx_files)} email files")
# Apply max_items limit per directory
max_per_dir = -1 # Default to process all
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_dir = remaining
# If args.max_items == -1, max_per_dir stays -1 (process all)
# Load emails - fix the parameter passing
documents = reader.load_data(
input_dir=str(messages_dir),
max_count=max_per_dir,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} emails from this directory")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No emails found to process!")
return []
print(f"\nTotal emails processed: {len(all_documents)}")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks
# Email reader uses chunk_overlap=25 as in original
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
print(" Windows/Linux support coming soon!\n")
# Example queries for email RAG
print("\n📧 Email RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did my boss say about deadlines?'")
print("- 'Find emails about travel expenses'")
print("- 'Show me emails from last month about the project'")
print("- 'What food did I order from DoorDash?'")
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
rag = EmailRAG()
asyncio.run(rag.run())

View File

@@ -1 +0,0 @@
"""iMessage data processing module."""

View File

@@ -1,342 +0,0 @@
"""
iMessage data reader.
Reads and processes iMessage conversation data from the macOS Messages database.
"""
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class IMessageReader(BaseReader):
"""
iMessage data reader.
Reads iMessage conversation data from the macOS Messages database (chat.db).
Processes conversations into structured documents with metadata.
"""
def __init__(self, concatenate_conversations: bool = True) -> None:
"""
Initialize.
Args:
concatenate_conversations: Whether to concatenate messages within conversations for better context
"""
self.concatenate_conversations = concatenate_conversations
def _get_default_chat_db_path(self) -> Path:
"""
Get the default path to the iMessage chat database.
Returns:
Path to the chat.db file
"""
home = Path.home()
return home / "Library" / "Messages" / "chat.db"
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
"""
Convert Cocoa timestamp to readable format.
Args:
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
Returns:
Formatted timestamp string
"""
if cocoa_timestamp == 0:
return "Unknown"
try:
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
# Convert to seconds and add to Unix epoch
cocoa_epoch = datetime(2001, 1, 1)
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
message_time = cocoa_epoch.timestamp() + unix_timestamp
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "Unknown"
def _get_contact_name(self, handle_id: str) -> str:
"""
Get a readable contact name from handle ID.
Args:
handle_id: The handle ID (phone number or email)
Returns:
Formatted contact name
"""
if not handle_id:
return "Unknown"
# Clean up phone numbers and emails for display
if "@" in handle_id:
return handle_id # Email address
elif handle_id.startswith("+"):
return handle_id # International phone number
else:
# Try to format as phone number
digits = "".join(filter(str.isdigit, handle_id))
if len(digits) == 10:
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
elif len(digits) == 11 and digits[0] == "1":
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
else:
return handle_id
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
"""
Read messages from the iMessage database.
Args:
db_path: Path to the chat.db file
Returns:
List of message dictionaries
"""
if not db_path.exists():
print(f"iMessage database not found at: {db_path}")
return []
try:
# Connect to the database
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Query to get messages with chat and handle information
query = """
SELECT
m.ROWID as message_id,
m.text,
m.date,
m.is_from_me,
m.service,
c.chat_identifier,
c.display_name as chat_display_name,
h.id as handle_id,
c.ROWID as chat_id
FROM message m
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
LEFT JOIN handle h ON m.handle_id = h.ROWID
WHERE m.text IS NOT NULL AND m.text != ''
ORDER BY c.ROWID, m.date
"""
cursor.execute(query)
rows = cursor.fetchall()
messages = []
for row in rows:
(
message_id,
text,
date,
is_from_me,
service,
chat_identifier,
chat_display_name,
handle_id,
chat_id,
) = row
message = {
"message_id": message_id,
"text": text,
"timestamp": self._convert_cocoa_timestamp(date),
"is_from_me": bool(is_from_me),
"service": service or "iMessage",
"chat_identifier": chat_identifier or "Unknown",
"chat_display_name": chat_display_name or "Unknown Chat",
"handle_id": handle_id or "Unknown",
"contact_name": self._get_contact_name(handle_id or ""),
"chat_id": chat_id,
}
messages.append(message)
conn.close()
print(f"Found {len(messages)} messages in database")
return messages
except sqlite3.Error as e:
print(f"Error reading iMessage database: {e}")
return []
except Exception as e:
print(f"Unexpected error reading iMessage database: {e}")
return []
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
"""
Group messages by chat ID.
Args:
messages: List of message dictionaries
Returns:
Dictionary mapping chat_id to list of messages
"""
chats = {}
for message in messages:
chat_id = message["chat_id"]
if chat_id not in chats:
chats[chat_id] = []
chats[chat_id].append(message)
return chats
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
"""
Create concatenated content from chat messages.
Args:
chat_id: The chat ID
messages: List of messages in the chat
Returns:
Concatenated text content
"""
if not messages:
return ""
# Get chat info from first message
first_msg = messages[0]
chat_name = first_msg["chat_display_name"]
chat_identifier = first_msg["chat_identifier"]
# Build message content
message_parts = []
for message in messages:
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
if is_from_me:
prefix = "[You]"
else:
prefix = f"[{contact_name}]"
if timestamp != "Unknown":
prefix += f" ({timestamp})"
message_parts.append(f"{prefix}: {text}")
concatenated_text = "\n\n".join(message_parts)
doc_content = f"""Chat: {chat_name}
Identifier: {chat_identifier}
Messages ({len(messages)} messages):
{concatenated_text}
"""
return doc_content
def _create_individual_content(self, message: dict) -> str:
"""
Create content for individual message.
Args:
message: Message dictionary
Returns:
Formatted message content
"""
timestamp = message["timestamp"]
is_from_me = message["is_from_me"]
text = message["text"]
contact_name = message["contact_name"]
chat_name = message["chat_display_name"]
sender = "You" if is_from_me else contact_name
return f"""Message from {sender} in chat "{chat_name}"
Time: {timestamp}
Content: {text}
"""
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
"""
Load iMessage data and return as documents.
Args:
input_dir: Optional path to directory containing chat.db file.
If not provided, uses default macOS location.
**load_kwargs: Additional arguments (unused)
Returns:
List of Document objects containing iMessage data
"""
docs = []
# Determine database path
if input_dir:
db_path = Path(input_dir) / "chat.db"
else:
db_path = self._get_default_chat_db_path()
print(f"Reading iMessage database from: {db_path}")
# Read messages from database
messages = self._read_messages_from_db(db_path)
if not messages:
return docs
if self.concatenate_conversations:
# Group messages by chat and create concatenated documents
chats = self._group_messages_by_chat(messages)
for chat_id, chat_messages in chats.items():
if not chat_messages:
continue
content = self._create_concatenated_content(chat_id, chat_messages)
# Create metadata
first_msg = chat_messages[0]
last_msg = chat_messages[-1]
metadata = {
"source": "iMessage",
"chat_id": chat_id,
"chat_name": first_msg["chat_display_name"],
"chat_identifier": first_msg["chat_identifier"],
"message_count": len(chat_messages),
"first_message_date": first_msg["timestamp"],
"last_message_date": last_msg["timestamp"],
"participants": list(
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
),
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
else:
# Create individual documents for each message
for message in messages:
content = self._create_individual_content(message)
metadata = {
"source": "iMessage",
"message_id": message["message_id"],
"chat_id": message["chat_id"],
"chat_name": message["chat_display_name"],
"chat_identifier": message["chat_identifier"],
"timestamp": message["timestamp"],
"is_from_me": message["is_from_me"],
"contact_name": message["contact_name"],
"service": message["service"],
}
doc = Document(text=content, metadata=metadata)
docs.append(doc)
print(f"Created {len(docs)} documents from iMessage data")
return docs

View File

@@ -1,125 +0,0 @@
"""
iMessage RAG Example.
This example demonstrates how to build a RAG system on your iMessage conversation history.
"""
import asyncio
from pathlib import Path
from leann.chunking_utils import create_text_chunks
from apps.base_rag_example import BaseRAGExample
from apps.imessage_data.imessage_reader import IMessageReader
class IMessageRAG(BaseRAGExample):
"""RAG example for iMessage conversation history."""
def __init__(self):
super().__init__(
name="iMessage",
description="RAG on your iMessage conversation history",
default_index_name="imessage_index",
)
def _add_specific_arguments(self, parser):
"""Add iMessage-specific arguments."""
imessage_group = parser.add_argument_group("iMessage Parameters")
imessage_group.add_argument(
"--db-path",
type=str,
default=None,
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
)
imessage_group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
imessage_group.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process each message individually instead of concatenating by conversation",
)
imessage_group.add_argument(
"--chunk-size",
type=int,
default=1000,
help="Maximum characters per text chunk (default: 1000)",
)
imessage_group.add_argument(
"--chunk-overlap",
type=int,
default=200,
help="Overlap between text chunks (default: 200)",
)
async def load_data(self, args) -> list[str]:
"""Load iMessage history and convert to text chunks."""
print("Loading iMessage conversation history...")
# Determine concatenation setting
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
# Initialize iMessage reader
reader = IMessageReader(concatenate_conversations=concatenate)
# Load documents
try:
if args.db_path:
# Use custom database path
db_dir = str(Path(args.db_path).parent)
documents = reader.load_data(input_dir=db_dir)
else:
# Use default macOS location
documents = reader.load_data()
except Exception as e:
print(f"Error loading iMessage data: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
print("3. Try specifying a custom path with --db-path if you have a backup")
return []
if not documents:
print("No iMessage conversations found!")
return []
print(f"Loaded {len(documents)} iMessage documents")
# Show some statistics
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
print(f"Total messages: {total_messages}")
if concatenate:
# Show chat statistics
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
unique_chats = len(set(chat_names))
print(f"Unique conversations: {unique_chats}")
# Convert to text chunks
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
)
# Apply max_items limit if specified
if args.max_items > 0:
all_texts = all_texts[: args.max_items]
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
return all_texts
async def main():
"""Main entry point."""
app = IMessageRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,113 +0,0 @@
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
### What youll run
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
## Prerequisites (macOS)
### 1) Homebrew poppler (for pdf2image)
```bash
brew install poppler
which pdfinfo && pdfinfo -v
```
### 2) Python environment
Use uv (recommended) or pip. Python 3.9+.
Using uv:
```bash
uv pip install \
colpali_engine \
pdf2image \
pillow \
matplotlib qwen_vl_utils \
einops \
seaborn
```
Notes:
- On first run, models download from Hugging Face. Login/config if needed.
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
```bash
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
```
## Run the demos
### A) Local PDF example
Converts a local PDF into page images, embeds them, builds an index, and searches.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
# If you don't have the sample PDF locally, download it (ignored by Git)
mkdir -p pdfs
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
ls pdfs/2004.12832v2.pdf
# Ensure output dir exists
mkdir -p pages
python multi-vector-leann-paper-example.py
```
Expected:
- Page images in `pages/`.
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
To use your own PDF: edit `pdf_path` near the top of the script.
### B) Similarity map + answer demo
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Artifacts (when enabled):
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
- Similarity maps: `./figures/similarity_map_rank{K}.png`
Key knobs in the script (top of file):
- `QUERY`: your question
- `MODEL`: `"colqwen2"` or `"colpali"`
- `USE_HF_DATASET`: set `False` to use local pages
- `PDF`, `PAGES_DIR`: for local mode
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
## Troubleshooting
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
## Notes
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
- For local PDFs, page images go to `./pages/`.
### Retrieval and Visualization Example
Example settings in `multi-vector-leann-similarity-map.py`:
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
- `SIMILARITY_MAP = True` (to generate heatmaps)
- `TOPK = 1` (save the top retrieved page and its similarity map)
Run:
```bash
cd apps/multimodal/vision-based-pdf-multi-vector
python multi-vector-leann-similarity-map.py
```
Outputs (by default):
- Retrieved page: `./figures/retrieved_page_rank1.png`
- Similarity map: `./figures/similarity_map_rank1.png`
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
"):
![Similarity map example](fig/image.png)
Notes:
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 166 KiB

View File

@@ -1,780 +0,0 @@
import concurrent.futures
import json
import os
import re
import sys
from pathlib import Path
from typing import Any, Optional, cast
import numpy as np
from PIL import Image
from tqdm import tqdm
def _ensure_repo_paths_importable(current_file: str) -> None:
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
_repo_root = Path(current_file).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
def _find_backend_module_file() -> Optional[Path]:
"""Best-effort locate the backend leann_multi_vector.py file, avoiding this file."""
this_file = Path(__file__).resolve()
candidates: list[Path] = []
# Common in-repo location
repo_root = this_file.parents[3]
candidates.append(repo_root / "packages" / "leann-backend-hnsw" / "leann_multi_vector.py")
candidates.append(
repo_root / "packages" / "leann-backend-hnsw" / "src" / "leann_multi_vector.py"
)
for cand in candidates:
try:
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
pass
# Fallback: scan sys.path for another leann_multi_vector.py different from this file
for p in list(sys.path):
try:
cand = Path(p) / "leann_multi_vector.py"
if cand.exists() and cand.resolve() != this_file:
return cand.resolve()
except Exception:
continue
return None
_BACKEND_LEANN_CLASS: Optional[type] = None
def _get_backend_leann_multi_vector() -> type:
"""Load backend LeannMultiVector class even if this file shadows its module name."""
global _BACKEND_LEANN_CLASS
if _BACKEND_LEANN_CLASS is not None:
return _BACKEND_LEANN_CLASS
backend_path = _find_backend_module_file()
if backend_path is None:
# Fallback to local implementation in this module
try:
cls = LeannMultiVector # type: ignore[name-defined]
_BACKEND_LEANN_CLASS = cls
return cls
except Exception as e:
raise ImportError(
"Could not locate backend 'leann_multi_vector.py' and no local implementation found. "
"Ensure the leann backend is available under packages/leann-backend-hnsw or installed."
) from e
import importlib.util
module_name = "leann_hnsw_backend_module"
spec = importlib.util.spec_from_file_location(module_name, str(backend_path))
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for backend module at {backend_path}")
backend_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = backend_module
spec.loader.exec_module(backend_module) # type: ignore[assignment]
if not hasattr(backend_module, "LeannMultiVector"):
raise ImportError(f"'LeannMultiVector' not found in backend module at {backend_path}")
_BACKEND_LEANN_CLASS = backend_module.LeannMultiVector
return _BACKEND_LEANN_CLASS
def _natural_sort_key(name: str) -> int:
m = re.search(r"\d+", name)
return int(m.group()) if m else 0
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
filenames = sorted(filenames, key=_natural_sort_key)
filepaths = [os.path.join(pages_dir, n) for n in filenames]
images = [Image.open(p) for p in filepaths]
return filepaths, images
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
if not pdf_path:
return
os.makedirs(pages_dir, exist_ok=True)
try:
from pdf2image import convert_from_path
except Exception as e:
raise RuntimeError(
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
) from e
images = convert_from_path(pdf_path, dpi=dpi)
for i, image in enumerate(images):
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
def _select_device_and_dtype():
import torch
from colpali_engine.utils.torch_utils import get_torch_device
device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(device_str)
# Stable dtype selection to avoid NaNs:
# - CUDA: prefer bfloat16 if supported, else float16
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
# - CPU: float32
if device_str == "cuda":
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
try:
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
except Exception:
pass
elif device_str == "mps":
dtype = torch.float32
else:
dtype = torch.float32
return device_str, device, dtype
def _load_colvision(model_choice: str):
import torch
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from transformers.utils.import_utils import is_flash_attn_2_available
device_str, device, dtype = _select_device_and_dtype()
if model_choice == "colqwen2":
model_name = "vidore/colqwen2-v1.0"
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
attn_implementation = (
"flash_attention_2"
if (device_str == "cuda" and is_flash_attn_2_available())
else "eager"
)
model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation=attn_implementation,
).eval()
processor = ColQwen2Processor.from_pretrained(model_name)
else:
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
return model_name, model, processor, device_str, device, dtype
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
# Ensure deterministic eval and autocast for stability
model.eval()
dataloader = DataLoader(
dataset=ListDataset[Image.Image](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
doc_vecs: list[Any] = []
for batch_doc in tqdm(dataloader, desc="Embedding images"):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_doc = model(**batch_doc)
else:
embeddings_doc = model(**batch_doc)
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return doc_vecs
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
import torch
from colpali_engine.utils.torch_utils import ListDataset
from torch.utils.data import DataLoader
model.eval()
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
q_vecs: list[Any] = []
for batch_query in tqdm(dataloader, desc="Embedding queries"):
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
if model.device.type == "cuda":
with torch.autocast(
device_type="cuda",
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
):
embeddings_query = model(**batch_query)
else:
embeddings_query = model(**batch_query)
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
return q_vecs
def _build_index(
index_path: str, doc_vecs: list[Any], filepaths: list[str], images: list[Image.Image]
) -> Any:
LeannMultiVector = _get_backend_leann_multi_vector()
dim = int(doc_vecs[0].shape[-1])
retriever = LeannMultiVector(index_path=index_path, dim=dim)
retriever.create_collection()
for i, vec in enumerate(doc_vecs):
data = {
"colbert_vecs": vec.float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
"image": images[i], # Include the original image
}
retriever.insert(data)
retriever.create_index()
return retriever
def _load_retriever_if_index_exists(index_path: str) -> Optional[Any]:
LeannMultiVector = _get_backend_leann_multi_vector()
index_base = Path(index_path)
# Check for the actual HNSW index file written by the backend + our sidecar files
index_file = index_base.parent / f"{index_base.stem}.index"
meta = index_base.parent / f"{index_base.name}.meta.json"
labels = index_base.parent / f"{index_base.name}.labels.json"
if index_file.exists() and meta.exists() and labels.exists():
try:
with open(meta, encoding="utf-8") as f:
meta_json = json.load(f)
dim = int(meta_json.get("dimensions", 128))
except Exception:
dim = 128
return LeannMultiVector(index_path=index_path, dim=dim)
return None
def _generate_similarity_map(
model,
processor,
image: Image.Image,
query: str,
token_idx: Optional[int] = None,
output_path: Optional[str] = None,
) -> tuple[int, float]:
import torch
from colpali_engine.interpretability import (
get_similarity_maps_from_embeddings,
plot_similarity_map,
)
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
image_embeddings = model.forward(**batch_images)
query_embeddings = model.forward(**batch_queries)
n_patches = processor.get_n_patches(
image_size=image.size,
spatial_merge_size=getattr(model, "spatial_merge_size", None),
)
image_mask = processor.get_image_mask(batch_images)
batched_similarity_maps = get_similarity_maps_from_embeddings(
image_embeddings=image_embeddings,
query_embeddings=query_embeddings,
n_patches=n_patches,
image_mask=image_mask,
)
similarity_maps = batched_similarity_maps[0]
# Determine token index if not provided: choose the token with highest max score
if token_idx is None:
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
token_idx = int(per_token_max.argmax().item())
max_sim_score = similarity_maps[token_idx, :, :].max().item()
if output_path:
import matplotlib.pyplot as plt
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[token_idx],
figsize=(14, 14),
show_colorbar=False,
)
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return token_idx, float(max_sim_score)
class QwenVL:
def __init__(self, device: str):
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from transformers.utils.import_utils import is_flash_attn_2_available
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
torch_dtype="auto",
device_map=device,
attn_implementation=attn_implementation,
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
)
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
import base64
from io import BytesIO
from qwen_vl_utils import process_vision_info
content = []
for img in images:
buffer = BytesIO()
img.save(buffer, format="jpeg")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
)
inputs = inputs.to(self.model.device)
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Ensure repo paths are importable for dynamic backend loading
_ensure_repo_paths_importable(__file__)
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
class LeannMultiVector:
def __init__(
self,
index_path: str,
dim: int = 128,
distance_metric: str = "mips",
m: int = 16,
ef_construction: int = 500,
is_compact: bool = False,
is_recompute: bool = False,
embedding_model_name: str = "colvision",
) -> None:
self.index_path = index_path
self.dim = dim
self.embedding_model_name = embedding_model_name
self._pending_items: list[dict] = []
self._backend_kwargs = {
"distance_metric": distance_metric,
"M": m,
"efConstruction": ef_construction,
"is_compact": is_compact,
"is_recompute": is_recompute,
}
self._labels_meta: list[dict] = []
self._docid_to_indices: dict[int, list[int]] | None = None
def _meta_dict(self) -> dict:
return {
"version": "1.0",
"backend_name": "hnsw",
"embedding_model": self.embedding_model_name,
"embedding_mode": "custom",
"dimensions": self.dim,
"backend_kwargs": self._backend_kwargs,
"is_compact": self._backend_kwargs.get("is_compact", True),
"is_pruned": self._backend_kwargs.get("is_compact", True)
and self._backend_kwargs.get("is_recompute", True),
}
def create_collection(self) -> None:
path = Path(self.index_path)
path.parent.mkdir(parents=True, exist_ok=True)
def insert(self, data: dict) -> None:
self._pending_items.append(
{
"doc_id": int(data["doc_id"]),
"filepath": data.get("filepath", ""),
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
"image": data.get("image"), # PIL Image object (optional)
}
)
def _labels_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
def _meta_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
def _embeddings_path(self) -> Path:
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.emb.npy"
def _images_dir_path(self) -> Path:
"""Directory where original images are stored."""
index_path_obj = Path(self.index_path)
return index_path_obj.parent / f"{index_path_obj.name}.images"
def create_index(self) -> None:
if not self._pending_items:
return
embeddings: list[np.ndarray] = []
labels_meta: list[dict] = []
# Create images directory if needed
images_dir = self._images_dir_path()
images_dir.mkdir(parents=True, exist_ok=True)
for item in self._pending_items:
doc_id = int(item["doc_id"])
filepath = item.get("filepath", "")
colbert_vecs = item["colbert_vecs"]
image = item.get("image")
# Save image if provided
image_path = ""
if image is not None and isinstance(image, Image.Image):
image_filename = f"doc_{doc_id}.png"
image_path = str(images_dir / image_filename)
image.save(image_path, "PNG")
for seq_id, vec in enumerate(colbert_vecs):
vec_np = np.asarray(vec, dtype=np.float32)
embeddings.append(vec_np)
labels_meta.append(
{
"id": f"{doc_id}:{seq_id}",
"doc_id": doc_id,
"seq_id": int(seq_id),
"filepath": filepath,
"image_path": image_path, # Store the path to the saved image
}
)
if not embeddings:
return
embeddings_np = np.vstack(embeddings).astype(np.float32)
print(embeddings_np.shape)
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
ids = [str(i) for i in range(embeddings_np.shape[0])]
builder.build(embeddings_np, ids, self.index_path)
import json as _json
with open(self._meta_path(), "w", encoding="utf-8") as f:
_json.dump(self._meta_dict(), f, indent=2)
with open(self._labels_path(), "w", encoding="utf-8") as f:
_json.dump(labels_meta, f)
# Persist embeddings for exact reranking
np.save(self._embeddings_path(), embeddings_np)
self._labels_meta = labels_meta
def _load_labels_meta_if_needed(self) -> None:
if self._labels_meta:
return
labels_path = self._labels_path()
if labels_path.exists():
import json as _json
with open(labels_path, encoding="utf-8") as f:
self._labels_meta = _json.load(f)
def _build_docid_to_indices_if_needed(self) -> None:
if self._docid_to_indices is not None:
return
self._load_labels_meta_if_needed()
mapping: dict[int, list[int]] = {}
for idx, meta in enumerate(self._labels_meta):
try:
doc_id = int(meta["doc_id"]) # type: ignore[index]
except Exception:
continue
mapping.setdefault(doc_id, []).append(idx)
self._docid_to_indices = mapping
def search(
self, data: np.ndarray, topk: int, first_stage_k: int = 50
) -> list[tuple[float, int]]:
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
distances = raw.get("distances")
if labels is None or distances is None:
return []
doc_scores: dict[int, float] = {}
B = len(labels)
for b in range(B):
per_doc_best: dict[int, float] = {}
for k, sid in enumerate(labels[b]):
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
else:
continue
score = float(distances[b][k])
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
per_doc_best[doc_id] = score
for doc_id, best_score in per_doc_best.items():
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact(
self,
data: np.ndarray,
topk: int,
*,
first_stage_k: int = 200,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
High-precision MaxSim reranking over candidate documents.
Steps:
1) Run a first-stage ANN to collect candidate doc_ids (using seq-level neighbors).
2) For each candidate doc, load all its token embeddings and compute
MaxSim(query_tokens, doc_tokens) exactly: sum(max(dot(q_i, d_j))).
Returns top-k list of (score, doc_id).
"""
# Normalize inputs
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
# Fallback to approximate if we don't have persisted embeddings
return self.search(data, topk, first_stage_k=first_stage_k)
# Memory-map embeddings to avoid loading all into RAM
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
# First-stage ANN to collect candidate doc_ids
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
raw = searcher.search(
data,
first_stage_k,
recompute_embeddings=False,
complexity=128,
beam_width=1,
prune_ratio=0.0,
batch_size=0,
)
labels = raw.get("labels")
if labels is None:
return []
candidate_doc_ids: set[int] = set()
for batch in labels:
for sid in batch:
try:
idx = int(sid)
except Exception:
continue
if 0 <= idx < len(self._labels_meta):
candidate_doc_ids.add(int(self._labels_meta[idx]["doc_id"])) # type: ignore[index]
# Exact scoring per doc (parallelized)
assert self._docid_to_indices is not None
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
# (Q, D) x (P, D)^T -> (Q, P) then MaxSim over P, sum over Q
sim = np.dot(data, doc_vecs.T)
# nan-safe
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, doc_id) for doc_id in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def search_exact_all(
self,
data: np.ndarray,
topk: int,
*,
max_workers: int = 32,
) -> list[tuple[float, int]]:
"""
Exact MaxSim over ALL documents (no ANN pre-filtering).
This computes, for each document, sum_i max_j dot(q_i, d_j).
It memory-maps the persisted token-embedding matrix for scalability.
"""
if data.ndim == 1:
data = data.reshape(1, -1)
if data.dtype != np.float32:
data = data.astype(np.float32)
self._load_labels_meta_if_needed()
self._build_docid_to_indices_if_needed()
emb_path = self._embeddings_path()
if not emb_path.exists():
return self.search(data, topk)
all_embeddings = np.load(emb_path, mmap_mode="r")
if all_embeddings.dtype != np.float32:
all_embeddings = all_embeddings.astype(np.float32)
assert self._docid_to_indices is not None
candidate_doc_ids = list(self._docid_to_indices.keys())
def _score_one(doc_id: int) -> tuple[float, int]:
token_indices = self._docid_to_indices.get(doc_id, [])
if not token_indices:
return (0.0, doc_id)
doc_vecs = np.asarray(all_embeddings[token_indices], dtype=np.float32)
sim = np.dot(data, doc_vecs.T)
sim = np.nan_to_num(sim, nan=-1e30, posinf=1e30, neginf=-1e30)
score = sim.max(axis=2).sum(axis=1) if sim.ndim == 3 else sim.max(axis=1).sum()
return (float(score), doc_id)
scores: list[tuple[float, int]] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_score_one, d) for d in candidate_doc_ids]
for fut in concurrent.futures.as_completed(futures):
scores.append(fut.result())
scores.sort(key=lambda x: x[0], reverse=True)
return scores[:topk] if len(scores) >= topk else scores
def get_image(self, doc_id: int) -> Optional[Image.Image]:
"""
Retrieve the original image for a given doc_id from the index.
Args:
doc_id: The document ID
Returns:
PIL Image object if found, None otherwise
"""
self._load_labels_meta_if_needed()
# Find the image_path for this doc_id (all seq_ids for same doc share the same image_path)
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
image_path = meta.get("image_path", "")
if image_path and Path(image_path).exists():
return Image.open(image_path)
break
return None
def get_metadata(self, doc_id: int) -> Optional[dict]:
"""
Retrieve metadata for a given doc_id.
Args:
doc_id: The document ID
Returns:
Dictionary with metadata (filepath, image_path, etc.) if found, None otherwise
"""
self._load_labels_meta_if_needed()
for meta in self._labels_meta:
if meta.get("doc_id") == doc_id:
return {
"doc_id": doc_id,
"filepath": meta.get("filepath", ""),
"image_path": meta.get("image_path", ""),
}
return None

View File

@@ -1,112 +0,0 @@
# pip install pdf2image
# pip install pymilvus
# pip install colpali_engine
# pip install tqdm
# pip install pillow
import os
import re
import sys
from pathlib import Path
from typing import cast
from PIL import Image
from tqdm import tqdm
# Ensure local leann packages are importable before importing them
_repo_root = Path(__file__).resolve().parents[3]
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
if str(_leann_core_src) not in sys.path:
sys.path.append(str(_leann_core_src))
if str(_leann_hnsw_pkg) not in sys.path:
sys.path.append(str(_leann_hnsw_pkg))
import torch
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
# Auto-select device: CUDA > MPS (mac) > CPU
_device_str = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
else "cpu"
)
)
device = get_torch_device(_device_str)
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
model_name,
torch_dtype=_dtype,
device_map=device,
).eval()
print(f"Using device={_device_str}, dtype={_dtype}")
queries = [
"How to end-to-end retrieval with ColBert",
"Where is ColBERT performance Table, including text representation results?",
]
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
dataloader = DataLoader(
dataset=ListDataset[str](queries),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: list[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
print(qs[0].shape)
# %%
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=1,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: list[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
print(ds[0].shape)
# %%
# Build HNSW index via LeannRetriever primitives and run search
index_path = "./indexes/colpali.leann"
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
retriever.create_collection()
filepaths = [os.path.join("./pages", name) for name in page_filenames]
for i in range(len(filepaths)):
data = {
"colbert_vecs": ds[i].float().numpy(),
"doc_id": i,
"filepath": filepaths[i],
}
retriever.insert(data)
retriever.create_index()
for query in qs:
query_np = query.float().numpy()
result = retriever.search(query_np, topk=1)
print(filepaths[result[0][1]])

View File

@@ -1,209 +0,0 @@
## Jupyter-style notebook script
# %%
# uv pip install matplotlib qwen_vl_utils
import os
from typing import Any, Optional
from PIL import Image
from tqdm import tqdm
from leann_multi_vector import ( # utility functions/classes
_ensure_repo_paths_importable,
_load_images_from_dir,
_maybe_convert_pdf_to_images,
_load_colvision,
_embed_images,
_embed_queries,
_build_index,
_load_retriever_if_index_exists,
_generate_similarity_map,
QwenVL,
)
_ensure_repo_paths_importable(__file__)
# %%
# Config
os.environ["TOKENIZERS_PARALLELISM"] = "false"
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
# Data source: set to True to use the Hugging Face dataset example (recommended)
USE_HF_DATASET: bool = True
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
DATASET_SPLIT: str = "train"
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
# Local pages (used when USE_HF_DATASET == False)
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
PAGES_DIR: str = "./pages"
# Index + retrieval settings
INDEX_PATH: str = "./indexes/colvision.leann"
TOPK: int = 3
FIRST_STAGE_K: int = 500
REBUILD_INDEX: bool = False
# Artifacts
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
SIMILARITY_MAP: bool = True
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
SIM_OUTPUT: str = "./figures/similarity_map.png"
ANSWER: bool = True
MAX_NEW_TOKENS: int = 1024
# %%
# Step 1: Check if we can skip data loading (index already exists)
retriever: Optional[Any] = None
need_to_build_index = REBUILD_INDEX
if not REBUILD_INDEX:
retriever = _load_retriever_if_index_exists(INDEX_PATH)
if retriever is not None:
print(f"✓ Index loaded from {INDEX_PATH}")
print(f"✓ Images available at: {retriever._images_dir_path()}")
need_to_build_index = False
else:
print(f"Index not found, will build new index")
need_to_build_index = True
# Step 2: Load data only if we need to build the index
if need_to_build_index:
print("Loading dataset...")
if USE_HF_DATASET:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
filepaths: list[str] = []
images: list[Image.Image] = []
for i in tqdm(range(N), desc="Loading dataset", total=N):
p = dataset[i]
# Compose a descriptive identifier for printing later
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
filepaths.append(identifier)
images.append(p["page_image"]) # PIL Image
else:
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
filepaths, images = _load_images_from_dir(PAGES_DIR)
if not images:
raise RuntimeError(
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
)
print(f"Loaded {len(images)} images")
else:
print("Skipping dataset loading (using existing index)")
filepaths = [] # Not needed when using existing index
images = [] # Not needed when using existing index
# %%
# Step 3: Load model and processor (only if we need to build index or perform search)
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
# %%
# %%
# Step 4: Build index if needed
if need_to_build_index and retriever is None:
print("Building index...")
doc_vecs = _embed_images(model, processor, images)
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
# Clear memory
del images, filepaths, doc_vecs
# Note: Images are now stored in the index, retriever will load them on-demand from disk
# %%
# Step 5: Embed query and search
q_vec = _embed_queries(model, processor, [QUERY])[0]
results = retriever.search(q_vec.float().numpy(), topk=TOPK)
if not results:
print("No results found.")
else:
print(f'Top {len(results)} results for query: "{QUERY}"')
top_images: list[Image.Image] = []
for rank, (score, doc_id) in enumerate(results, start=1):
# Retrieve image from index instead of memory
image = retriever.get_image(doc_id)
if image is None:
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
continue
metadata = retriever.get_metadata(doc_id)
path = metadata.get("filepath", "unknown") if metadata else "unknown"
# For HF dataset, path is a descriptive identifier, not a real file path
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
top_images.append(image)
if SAVE_TOP_IMAGE:
from pathlib import Path as _Path
base = _Path(SAVE_TOP_IMAGE)
base.parent.mkdir(parents=True, exist_ok=True)
for rank, img in enumerate(top_images[:TOPK], start=1):
if base.suffix:
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
else:
out_path = base / f"retrieved_page_rank{rank}.png"
img.save(str(out_path))
# Print the retrieval score (document-level MaxSim) alongside the saved path
try:
score, _doc_id = results[rank - 1]
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
except Exception:
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
## TODO stange results of second page of DeepSeek-V2 rather than the first page
# %%
# Step 6: Similarity maps for top-K results
if results and SIMILARITY_MAP:
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
from pathlib import Path as _Path
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
for rank, img in enumerate(top_images[:TOPK], start=1):
if output_base:
if output_base.suffix:
out_dir = output_base.parent
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
out_path = str(out_dir / out_name)
else:
out_dir = output_base
out_dir.mkdir(parents=True, exist_ok=True)
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
else:
out_path = None
chosen_idx, max_sim = _generate_similarity_map(
model=model,
processor=processor,
image=img,
query=QUERY,
token_idx=token_idx,
output_path=out_path,
)
if out_path:
print(
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
)
else:
print(
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
)
# %%
# Step 7: Optional answer generation
if results and ANSWER:
qwen = QwenVL(device=device_str)
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
print("\nAnswer:")
print(response)

View File

@@ -1,183 +0,0 @@
#!/usr/bin/env python3
import re
import sys
from datetime import datetime, timedelta
from pathlib import Path
from leann import LeannSearcher
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
class TimeParser:
def __init__(self):
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
# Compile for performance
self.regex = re.compile(self.pattern, re.IGNORECASE)
# Stop words to remove before regex parsing
self.stop_words = {
"in",
"at",
"of",
"by",
"as",
"me",
"the",
"a",
"an",
"and",
"any",
"find",
"search",
"list",
"ago",
"back",
"past",
"earlier",
}
def clean_text(self, text):
"""Remove stop words from text"""
words = text.split()
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
return cleaned
def parse(self, text):
"""Extract all time expressions from text"""
# Clean text first
cleaned_text = self.clean_text(text)
matches = []
for match in self.regex.finditer(cleaned_text):
fuzzy = match.group(1) # "around", "about", etc.
number = int(match.group(2))
unit = match.group(3).lower()
matches.append(
{
"full_match": match.group(0),
"fuzzy": bool(fuzzy),
"number": number,
"unit": unit,
"range": self.calculate_range(number, unit, bool(fuzzy)),
}
)
return matches
def calculate_range(self, number, unit, is_fuzzy):
"""Convert to actual datetime range and return ISO format strings"""
units = {
"hour": timedelta(hours=number),
"day": timedelta(days=number),
"week": timedelta(weeks=number),
"month": timedelta(days=number * 30),
"year": timedelta(days=number * 365),
}
delta = units[unit]
now = datetime.now()
target = now - delta
if is_fuzzy:
buffer = delta * 0.2 # 20% buffer for fuzzy
start = (target - buffer).isoformat()
end = (target + buffer).isoformat()
else:
start = target.isoformat()
end = now.isoformat()
return (start, end)
def search_files(query, top_k=15):
"""Search the index and return results"""
# Parse time expressions
parser = TimeParser()
time_matches = parser.parse(query)
# Remove time expressions from query for semantic search
clean_query = query
if time_matches:
for match in time_matches:
clean_query = clean_query.replace(match["full_match"], "").strip()
# Check if clean_query is less than 4 characters
if len(clean_query) < 4:
print("Error: add more input for accurate results.")
return
# Single query to vector DB
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search(
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
)
# Filter by time if time expression found
if time_matches:
time_range = time_matches[0]["range"] # Use first time expression
start_time, end_time = time_range
filtered_results = []
for result in results:
# Access metadata attribute directly (not .get())
metadata = result.metadata if hasattr(result, "metadata") else {}
if metadata:
# Check modification date first, fall back to creation date
date_str = metadata.get("modification_date") or metadata.get("creation_date")
if date_str:
# Convert strings to datetime objects for proper comparison
try:
file_date = datetime.fromisoformat(date_str)
start_dt = datetime.fromisoformat(start_time)
end_dt = datetime.fromisoformat(end_time)
# Compare dates properly
if start_dt <= file_date <= end_dt:
filtered_results.append(result)
except (ValueError, TypeError):
# Handle invalid date formats
print(f"Warning: Invalid date format in metadata: {date_str}")
continue
results = filtered_results
# Print results
print(f"\nSearch results for: '{query}'")
if time_matches:
print(
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
)
print(
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
)
print("-" * 80)
for i, result in enumerate(results, 1):
print(f"\n[{i}] Score: {result.score:.4f}")
print(f"Content: {result.text}")
# Show metadata if present
metadata = result.metadata if hasattr(result, "metadata") else None
if metadata:
if "creation_date" in metadata:
print(f"Created: {metadata['creation_date']}")
if "modification_date" in metadata:
print(f"Modified: {metadata['modification_date']}")
print("-" * 80)
if __name__ == "__main__":
if len(sys.argv) < 2:
print('Usage: python search_index.py "<search query>" [top_k]')
sys.exit(1)
query = sys.argv[1]
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
search_files(query, top_k)

View File

@@ -1,82 +0,0 @@
#!/usr/bin/env python3
import json
import sys
from pathlib import Path
from leann import LeannBuilder
def process_json_items(json_file_path):
"""Load and process JSON file with metadata items"""
with open(json_file_path, encoding="utf-8") as f:
items = json.load(f)
# Guard against empty JSON
if not items:
print("⚠️ No items found in the JSON file. Exiting gracefully.")
return
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
total_items = len(items)
items_added = 0
print(f"Processing {total_items} items...")
for idx, item in enumerate(items):
try:
# Create embedding text sentence
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
# Prepare metadata with dates
metadata = {}
if "CreationDate" in item:
metadata["creation_date"] = item["CreationDate"]
if "ContentChangeDate" in item:
metadata["modification_date"] = item["ContentChangeDate"]
# Add to builder
builder.add_text(embedding_text, metadata=metadata)
items_added += 1
except Exception as e:
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
continue
# Show progress
progress = (idx + 1) / total_items * 100
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
sys.stdout.flush()
print() # New line after progress
# Guard against no successfully added items
if items_added == 0:
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
return
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
print("Building index...")
try:
builder.build_index(INDEX_PATH)
print(f"✓ Index saved to {INDEX_PATH}")
except ValueError as e:
if "No chunks added" in str(e):
print("⚠️ No chunks were added to the builder. Index not created.")
else:
raise
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python build_index.py <json_file>")
sys.exit(1)
json_file = sys.argv[1]
if not Path(json_file).exists():
print(f"Error: File {json_file} not found")
sys.exit(1)
process_json_items(json_file)

View File

@@ -1,265 +0,0 @@
#!/usr/bin/env python3
"""
Spotlight Metadata Dumper for Vector DB
Extracts only essential metadata for semantic search embeddings
Output is optimized for vector database storage with minimal fields
"""
import json
import sys
from datetime import datetime
# Check platform before importing macOS-specific modules
if sys.platform != "darwin":
print("This script requires macOS (uses Spotlight)")
sys.exit(1)
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
# EDIT THIS LIST: Add or remove folders to search
# Can be either:
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
# - Absolute paths (e.g., "/Applications", "/System/Library")
SEARCH_FOLDERS = [
"Desktop",
"Downloads",
"Documents",
"Music",
"Pictures",
"Movies",
# "Library", # Uncomment to include
# "/Applications", # Absolute path example
# "Code/Projects", # Subfolder example
# Add any other folders here
]
def convert_to_serializable(obj):
"""Convert NS objects to Python serializable types"""
if obj is None:
return None
# Handle NSDate
if hasattr(obj, "timeIntervalSince1970"):
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
# Handle NSArray
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
# Convert to string
try:
return str(obj)
except Exception:
return repr(obj)
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
"""
Dump Spotlight data using public.item predicate
"""
# Build full paths from SEARCH_FOLDERS
import os
home_dir = os.path.expanduser("~")
search_paths = []
print("Search locations:")
for folder in SEARCH_FOLDERS:
# Check if it's an absolute path or relative
if folder.startswith("/"):
full_path = folder
else:
full_path = os.path.join(home_dir, folder)
if os.path.exists(full_path):
search_paths.append(full_path)
print(f"{full_path}")
else:
print(f"{full_path} (not found)")
if not search_paths:
print("No valid search paths found!")
return []
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
# Create query with public.item predicate
query = NSMetadataQuery.alloc().init()
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
query.setPredicate_(predicate)
# Set search scopes to our specific folders
query.setSearchScopes_(search_paths)
print("Starting query...")
query.startQuery()
# Wait for gathering to complete
run_loop = NSRunLoop.currentRunLoop()
print("Gathering results...")
# Let it gather for a few seconds
for i in range(50): # 5 seconds max
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
# Check gathering status periodically
if i % 10 == 0:
current_count = query.resultCount()
if current_count > 0:
print(f" Found {current_count} items so far...")
# Continue while still gathering (up to 2 more seconds)
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
run_loop.runMode_beforeDate_(
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
)
query.stopQuery()
total_results = query.resultCount()
print(f"Found {total_results} total items")
if total_results == 0:
print("No results found")
return []
# Process items
items_to_process = min(total_results, max_items)
results = []
# ONLY relevant attributes for vector embeddings
# These provide essential context for semantic search without bloat
attributes = [
"kMDItemPath", # Full path for file retrieval
"kMDItemFSName", # Filename for display & embedding
"kMDItemFSSize", # Size for filtering/ranking
"kMDItemContentType", # File type for categorization
"kMDItemKind", # Human-readable type for embedding
"kMDItemFSCreationDate", # Temporal context
"kMDItemFSContentChangeDate", # Recency for ranking
]
print(f"Processing {items_to_process} items...")
for i in range(items_to_process):
try:
item = query.resultAtIndex_(i)
metadata = {}
# Extract ONLY the relevant attributes
for attr in attributes:
try:
value = item.valueForAttribute_(attr)
if value is not None:
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
clean_key = attr.replace("kMDItem", "").replace("FS", "")
metadata[clean_key] = convert_to_serializable(value)
except (AttributeError, ValueError, TypeError):
continue
# Only add if we have at least a path
if metadata.get("Path"):
results.append(metadata)
except Exception as e:
print(f"Error processing item {i}: {e}")
continue
# Save to JSON
with open(output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\n✓ Saved {len(results)} items to {output_file}")
# Show summary
print("\nSample items:")
import os
home_dir = os.path.expanduser("~")
for i, item in enumerate(results[:3]):
print(f"\n[Item {i + 1}]")
print(f" Path: {item.get('Path', 'N/A')}")
print(f" Name: {item.get('Name', 'N/A')}")
print(f" Type: {item.get('ContentType', 'N/A')}")
print(f" Kind: {item.get('Kind', 'N/A')}")
# Handle size properly
size = item.get("Size")
if size:
try:
size_int = int(size)
if size_int > 1024 * 1024:
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
elif size_int > 1024:
print(f" Size: {size_int / 1024:.2f} KB")
else:
print(f" Size: {size_int} bytes")
except (ValueError, TypeError):
print(f" Size: {size}")
# Show dates
if "CreationDate" in item:
print(f" Created: {item['CreationDate']}")
if "ContentChangeDate" in item:
print(f" Modified: {item['ContentChangeDate']}")
# Count by type
type_counts = {}
for item in results:
content_type = item.get("ContentType", "unknown")
type_counts[content_type] = type_counts.get(content_type, 0) + 1
print(f"\nTotal items saved: {len(results)}")
if type_counts:
print("\nTop content types:")
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
print(f" {ct}: {count} items")
# Count by folder
folder_counts = {}
for item in results:
path = item.get("Path", "")
for folder in SEARCH_FOLDERS:
# Build the full folder path
if folder.startswith("/"):
folder_path = folder
else:
folder_path = os.path.join(home_dir, folder)
if path.startswith(folder_path):
folder_counts[folder] = folder_counts.get(folder, 0) + 1
break
if folder_counts:
print("\nItems by location:")
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
print(f" {folder}: {count} items")
return results
def main():
# Parse arguments
if len(sys.argv) > 1:
try:
max_items = int(sys.argv[1])
except ValueError:
print("Usage: python spot.py [number_of_items]")
print("Default: 10 items")
sys.exit(1)
else:
max_items = 10
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
# Run dump
dump_spotlight_data(max_items=max_items, output_file=output_file)
if __name__ == "__main__":
main()

View File

@@ -1 +0,0 @@
# Slack MCP data integration for LEANN

View File

@@ -1,510 +0,0 @@
#!/usr/bin/env python3
"""
Slack MCP Reader for LEANN
This module provides functionality to connect to Slack MCP servers and fetch message data
for indexing in LEANN. It supports various Slack MCP server implementations and provides
flexible message processing options.
"""
import asyncio
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
class SlackMCPReader:
"""
Reader for Slack data via MCP (Model Context Protocol) servers.
This class connects to Slack MCP servers to fetch message data and convert it
into a format suitable for LEANN indexing.
"""
def __init__(
self,
mcp_server_command: str,
workspace_name: Optional[str] = None,
concatenate_conversations: bool = True,
max_messages_per_conversation: int = 100,
max_retries: int = 5,
retry_delay: float = 2.0,
):
"""
Initialize the Slack MCP Reader.
Args:
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
workspace_name: Optional workspace name to filter messages
concatenate_conversations: Whether to group messages by channel/thread
max_messages_per_conversation: Maximum messages to include per conversation
max_retries: Maximum number of retries for failed operations
retry_delay: Initial delay between retries in seconds
"""
self.mcp_server_command = mcp_server_command
self.workspace_name = workspace_name
self.concatenate_conversations = concatenate_conversations
self.max_messages_per_conversation = max_messages_per_conversation
self.max_retries = max_retries
self.retry_delay = retry_delay
self.mcp_process = None
async def start_mcp_server(self):
"""Start the MCP server process."""
try:
self.mcp_process = await asyncio.create_subprocess_exec(
*self.mcp_server_command.split(),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Started MCP server: {self.mcp_server_command}")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
raise
async def stop_mcp_server(self):
"""Stop the MCP server process."""
if self.mcp_process:
self.mcp_process.terminate()
await self.mcp_process.wait()
logger.info("Stopped MCP server")
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a request to the MCP server and get response."""
if not self.mcp_process:
raise RuntimeError("MCP server not started")
request_json = json.dumps(request) + "\n"
self.mcp_process.stdin.write(request_json.encode())
await self.mcp_process.stdin.drain()
response_line = await self.mcp_process.stdout.readline()
if not response_line:
raise RuntimeError("No response from MCP server")
return json.loads(response_line.decode().strip())
async def initialize_mcp_connection(self):
"""Initialize the MCP connection."""
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
},
}
response = await self.send_mcp_request(init_request)
if "error" in response:
raise RuntimeError(f"MCP initialization failed: {response['error']}")
logger.info("MCP connection initialized successfully")
async def list_available_tools(self) -> list[dict[str, Any]]:
"""List available tools from the MCP server."""
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
response = await self.send_mcp_request(list_request)
if "error" in response:
raise RuntimeError(f"Failed to list tools: {response['error']}")
return response.get("result", {}).get("tools", [])
def _is_cache_sync_error(self, error: dict) -> bool:
"""Check if the error is related to users cache not being ready."""
if isinstance(error, dict):
message = error.get("message", "").lower()
return (
"users cache is not ready" in message or "sync process is still running" in message
)
return False
async def _retry_with_backoff(self, func, *args, **kwargs):
"""Retry a function with exponential backoff, especially for cache sync issues."""
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# Check if this is a cache sync error
error_dict = {}
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
error_dict = e.args[0]
elif "Failed to fetch messages" in str(e):
# Try to extract error from the exception message
import re
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
pass
else:
# Try alternative format
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
if match:
try:
error_dict = eval(match.group(1))
except (ValueError, SyntaxError, NameError):
pass
if self._is_cache_sync_error(error_dict):
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # Exponential backoff
logger.info(
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
)
await asyncio.sleep(delay)
continue
else:
logger.warning(
f"Cache sync still not ready after {self.max_retries} retries, giving up"
)
break
else:
# Not a cache sync error, don't retry
break
# If we get here, all retries failed or it's not a retryable error
raise last_exception
async def fetch_slack_messages(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
Args:
channel: Optional channel name to filter messages
limit: Maximum number of messages to fetch
Returns:
List of message dictionaries
"""
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
async def _fetch_slack_messages_impl(
self, channel: Optional[str] = None, limit: int = 100
) -> list[dict[str, Any]]:
"""
Internal implementation of fetch_slack_messages without retry logic.
"""
# This is a generic implementation - specific MCP servers may have different tool names
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
tools = await self.list_available_tools()
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
message_tool = None
# Look for a tool that can fetch messages - prioritize conversations_history
message_tool = None
# First, try to find conversations_history specifically
for tool in tools:
tool_name = tool.get("name", "").lower()
if "conversations_history" in tool_name:
message_tool = tool
logger.info(f"Found conversations_history tool: {tool}")
break
# If not found, look for other message-fetching tools
if not message_tool:
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(
keyword in tool_name
for keyword in ["conversations_search", "message", "history"]
):
message_tool = tool
break
if not message_tool:
raise RuntimeError("No message fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {"limit": "180d"} # Use 180 days to get older messages
if channel:
# For conversations_history, use channel_id parameter
if message_tool["name"] == "conversations_history":
tool_params["channel_id"] = channel
else:
# Try common parameter names for channel specification
for param_name in ["channel", "channel_id", "channel_name"]:
tool_params[param_name] = channel
break
logger.info(f"Tool parameters: {tool_params}")
fetch_request = {
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {"name": message_tool["name"], "arguments": tool_params},
}
response = await self.send_mcp_request(fetch_request)
if "error" in response:
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
# Extract messages from response - format may vary by MCP server
result = response.get("result", {})
if "content" in result and isinstance(result["content"], list):
# Some MCP servers return content as a list
content = result["content"][0] if result["content"] else {}
if "text" in content:
try:
messages = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, try to parse as CSV format (Slack MCP server format)
messages = self._parse_csv_messages(content["text"], channel)
else:
messages = result["content"]
else:
# Direct message format
messages = result.get("messages", [result])
return messages if isinstance(messages, list) else [messages]
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
"""Parse CSV format messages from Slack MCP server."""
import csv
import io
messages = []
try:
# Split by lines and process each line as a CSV row
lines = csv_text.strip().split("\n")
if not lines:
return messages
# Skip header line if it exists
start_idx = 0
if lines[0].startswith("MsgID,UserID,UserName"):
start_idx = 1
for line in lines[start_idx:]:
if not line.strip():
continue
# Parse CSV line
reader = csv.reader(io.StringIO(line))
try:
row = next(reader)
if len(row) >= 7: # Ensure we have enough columns
message = {
"ts": row[0],
"user": row[1],
"username": row[2],
"real_name": row[3],
"channel": row[4],
"thread_ts": row[5],
"text": row[6],
"time": row[7] if len(row) > 7 else "",
"reactions": row[8] if len(row) > 8 else "",
"cursor": row[9] if len(row) > 9 else "",
}
messages.append(message)
except Exception as e:
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
continue
except Exception as e:
logger.warning(f"Failed to parse CSV messages: {e}")
# Fallback: treat entire text as one message
messages = [{"text": csv_text, "channel": channel or "unknown"}]
return messages
def _format_message(self, message: dict[str, Any]) -> str:
"""Format a single message for indexing."""
text = message.get("text", "")
user = message.get("user", message.get("username", "Unknown"))
channel = message.get("channel", message.get("channel_name", "Unknown"))
timestamp = message.get("ts", message.get("timestamp", ""))
# Format timestamp if available
formatted_time = ""
if timestamp:
try:
import datetime
if isinstance(timestamp, str) and "." in timestamp:
dt = datetime.datetime.fromtimestamp(float(timestamp))
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(timestamp, (int, float)):
dt = datetime.datetime.fromtimestamp(timestamp)
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_time = str(timestamp)
except (ValueError, TypeError):
formatted_time = str(timestamp)
# Build formatted message
parts = []
if channel:
parts.append(f"Channel: #{channel}")
if user:
parts.append(f"User: {user}")
if formatted_time:
parts.append(f"Time: {formatted_time}")
if text:
parts.append(f"Message: {text}")
return "\n".join(parts)
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
"""Create concatenated content from multiple messages in a channel."""
if not messages:
return ""
# Sort messages by timestamp if available
try:
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
except (ValueError, TypeError):
pass # Keep original order if timestamps aren't numeric
# Limit messages per conversation
if len(messages) > self.max_messages_per_conversation:
messages = messages[-self.max_messages_per_conversation :]
# Create header
content_parts = [
f"Slack Channel: #{channel}",
f"Message Count: {len(messages)}",
f"Workspace: {self.workspace_name or 'Unknown'}",
"=" * 50,
"",
]
# Add messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
content_parts.append(formatted_msg)
content_parts.append("-" * 30)
content_parts.append("")
return "\n".join(content_parts)
async def get_all_channels(self) -> list[str]:
"""Get list of all available channels."""
try:
channels_list_request = {
"jsonrpc": "2.0",
"id": 4,
"method": "tools/call",
"params": {"name": "channels_list", "arguments": {}},
}
channels_response = await self.send_mcp_request(channels_list_request)
if "result" in channels_response:
result = channels_response["result"]
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
# Parse the channels from the response
channels = []
lines = content["text"].split("\n")
for line in lines:
if line.strip() and ("#" in line or "C" in line[:10]):
# Extract channel ID or name
parts = line.split()
for part in parts:
if part.startswith("C") and len(part) > 5:
channels.append(part)
elif part.startswith("#"):
channels.append(part[1:]) # Remove #
logger.info(f"Found {len(channels)} channels: {channels}")
return channels
return []
except Exception as e:
logger.warning(f"Failed to get channels list: {e}")
return []
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
"""
Read Slack data and return formatted text chunks.
Args:
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
Returns:
List of formatted text chunks ready for LEANN indexing
"""
try:
await self.start_mcp_server()
await self.initialize_mcp_connection()
all_texts = []
if channels:
# Fetch specific channels
for channel in channels:
try:
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
else:
# Fetch from all available channels
logger.info("Fetching from all available channels...")
all_channels = await self.get_all_channels()
if not all_channels:
# Fallback to common channel names if we can't get the list
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
logger.info(f"Using fallback channels: {all_channels}")
for channel in all_channels:
try:
logger.info(f"Searching channel: {channel}")
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
if messages:
if self.concatenate_conversations:
text_content = self._create_concatenated_content(messages, channel)
if text_content.strip():
all_texts.append(text_content)
else:
# Process individual messages
for message in messages:
formatted_msg = self._format_message(message)
if formatted_msg.strip():
all_texts.append(formatted_msg)
except Exception as e:
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
continue
return all_texts
finally:
await self.stop_mcp_server()
async def __aenter__(self):
"""Async context manager entry."""
await self.start_mcp_server()
await self.initialize_mcp_connection()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop_mcp_server()

View File

@@ -1,227 +0,0 @@
#!/usr/bin/env python3
"""
Slack RAG Application with MCP Support
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
Usage:
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
"""
import argparse
import asyncio
from apps.base_rag_example import BaseRAGExample
from apps.slack_data.slack_mcp_reader import SlackMCPReader
class SlackMCPRAG(BaseRAGExample):
"""
RAG application for Slack messages via MCP servers.
This class provides a complete RAG pipeline for Slack data, including
MCP server connection, data fetching, indexing, and interactive chat.
"""
def __init__(self):
super().__init__(
name="Slack MCP RAG",
description="RAG application for Slack messages via MCP servers",
default_index_name="slack_messages",
)
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add Slack MCP-specific arguments."""
parser.add_argument(
"--mcp-server",
type=str,
required=True,
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
)
parser.add_argument(
"--workspace-name",
type=str,
help="Slack workspace name for better organization and filtering",
)
parser.add_argument(
"--channels",
nargs="+",
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
)
parser.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Group messages by channel/thread for better context (default: True)",
)
parser.add_argument(
"--no-concatenate-conversations",
action="store_true",
help="Process individual messages instead of grouping by channel",
)
parser.add_argument(
"--max-messages-per-channel",
type=int,
default=100,
help="Maximum number of messages to include per channel (default: 100)",
)
parser.add_argument(
"--test-connection",
action="store_true",
help="Test MCP server connection and list available tools without indexing",
)
parser.add_argument(
"--max-retries",
type=int,
default=5,
help="Maximum number of retries for failed operations (default: 5)",
)
parser.add_argument(
"--retry-delay",
type=float,
default=2.0,
help="Initial delay between retries in seconds (default: 2.0)",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
try:
reader = SlackMCPReader(
mcp_server_command=args.mcp_server,
workspace_name=args.workspace_name,
concatenate_conversations=not args.no_concatenate_conversations,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
async with reader:
tools = await reader.list_available_tools()
print("Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
name = tool.get("name", "Unknown")
description = tool.get("description", "No description available")
print(f"\n{i}. {name}")
print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
)
# Show input schema if available
schema = tool.get("inputSchema", {})
if schema.get("properties"):
props = list(schema["properties"].keys())[:3] # Show first 3 properties
print(
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
)
return True
except Exception as e:
print(f"Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the MCP server is installed and accessible")
print("2. Check if the server command is correct")
print("3. Ensure you have proper authentication/credentials configured")
print("4. Try running the MCP server command directly to test it")
return False
async def load_data(self, args) -> list[str]:
"""Load Slack messages via MCP server."""
print(f"Connecting to Slack MCP server: {args.mcp_server}")
if args.workspace_name:
print(f"Workspace: {args.workspace_name}")
# Filter out empty strings from channels
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
if channels:
print(f"Channels: {', '.join(channels)}")
else:
print("Fetching from all available channels")
concatenate = not args.no_concatenate_conversations
print(
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
)
try:
reader = SlackMCPReader(
mcp_server_command=args.mcp_server,
workspace_name=args.workspace_name,
concatenate_conversations=concatenate,
max_messages_per_conversation=args.max_messages_per_channel,
max_retries=args.max_retries,
retry_delay=args.retry_delay,
)
texts = await reader.read_slack_data(channels=channels)
if not texts:
print("No messages found! This could mean:")
print("- The MCP server couldn't fetch messages")
print("- The specified channels don't exist or are empty")
print("- Authentication issues with the Slack workspace")
return []
print(f"Successfully loaded {len(texts)} text chunks from Slack")
# Show sample of what was loaded
if texts:
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
print("\nSample content:")
print("-" * 40)
print(sample_text)
print("-" * 40)
return texts
except Exception as e:
print(f"Error loading Slack data: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Authentication problems")
print("- Network connectivity issues")
print("- Incorrect channel names")
raise
async def run(self):
"""Main entry point with MCP connection testing."""
args = self.parser.parse_args()
# Test connection if requested
if args.test_connection:
success = await self.test_mcp_connection(args)
if not success:
return
print(
"MCP server is working! You can now run without --test-connection to start indexing."
)
return
# Run the standard RAG pipeline
await super().run()
async def main():
"""Main entry point for the Slack MCP RAG application."""
app = SlackMCPRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1 +0,0 @@
# Twitter MCP data integration for LEANN

View File

@@ -1,295 +0,0 @@
#!/usr/bin/env python3
"""
Twitter MCP Reader for LEANN
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
flexible bookmark processing options.
"""
import asyncio
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
class TwitterMCPReader:
"""
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
This class connects to Twitter MCP servers to fetch bookmark data and convert it
into a format suitable for LEANN indexing.
"""
def __init__(
self,
mcp_server_command: str,
username: Optional[str] = None,
include_tweet_content: bool = True,
include_metadata: bool = True,
max_bookmarks: int = 1000,
):
"""
Initialize the Twitter MCP Reader.
Args:
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
username: Optional Twitter username to filter bookmarks
include_tweet_content: Whether to include full tweet content
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
max_bookmarks: Maximum number of bookmarks to fetch
"""
self.mcp_server_command = mcp_server_command
self.username = username
self.include_tweet_content = include_tweet_content
self.include_metadata = include_metadata
self.max_bookmarks = max_bookmarks
self.mcp_process = None
async def start_mcp_server(self):
"""Start the MCP server process."""
try:
self.mcp_process = await asyncio.create_subprocess_exec(
*self.mcp_server_command.split(),
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info(f"Started MCP server: {self.mcp_server_command}")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}")
raise
async def stop_mcp_server(self):
"""Stop the MCP server process."""
if self.mcp_process:
self.mcp_process.terminate()
await self.mcp_process.wait()
logger.info("Stopped MCP server")
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
"""Send a request to the MCP server and get response."""
if not self.mcp_process:
raise RuntimeError("MCP server not started")
request_json = json.dumps(request) + "\n"
self.mcp_process.stdin.write(request_json.encode())
await self.mcp_process.stdin.drain()
response_line = await self.mcp_process.stdout.readline()
if not response_line:
raise RuntimeError("No response from MCP server")
return json.loads(response_line.decode().strip())
async def initialize_mcp_connection(self):
"""Initialize the MCP connection."""
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
},
}
response = await self.send_mcp_request(init_request)
if "error" in response:
raise RuntimeError(f"MCP initialization failed: {response['error']}")
logger.info("MCP connection initialized successfully")
async def list_available_tools(self) -> list[dict[str, Any]]:
"""List available tools from the MCP server."""
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
response = await self.send_mcp_request(list_request)
if "error" in response:
raise RuntimeError(f"Failed to list tools: {response['error']}")
return response.get("result", {}).get("tools", [])
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
"""
Fetch Twitter bookmarks using MCP tools.
Args:
limit: Maximum number of bookmarks to fetch
Returns:
List of bookmark dictionaries
"""
tools = await self.list_available_tools()
bookmark_tool = None
# Look for a tool that can fetch bookmarks
for tool in tools:
tool_name = tool.get("name", "").lower()
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
bookmark_tool = tool
break
if not bookmark_tool:
raise RuntimeError("No bookmark fetching tool found in MCP server")
# Prepare tool call parameters
tool_params = {}
if limit or self.max_bookmarks:
tool_params["limit"] = limit or self.max_bookmarks
if self.username:
tool_params["username"] = self.username
fetch_request = {
"jsonrpc": "2.0",
"id": 3,
"method": "tools/call",
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
}
response = await self.send_mcp_request(fetch_request)
if "error" in response:
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
# Extract bookmarks from response
result = response.get("result", {})
if "content" in result and isinstance(result["content"], list):
content = result["content"][0] if result["content"] else {}
if "text" in content:
try:
bookmarks = json.loads(content["text"])
except json.JSONDecodeError:
# If not JSON, treat as plain text
bookmarks = [{"text": content["text"], "source": "twitter"}]
else:
bookmarks = result["content"]
else:
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
"""Format a single bookmark for indexing."""
# Extract tweet information
text = bookmark.get("text", bookmark.get("content", ""))
author = bookmark.get(
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
)
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
url = bookmark.get("url", bookmark.get("tweet_url", ""))
# Extract metadata if available
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
# Build formatted bookmark
parts = []
# Header
parts.append("=== Twitter Bookmark ===")
if author:
parts.append(f"Author: @{author}")
if timestamp:
# Format timestamp if it's a standard format
try:
import datetime
if "T" in str(timestamp): # ISO format
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
else:
formatted_time = str(timestamp)
parts.append(f"Date: {formatted_time}")
except (ValueError, TypeError):
parts.append(f"Date: {timestamp}")
if url:
parts.append(f"URL: {url}")
# Tweet content
if text and self.include_tweet_content:
parts.append("")
parts.append("Content:")
parts.append(text)
# Metadata
if self.include_metadata and any([likes, retweets, replies]):
parts.append("")
parts.append("Engagement:")
if likes:
parts.append(f" Likes: {likes}")
if retweets:
parts.append(f" Retweets: {retweets}")
if replies:
parts.append(f" Replies: {replies}")
# Extract hashtags and mentions if available
hashtags = bookmark.get("hashtags", [])
mentions = bookmark.get("mentions", [])
if hashtags or mentions:
parts.append("")
if hashtags:
parts.append(f"Hashtags: {', '.join(hashtags)}")
if mentions:
parts.append(f"Mentions: {', '.join(mentions)}")
return "\n".join(parts)
async def read_twitter_bookmarks(self) -> list[str]:
"""
Read Twitter bookmark data and return formatted text chunks.
Returns:
List of formatted text chunks ready for LEANN indexing
"""
try:
await self.start_mcp_server()
await self.initialize_mcp_connection()
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
if self.username:
print(f"Filtering for user: @{self.username}")
bookmarks = await self.fetch_twitter_bookmarks()
if not bookmarks:
print("No bookmarks found")
return []
print(f"Processing {len(bookmarks)} bookmarks...")
all_texts = []
processed_count = 0
for bookmark in bookmarks:
try:
formatted_bookmark = self._format_bookmark(bookmark)
if formatted_bookmark.strip():
all_texts.append(formatted_bookmark)
processed_count += 1
except Exception as e:
logger.warning(f"Failed to format bookmark: {e}")
continue
print(f"Successfully processed {processed_count} bookmarks")
return all_texts
finally:
await self.stop_mcp_server()
async def __aenter__(self):
"""Async context manager entry."""
await self.start_mcp_server()
await self.initialize_mcp_connection()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.stop_mcp_server()

View File

@@ -1,195 +0,0 @@
#!/usr/bin/env python3
"""
Twitter RAG Application with MCP Support
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
Usage:
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
"""
import argparse
import asyncio
from apps.base_rag_example import BaseRAGExample
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
class TwitterMCPRAG(BaseRAGExample):
"""
RAG application for Twitter bookmarks via MCP servers.
This class provides a complete RAG pipeline for Twitter bookmark data, including
MCP server connection, data fetching, indexing, and interactive chat.
"""
def __init__(self):
super().__init__(
name="Twitter MCP RAG",
description="RAG application for Twitter bookmarks via MCP servers",
default_index_name="twitter_bookmarks",
)
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add Twitter MCP-specific arguments."""
parser.add_argument(
"--mcp-server",
type=str,
required=True,
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
)
parser.add_argument(
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
)
parser.add_argument(
"--max-bookmarks",
type=int,
default=1000,
help="Maximum number of bookmarks to fetch (default: 1000)",
)
parser.add_argument(
"--no-tweet-content",
action="store_true",
help="Exclude tweet content, only include metadata",
)
parser.add_argument(
"--no-metadata",
action="store_true",
help="Exclude engagement metadata (likes, retweets, etc.)",
)
parser.add_argument(
"--test-connection",
action="store_true",
help="Test MCP server connection and list available tools without indexing",
)
async def test_mcp_connection(self, args) -> bool:
"""Test the MCP server connection and display available tools."""
print(f"Testing connection to MCP server: {args.mcp_server}")
try:
reader = TwitterMCPReader(
mcp_server_command=args.mcp_server,
username=args.username,
include_tweet_content=not args.no_tweet_content,
include_metadata=not args.no_metadata,
max_bookmarks=args.max_bookmarks,
)
async with reader:
tools = await reader.list_available_tools()
print("\n✅ Successfully connected to MCP server!")
print(f"Available tools ({len(tools)}):")
for i, tool in enumerate(tools, 1):
name = tool.get("name", "Unknown")
description = tool.get("description", "No description available")
print(f"\n{i}. {name}")
print(
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
)
# Show input schema if available
schema = tool.get("inputSchema", {})
if schema.get("properties"):
props = list(schema["properties"].keys())[:3] # Show first 3 properties
print(
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
)
return True
except Exception as e:
print(f"\n❌ Failed to connect to MCP server: {e}")
print("\nTroubleshooting tips:")
print("1. Make sure the Twitter MCP server is installed and accessible")
print("2. Check if the server command is correct")
print("3. Ensure you have proper Twitter API credentials configured")
print("4. Verify your Twitter account has bookmarks to fetch")
print("5. Try running the MCP server command directly to test it")
return False
async def load_data(self, args) -> list[str]:
"""Load Twitter bookmarks via MCP server."""
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
if args.username:
print(f"Username filter: @{args.username}")
print(f"Max bookmarks: {args.max_bookmarks}")
print(f"Include tweet content: {not args.no_tweet_content}")
print(f"Include metadata: {not args.no_metadata}")
try:
reader = TwitterMCPReader(
mcp_server_command=args.mcp_server,
username=args.username,
include_tweet_content=not args.no_tweet_content,
include_metadata=not args.no_metadata,
max_bookmarks=args.max_bookmarks,
)
texts = await reader.read_twitter_bookmarks()
if not texts:
print("❌ No bookmarks found! This could mean:")
print("- You don't have any bookmarks on Twitter")
print("- The MCP server couldn't access your bookmarks")
print("- Authentication issues with Twitter API")
print("- The username filter didn't match any bookmarks")
return []
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
# Show sample of what was loaded
if texts:
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
print("\nSample bookmark:")
print("-" * 50)
print(sample_text)
print("-" * 50)
return texts
except Exception as e:
print(f"❌ Error loading Twitter bookmarks: {e}")
print("\nThis might be due to:")
print("- MCP server connection issues")
print("- Twitter API authentication problems")
print("- Network connectivity issues")
print("- Rate limiting from Twitter API")
raise
async def run(self):
"""Main entry point with MCP connection testing."""
args = self.parser.parse_args()
# Test connection if requested
if args.test_connection:
success = await self.test_mcp_connection(args)
if not success:
return
print(
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
)
return
# Run the standard RAG pipeline
await super().run()
async def main():
"""Main entry point for the Twitter MCP RAG application."""
app = TwitterMCPRAG()
await app.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,189 +0,0 @@
"""
WeChat History RAG example using the unified interface.
Supports WeChat chat history export and search.
"""
import subprocess
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from .history_data.wechat_history import WeChatHistoryReader
class WeChatRAG(BaseRAGExample):
"""RAG example for WeChat chat history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Match original default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="WeChat History",
description="Process and query WeChat chat history with LEANN",
default_index_name="wechat_history_magic_test_11Debug_new",
)
def _add_specific_arguments(self, parser):
"""Add WeChat-specific arguments."""
wechat_group = parser.add_argument_group("WeChat Parameters")
wechat_group.add_argument(
"--export-dir",
type=str,
default="./wechat_export",
help="Directory to store WeChat exports (default: ./wechat_export)",
)
wechat_group.add_argument(
"--force-export",
action="store_true",
help="Force re-export of WeChat data even if exports exist",
)
wechat_group.add_argument(
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
)
wechat_group.add_argument(
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
)
def _export_wechat_data(self, export_dir: Path) -> bool:
"""Export WeChat data using wechattweak-cli."""
print("Exporting WeChat data...")
# Check if WeChat is running
try:
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
if result.returncode != 0:
print("WeChat is not running. Please start WeChat first.")
return False
except Exception:
pass # pgrep might not be available on all systems
# Create export directory
export_dir.mkdir(parents=True, exist_ok=True)
# Run export command
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
try:
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("WeChat data exported successfully!")
return True
else:
print(f"Export failed: {result.stderr}")
return False
except FileNotFoundError:
print("\nError: wechattweak-cli not found!")
print("Please install it first:")
print(" sudo packages/wechat-exporter/wechattweak-cli install")
return False
except Exception as e:
print(f"Export error: {e}")
return False
async def load_data(self, args) -> list[str]:
"""Load WeChat history and convert to text chunks."""
# Initialize WeChat reader with export capabilities
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Trying to find any existing exports...")
# Try to find any existing exports in common locations
export_dirs = reader.find_wechat_export_dirs()
if not export_dirs:
print("No WeChat data found. Please ensure WeChat exports exist.")
return []
# Load documents from all found export directories
all_documents = []
total_processed = 0
for i, export_dir in enumerate(export_dirs):
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
try:
# Apply max_items limit per export
max_per_export = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_export = remaining
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_per_export,
concatenate_messages=True, # Enable message concatenation for better context
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return []
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks with contact information
all_texts = []
for doc in all_documents:
# Split the document into chunks
from llama_index.core.node_parser import SentenceSplitter
text_splitter = SentenceSplitter(
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
# Add contact information to each chunk
contact_name = doc.metadata.get("contact_name", "Unknown")
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: WeChat export is only supported on macOS")
print(" You can still query existing exports on other platforms\n")
# Example queries for WeChat RAG
print("\n💬 WeChat History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'Show me conversations about travel plans'")
print("- 'Find group chats about weekend activities'")
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
print("- 'What did we discuss about the project last month?'")
print("\nNote: WeChat must be running for export to work\n")
rag = WeChatRAG()
asyncio.run(rag.run())

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 73 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 224 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 152 KiB

View File

View File

@@ -1,148 +0,0 @@
import argparse
import os
import time
from pathlib import Path
from leann import LeannBuilder, LeannSearcher
def _meta_exists(index_path: str) -> bool:
p = Path(index_path)
return (p.parent / f"{p.stem}.meta.json").exists()
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
# if _meta_exists(index_path):
# return
kwargs = {}
if backend_name == "hnsw":
kwargs["is_compact"] = is_recompute
builder = LeannBuilder(
backend_name=backend_name,
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
graph_degree=32,
complexity=64,
is_recompute=is_recompute,
num_threads=4,
**kwargs,
)
for i in range(num_docs):
builder.add_text(
f"This is a test document number {i}. It contains some repeated text for benchmarking."
)
builder.build_index(index_path)
def _bench_group(
index_path: str,
recompute: bool,
query: str,
repeats: int,
complexity: int = 32,
top_k: int = 10,
) -> float:
# Independent searcher per group; fixed port when recompute
searcher = LeannSearcher(index_path=index_path)
# Warm-up once
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
def _once() -> float:
t0 = time.time()
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
return time.time() - t0
if repeats <= 1:
t = _once()
else:
vals = [_once() for _ in range(repeats)]
vals.sort()
t = vals[len(vals) // 2]
searcher.cleanup()
return t
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-docs", type=int, default=5000)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--complexity", type=int, default=32)
args = parser.parse_args()
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
base.parent.mkdir(parents=True, exist_ok=True)
# ---------- Build HNSW variants ----------
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
# ---------- Build DiskANN variants ----------
diskann_r = str(base / "diskann_r.leann")
diskann_nr = str(base / "diskann_nr.leann")
ensure_index(diskann_r, "diskann", args.num_docs, True)
ensure_index(diskann_nr, "diskann", args.num_docs, False)
# ---------- Helpers ----------
def _size_for(prefix: str) -> int:
p = Path(prefix)
base_dir = p.parent
stem = p.stem
total = 0
for f in base_dir.iterdir():
if f.is_file() and f.name.startswith(stem):
total += f.stat().st_size
return total
# ---------- HNSW benchmark ----------
t_hnsw_r = _bench_group(
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
t_hnsw_nr = _bench_group(
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
size_hnsw_r = _size_for(hnsw_r)
size_hnsw_nr = _size_for(hnsw_nr)
print("Benchmark results (HNSW):")
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
print(
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
)
print(" Expectation: no-recompute should be faster but larger on disk.")
# ---------- DiskANN benchmark ----------
t_diskann_r = _bench_group(
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
)
t_diskann_nr = _bench_group(
diskann_nr,
False,
"DiskANN NR test doc 123",
repeats=args.repeats,
complexity=args.complexity,
)
size_diskann_r = _size_for(diskann_r)
size_diskann_nr = _size_for(diskann_nr)
print("\nBenchmark results (DiskANN):")
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
if __name__ == "__main__":
main()

View File

@@ -1,23 +0,0 @@
BM25 vs DiskANN Baselines
```bash
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
```
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
- Machine-specific; results measured locally with the current repo.
DiskANN (NQ queries, search-only)
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
BM25
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
Notes
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.

Before

Width:  |  Height:  |  Size: 1.3 KiB

View File

@@ -1,183 +0,0 @@
# /// script
# dependencies = [
# "pyserini"
# ]
# ///
# sudo pacman -S jdk21-openjdk
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
# sudo archlinux-java status
# sudo archlinux-java set java-21-openjdk
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
# fish_add_path --global $JAVA_HOME/bin
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
import argparse
import json
import os
import sys
import time
from statistics import mean
def load_queries(path: str, limit: int | None) -> list[str]:
queries: list[str] = []
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
_, ext = os.path.splitext(path)
if ext.lower() in {".jsonl", ".json"}:
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# Not strict JSONL? treat the whole line as the query
queries.append(line)
continue
q = obj.get("query") or obj.get("text") or obj.get("question")
if q:
queries.append(str(q))
else:
with open(path, encoding="utf-8") as f:
for line in f:
s = line.strip()
if s:
queries.append(s)
if limit is not None and limit > 0:
queries = queries[:limit]
return queries
def percentile(values: list[float], p: float) -> float:
if not values:
return 0.0
s = sorted(values)
k = (len(s) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(s) - 1)
if f == c:
return s[f]
return s[f] + (s[c] - s[f]) * (k - f)
def main():
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
ap.add_argument(
"--bm25-index",
default="benchmarks/data/indices/bm25_index",
help="Path to Pyserini Lucene index directory",
)
ap.add_argument(
"--queries",
default="benchmarks/data/queries/nq_open.jsonl",
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
)
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
ap.add_argument(
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
)
ap.add_argument(
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
)
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
args = ap.parse_args()
try:
from pyserini.search.lucene import LuceneSearcher
except Exception:
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
raise
if not os.path.isdir(args.bm25_index):
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
sys.exit(1)
queries = load_queries(args.queries, args.limit)
if not queries:
print("No queries loaded.", file=sys.stderr)
sys.exit(1)
print(f"Loaded {len(queries)} queries from {args.queries}")
print(f"Opening BM25 index: {args.bm25_index}")
searcher = LuceneSearcher(args.bm25_index)
# Some builds of pyserini require explicit set_bm25; others ignore
try:
searcher.set_bm25(k1=args.k1, b=args.b)
except Exception:
pass
latencies: list[float] = []
total_searches = 0
# Warmup
for i in range(min(args.warmup, len(queries))):
_ = searcher.search(queries[i], k=args.k)
t0 = time.time()
for i, q in enumerate(queries):
t1 = time.time()
hits = searcher.search(q, k=args.k)
t2 = time.time()
latencies.append(t2 - t1)
total_searches += 1
if args.fetch_docs:
# Optional doc fetch to include I/O time
for h in hits:
try:
_ = searcher.doc(h.docid)
except Exception:
pass
if (i + 1) % 50 == 0:
print(f"Processed {i + 1}/{len(queries)} queries")
t1 = time.time()
total_time = t1 - t0
if latencies:
avg = mean(latencies)
p50 = percentile(latencies, 50)
p90 = percentile(latencies, 90)
p95 = percentile(latencies, 95)
p99 = percentile(latencies, 99)
qps = total_searches / total_time if total_time > 0 else 0.0
else:
avg = p50 = p90 = p95 = p99 = qps = 0.0
print("BM25 Latency Report")
print(f" queries: {total_searches}")
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
print(f" avg per query: {avg:.6f} s")
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
if args.report:
payload = {
"queries": total_searches,
"k": args.k,
"k1": args.k1,
"b": args.b,
"avg_s": avg,
"p50_s": p50,
"p90_s": p90,
"p95_s": p95,
"p99_s": p99,
"total_time_s": total_time,
"qps": qps,
"index_dir": os.path.abspath(args.bm25_index),
"fetch_docs": bool(args.fetch_docs),
}
with open(args.report, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
print(f"Saved report to {args.report}")
if __name__ == "__main__":
main()

View File

@@ -1,124 +0,0 @@
# /// script
# dependencies = [
# "leann-backend-diskann"
# ]
# ///
import argparse
import json
import time
from pathlib import Path
import numpy as np
def load_queries(path: Path, limit: int | None) -> list[str]:
out: list[str] = []
with open(path, encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
out.append(obj["query"])
if limit and len(out) >= limit:
break
return out
def main() -> None:
ap = argparse.ArgumentParser(
description="DiskANN baseline on real NQ queries (search-only timing)"
)
ap.add_argument(
"--index-dir",
default="benchmarks/data/indices/diskann_rpj_wiki",
help="Directory containing DiskANN files",
)
ap.add_argument("--index-prefix", default="ann")
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
ap.add_argument("--num-queries", type=int, default=200)
ap.add_argument("--top-k", type=int, default=10)
ap.add_argument("--complexity", type=int, default=62)
ap.add_argument("--threads", type=int, default=1)
ap.add_argument("--beam-width", type=int, default=1)
ap.add_argument("--cache-mechanism", type=int, default=2)
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
args = ap.parse_args()
index_dir = Path(args.index_dir).resolve()
if not index_dir.is_dir():
raise SystemExit(f"Index dir not found: {index_dir}")
qpath = Path(args.queries_file).resolve()
if not qpath.exists():
raise SystemExit(f"Queries file not found: {qpath}")
queries = load_queries(qpath, args.num_queries)
print(f"Loaded {len(queries)} queries from {qpath}")
# Compute embeddings once (exclude from timing)
from leann.api import compute_embeddings as _compute
embs = _compute(
queries,
model_name="facebook/contriever-msmarco",
mode="sentence-transformers",
use_server=False,
).astype(np.float32)
if embs.ndim != 2:
raise SystemExit("Embedding compute failed or returned wrong shape")
# Build searcher
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
index_prefix_path = str(index_dir / args.index_prefix)
searcher = _DiskannSearcher(
index_prefix_path,
num_threads=int(args.threads),
cache_mechanism=int(args.cache_mechanism),
num_nodes_to_cache=int(args.num_nodes_to_cache),
)
# Warmup (not timed)
_ = searcher.search(
embs[0:1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
# Timed loop
times: list[float] = []
for i in range(embs.shape[0]):
t0 = time.time()
_ = searcher.search(
embs[i : i + 1],
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=0.0,
recompute_embeddings=False,
batch_recompute=False,
dedup_node_dis=False,
)
times.append(time.time() - t0)
times_sorted = sorted(times)
avg = float(sum(times) / len(times))
p50 = times_sorted[len(times) // 2]
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
print("\nDiskANN (NQ, search-only) Report")
print(f" queries: {len(times)}")
print(
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
)
print(f" avg per query: {avg:.6f} s")
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
print(f" QPS: {1.0 / avg:.2f}")
if __name__ == "__main__":
main()

View File

@@ -1,286 +0,0 @@
#!/usr/bin/env python3
"""
DiskANN vs HNSW Search Performance Comparison
This benchmark compares search performance between DiskANN and HNSW backends:
- DiskANN: With graph partitioning enabled (is_recompute=True)
- HNSW: With recompute enabled (is_recompute=True)
- Tests performance across different dataset sizes
- Measures search latency, recall, and index size
"""
import gc
import multiprocessing as mp
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
try:
mp.set_start_method("fork", force=True)
except Exception:
pass
def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking."""
np.random.seed(42)
topics = [
"machine learning and artificial intelligence",
"natural language processing and text analysis",
"computer vision and image recognition",
"data science and statistical analysis",
"deep learning and neural networks",
"information retrieval and search engines",
"database systems and data management",
"software engineering and programming",
"cybersecurity and network protection",
"cloud computing and distributed systems",
]
texts = []
for i in range(n_docs):
topic = topics[i % len(topics)]
variation = np.random.randint(1, 100)
text = (
f"This is document {i} about {topic}. Content variation {variation}. "
f"Additional information about {topic} with details and examples. "
f"Technical discussion of {topic} including implementation aspects."
)
texts.append(text)
return texts
def benchmark_backend(
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
) -> dict[str, float]:
"""Benchmark a specific backend with the given configuration."""
from leann.api import LeannBuilder, LeannSearcher
print(f"\n🔧 Testing {backend_name.upper()} backend...")
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
# Build index
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
start_time = time.time()
builder = LeannBuilder(
backend_name=backend_name,
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
**backend_kwargs,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
build_time = time.time() - start_time
# Measure index size
index_dir = Path(index_path).parent
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
size_mb = total_size / (1024 * 1024)
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
# Search benchmark
print("🔍 Running search benchmark...")
searcher = LeannSearcher(index_path)
search_times = []
all_results = []
for query in test_queries:
start_time = time.time()
results = searcher.search(query, top_k=5)
search_time = time.time() - start_time
search_times.append(search_time)
all_results.append(results)
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
# Check for valid scores (detect -inf issues)
all_scores = [
result.score
for results in all_results
for result in results
if result.score is not None
]
valid_scores = [
score for score in all_scores if score != float("-inf") and score != float("inf")
]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up (ensure embedding server shutdown and object GC)
try:
if hasattr(searcher, "cleanup"):
searcher.cleanup()
del searcher
del builder
gc.collect()
except Exception as e:
print(f"⚠️ Warning: Resource cleanup error: {e}")
return {
"build_time": build_time,
"avg_search_time_ms": avg_search_time,
"index_size_mb": size_mb,
"score_validity_rate": score_validity_rate,
}
def run_comparison(n_docs: int = 500, n_queries: int = 10):
"""Run performance comparison between DiskANN and HNSW."""
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
# Create test data
texts = create_test_texts(n_docs)
test_queries = [
"machine learning algorithms",
"natural language processing",
"computer vision techniques",
"data analysis methods",
"neural network architectures",
"database query optimization",
"software development practices",
"security vulnerabilities",
"cloud infrastructure",
"distributed computing",
][:n_queries]
# HNSW benchmark
hnsw_results = benchmark_backend(
backend_name="hnsw",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable recompute for fair comparison
"M": 16,
"efConstruction": 200,
},
)
# DiskANN benchmark
diskann_results = benchmark_backend(
backend_name="diskann",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable graph partitioning
"num_neighbors": 32,
"search_list_size": 50,
},
)
# Performance comparison
print("\n📈 Performance Comparison Results")
print(f"{'=' * 60}")
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
print(f"{'-' * 60}")
# Build time comparison
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
print(
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
)
# Search time comparison
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
print(
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
)
# Index size comparison
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
print(
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
)
# Score validity
print(
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
)
print(f"{'=' * 60}")
print("\n🎯 Summary:")
if search_speedup > 1:
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
else:
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
if size_ratio > 1:
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
else:
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
print(
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
)
if __name__ == "__main__":
import sys
try:
# Handle help request
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
print()
print("Arguments:")
print(" n_docs Number of documents to index (default: 500)")
print(" n_queries Number of test queries to run (default: 10)")
print()
print("Examples:")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
sys.exit(0)
# Parse command line arguments
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Dataset: {n_docs} documents, {n_queries} queries")
print()
run_comparison(n_docs=n_docs, n_queries=n_queries)
except KeyboardInterrupt:
print("\n⚠️ Benchmark interrupted by user")
sys.exit(130)
except Exception as e:
print(f"\n❌ Benchmark failed: {e}")
sys.exit(1)
finally:
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
try:
gc.collect()
print("\n🧹 Cleanup completed")
# Flush stdio to ensure message is visible before hard-exit
try:
import sys as _sys
_sys.stdout.flush()
_sys.stderr.flush()
except Exception:
pass
except Exception:
pass
# Use os._exit to bypass atexit handlers that may hang in rare cases
import os as _os
_os._exit(0)

View File

@@ -1,141 +0,0 @@
# Enron Emails Benchmark
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
## Layout
benchmarks/enron_emails/
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
- data/: Generated passages, queries, embeddings-related files
- baseline/: FAISS Flat baseline files
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
## Quickstart
1) Prepare the data and index
cd benchmarks/enron_emails
python setup_enron_emails.py --data-dir data
Notes:
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
Alternatively, pass a local path to `--emails-csv`.
Notes:
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
2) Run recall evaluation (Stage 2)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
3) Complexity sweep (Stage 3)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
4) Index comparison (Stage 4)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
5) Generation evaluation (Stage 5)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
6) Combined index + generation evaluation (Stages 4+5, recommended)
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
Notes:
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
- `--baseline-dir` defaults to `baseline`
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
- `--model-name` defaults to `Qwen/Qwen3-8B`
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
Optional flags:
- --queries data/evaluation_queries.jsonl (custom queries file)
- --baseline-dir baseline (where FAISS baseline lives)
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
- --llm-backend hf|vllm (LLM backend for generation)
- --model-name Qwen/Qwen3-8B (LLM model for generation)
- --max-queries 1000 (limit number of queries for evaluation)
## Files Produced
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
- data/enron_index_hnsw.leann.*: LEANN index files
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
## Notes
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
## Stages Summary
- Stage 2 (Recall@3):
- Compares LEANN vs FAISS Flat baseline on Recall@3.
- Compact index runs with `recompute_embeddings=True`.
- Stage 3 (Binary Search for Complexity):
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
- Stage 4 (Index Comparison):
- Reports .index-only sizes for compact vs non-compact.
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
- Stores retrieval results for Stage 5 generation evaluation.
- Fails fast if compact recompute cannot run.
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
- First from the current run (when running `--stage all`), otherwise
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
- Stage 5 (Generation Evaluation):
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
- Measures generation timing separately from search timing.
- Requires Stage 4 results (no additional searching performed).
## Example Results
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
- Stage 3 (Binary Search):
- Minimal complexity achieving 90% Recall@3: 88
- Sampled points:
- C=8 → 59.9% Recall@3
- C=72 → 89.4% Recall@3
- C=88 → 90.2% Recall@3
- C=96 → 90.7% Recall@3
- C=112 → 91.1% Recall@3
- C=136 → 91.3% Recall@3
- C=256 → 92.0% Recall@3
- Stage 4 (Index Sizes, .index only):
- Compact: ~2.2 MB
- Non-compact: ~82.0 MB
- Storage saving by compact: ~97.3%
- Stage 4 (Search Timing, 988 queries, complexity=88):
- Non-compact (no recompute): ~0.0075 s avg per query
- Compact (with recompute): ~1.981 s avg per query
- Speed ratio (non-compact/compact): ~0.0038x
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
- Average generation time: ~22.302 s per query
- Total queries processed: 988
- LLM backend: HuggingFace transformers
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
Full JSON output is saved by the script (see `--output`), e.g.:
`benchmarks/enron_emails/results_enron_stage45.json`.

View File

@@ -1 +0,0 @@
downloads/

View File

@@ -1,614 +0,0 @@
"""
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
On errors, fail fast without fallbacks.
"""
import argparse
import json
import logging
import os
import pickle
from pathlib import Path
import numpy as np
from leann import LeannBuilder, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 using FAISS Flat as ground truth"""
from leann.api import compute_embeddings
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
for i, query in enumerate(queries):
# Compute query embedding with the same model/mode as the index
q_emb = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS Flat ground truth
n = q_emb.shape[0]
k = 3
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(q_emb),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN (may require embedding server depending on index configuration)
results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {r.id for r in results}
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0
total_recall += recall
if i < 3:
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS: {list(baseline_ids)}")
print(f" LEANN: {list(test_ids)}")
print(f" ∩: {list(intersection)}")
avg = total_recall / max(1, len(queries))
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
return avg
def cleanup(self):
if hasattr(self, "searcher"):
self.searcher.cleanup()
class EnronEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
queries: list[str] = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
if "query" in data:
queries.append(data["query"])
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
return queries
def cleanup(self):
if self.searcher:
self.searcher.cleanup()
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes (.index only), similar to LAION bench."""
print("📏 Analyzing index sizes (.index only)...")
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem
sizes: dict[str, float] = {}
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json"
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
sizes["index_only_mb"] = (
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison using current passages and embeddings."""
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source and embedding model
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
# Load all passages and ids
ids: list[str] = []
texts: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
texts.append(data["text"])
# Compute embeddings using the same method as LEANN
from leann.api import compute_embeddings
embeddings = compute_embeddings(
texts,
meta["embedding_model"],
mode=meta.get("embedding_mode", "sentence-transformers"),
use_server=False,
).astype(np.float32)
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False,
is_compact=False,
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = EnronEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
) -> dict:
"""Compare search speed for non-compact vs compact indexes."""
import time
results: dict = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
"retrieval_results": [], # Store retrieval results for Stage 5
}
print("⚡ Comparing search performance between indexes...")
# Non-compact (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for q in test_queries:
t0 = time.time()
_ = non_compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=False
)
results["non_compact"]["search_times"].append(time.time() - t0)
# Compact (with recompute). Fail fast if it cannot run.
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for q in test_queries:
t0 = time.time()
docs = compact_searcher.search(
q, top_k=3, complexity=complexity, recompute_embeddings=True
)
results["compact"]["search_times"].append(time.time() - t0)
# Store retrieval results for Stage 5
results["retrieval_results"].append(
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
)
compact_searcher.cleanup()
if results["non_compact"]["search_times"]:
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
if results["compact"]["search_times"]:
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
if results["avg_search_times"].get("compact", 0) > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = 0.0
non_compact_searcher.cleanup()
return results
def evaluate_complexity(
self,
recall_eval: "RecallEvaluator",
queries: list[str],
target: float = 0.90,
c_min: int = 8,
c_max: int = 256,
max_iters: int = 10,
recompute: bool = False,
) -> dict:
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
def round_c(x: int) -> int:
# snap to multiple of 8 like other benches typically do
return max(1, int((x + 7) // 8) * 8)
metrics: list[dict] = []
lo = round_c(c_min)
hi = round_c(c_max)
print(
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
)
# Ensure upper bound can reach target; expand if needed (up to a cap)
r_lo = recall_eval.evaluate_recall_at_3(
queries, complexity=lo, recompute_embeddings=recompute
)
metrics.append({"complexity": lo, "recall_at_3": r_lo})
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
cap = 1024
while r_hi < target and hi < cap:
lo = hi
r_lo = r_hi
hi = round_c(hi * 2)
r_hi = recall_eval.evaluate_recall_at_3(
queries, complexity=hi, recompute_embeddings=recompute
)
metrics.append({"complexity": hi, "recall_at_3": r_hi})
if r_hi < target:
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
print("📈 Observations:")
for m in metrics:
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
# Binary search within [lo, hi]
best = hi
iters = 0
while lo < hi and iters < max_iters:
mid = round_c((lo + hi) // 2)
r_mid = recall_eval.evaluate_recall_at_3(
queries, complexity=mid, recompute_embeddings=recompute
)
metrics.append({"complexity": mid, "recall_at_3": r_mid})
if r_mid >= target:
best = mid
hi = mid
else:
lo = mid + 8 # move past mid, respecting multiple-of-8 step
iters += 1
print("📈 Binary search results (sampled points):")
# Print unique complexity entries ordered by complexity
for m in sorted(
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
):
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
def main():
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all", "45"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument(
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
)
parser.add_argument(
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
)
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
args = parser.parse_args()
# Resolve queries file: if default path not found, fall back to index's directory
if not os.path.exists(args.queries):
from pathlib import Path
idx_dir = Path(args.index).parent
fallback_q = idx_dir / "evaluation_queries.jsonl"
if fallback_q.exists():
args.queries = str(fallback_q)
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_enron_emails.py first to build the baseline")
raise SystemExit(1)
results_out: dict = {}
if args.stage in ("2", "all"):
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[:10]
print(f"🧪 Using first {len(queries)} queries")
complexity = args.complexity or 64
r = evaluator.evaluate_recall_at_3(queries, complexity)
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 2 completed!\n")
if args.stage in ("3", "all"):
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
queries = queries[: args.max_queries]
print(f"🧪 Using first {len(queries)} queries")
# Build non-compact index for fast binary search (recompute_embeddings=False)
from pathlib import Path
index_path = Path(args.index)
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact evaluator for binary search with recompute=False
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
sweep = enron_eval.evaluate_complexity(
evaluator_nc, queries, target=args.target_recall, recompute=False
)
results_out["stage3"] = sweep
# Persist default stage 3 results near the index for Stage 4 auto-pickup
from pathlib import Path
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
with open(default_stage3_path, "w", encoding="utf-8") as f:
json.dump({"stage3": sweep}, f, indent=2)
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
evaluator_nc.cleanup()
enron_eval.cleanup()
print("✅ Stage 3 completed!\n")
if args.stage in ("4", "all", "45"):
print("🚀 Starting Stage 4: Index size + performance comparison")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
enron_eval = EnronEvaluator(args.index)
queries = enron_eval.load_queries(args.queries)
test_q = queries[: min(args.max_queries, len(queries))]
current_sizes = enron_eval.analyze_index_sizes()
# Build non-compact index for comparison (no fallback)
from pathlib import Path
index_path = Path(args.index)
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
nc_eval = EnronEvaluator(non_compact_path)
if (
current_sizes.get("index_only_mb", 0) > 0
and non_compact_sizes.get("index_only_mb", 0) > 0
):
storage_saving_percent = max(
0.0,
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
)
else:
storage_saving_percent = 0.0
if args.complexity is None:
# Prefer in-session Stage 3 result
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
complexity = results_out["stage3"]["best_complexity"]
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
else:
# Try to load last saved Stage 3 result near index
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
if default_stage3_path.exists():
with open(default_stage3_path, encoding="utf-8") as f:
prev = json.load(f)
complexity = prev.get("stage3", {}).get("best_complexity")
if complexity is None:
raise SystemExit(
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
)
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
else:
raise SystemExit(
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
)
else:
complexity = args.complexity
comp = enron_eval.compare_index_performance(
non_compact_path, args.index, test_q, complexity=complexity
)
results_out["stage4"] = {
"current_index": current_sizes,
"non_compact_index": non_compact_sizes,
"storage_saving_percent": storage_saving_percent,
"performance_comparison": comp,
}
nc_eval.cleanup()
evaluator.cleanup()
enron_eval.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
# Check if Stage 4 results exist
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
print("❌ Stage 5 requires Stage 4 retrieval results")
print("💡 Run Stage 4 first or use --stage all")
raise SystemExit(1)
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
if not retrieval_results:
print("❌ No retrieval results found from Stage 4")
raise SystemExit(1)
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
# Load LLM
try:
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Run generation using stored retrieval results
import time
from llm_utils import create_prompt
generation_times = []
responses = []
print("🤖 Running generation on pre-retrieved results...")
for i, item in enumerate(retrieval_results):
query = item["query"]
retrieved_docs = item["retrieved_docs"]
# Prepare context from retrieved docs
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
prompt = create_prompt(context, query, "emails")
# Time generation only
gen_start = time.time()
response = llm_func(prompt)
gen_time = time.time() - gen_start
generation_times.append(gen_time)
responses.append(response)
if i < 3:
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
avg_gen_time = sum(generation_times) / len(generation_times)
print("\n📊 Generation Results:")
print(f" Total Queries: {len(retrieval_results)}")
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
print(" (Search time from Stage 4)")
results_out["stage5"] = {
"total_queries": len(retrieval_results),
"avg_generation_time": avg_gen_time,
"generation_times": generation_times,
"responses": responses,
}
# Show sample results
print("\n📝 Sample Results:")
for i in range(min(3, len(retrieval_results))):
query = retrieval_results[i]["query"]
response = responses[i]
print(f" Q{i + 1}: {query[:60]}...")
print(f" A{i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
print("💡 Make sure transformers/vllm is installed and model is available")
print("✅ Stage 5 completed!\n")
if args.output and results_out:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(results_out, f, indent=2)
print(f"📝 Saved results to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -1,359 +0,0 @@
"""
Enron Emails Benchmark Setup Script
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
"""
import argparse
import csv
import json
import os
import re
from collections.abc import Iterable
from email import message_from_string
from email.policy import default
from pathlib import Path
from typing import Optional
from leann import LeannBuilder
class EnronSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
self.index_path = self.data_dir / "enron_index_hnsw.leann"
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
self.downloads_dir = self.data_dir / "downloads"
self.downloads_dir.mkdir(parents=True, exist_ok=True)
# ----------------------------
# Dataset acquisition
# ----------------------------
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
"""Return a path to emails.csv, downloading from Kaggle if needed."""
if emails_csv:
p = Path(emails_csv)
if not p.exists():
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
return str(p)
print(
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
)
try:
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
api.dataset_download_files(
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
)
candidate = self.downloads_dir / "emails.csv"
if candidate.exists():
print(f"✅ Downloaded emails.csv: {candidate}")
return str(candidate)
else:
raise FileNotFoundError(
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
)
except Exception as e:
print(
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
)
print(
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
)
raise e
# ----------------------------
# Data preparation
# ----------------------------
@staticmethod
def _extract_message_id(raw_email: str) -> str:
msg = message_from_string(raw_email, policy=default)
val = msg.get("Message-ID", "")
if val.startswith("<") and val.endswith(">"):
val = val[1:-1]
return val or ""
@staticmethod
def _split_header_body(raw_email: str) -> tuple[str, str]:
parts = raw_email.split("\n\n", 1)
if len(parts) == 2:
return parts[0].strip(), parts[1].strip()
# Heuristic fallback
first_lines = raw_email.splitlines()
if first_lines and ":" in first_lines[0]:
return raw_email.strip(), ""
return "", raw_email.strip()
@staticmethod
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
text = (text or "").strip()
if not text:
return []
if chunk_words <= 0:
return [text]
words = text.split()
if not words:
return []
limit = len(words)
if not keep_last:
limit = (len(words) // chunk_words) * chunk_words
if limit == 0:
return []
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
return [c for c in (s.strip() for s in chunks) if c]
def _iter_passages_from_csv(
self,
emails_csv: Path,
chunk_words: int = 256,
keep_last_header: bool = True,
keep_last_body: bool = True,
max_emails: int | None = None,
) -> Iterable[dict]:
with open(emails_csv, encoding="utf-8") as f:
reader = csv.DictReader(f)
count = 0
for i, row in enumerate(reader):
if max_emails is not None and count >= max_emails:
break
raw_message = row.get("message", "")
email_file_id = row.get("file", "")
if not raw_message.strip():
continue
message_id = self._extract_message_id(raw_message)
if not message_id:
# Fallback ID based on CSV position and file path
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
message_id = f"enron_{i}_{safe_file}"
header, body = self._split_header_body(raw_message)
# Header chunks
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": True,
"email_file_id": email_file_id,
},
}
# Body chunks
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
yield {
"text": chunk,
"metadata": {
"message_id": message_id,
"is_header": False,
"email_file_id": email_file_id,
},
}
count += 1
# ----------------------------
# Build LEANN index and FAISS baseline
# ----------------------------
def build_leann_index(
self,
emails_csv: Optional[str],
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
chunk_words: int = 256,
max_emails: int | None = None,
) -> str:
emails_csv_path = self.ensure_emails_csv(emails_csv)
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True,
is_compact=True,
num_threads=4,
)
# Stream passages and add to builder
preview_written = 0
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
for p in self._iter_passages_from_csv(
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
):
builder.add_text(p["text"], metadata=p["metadata"])
if preview_written < 200:
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
preview_written += 1
print(f"🔨 Building index at {self.index_path}...")
builder.build_index(str(self.index_path))
print("✅ LEANN index built!")
return str(self.index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
print("🔨 Building FAISS Flat baseline from LEANN passages...")
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read meta for passage source and embedding model
meta_path = f"{index_path}.meta.json"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
# Load passages from builder output so IDs match LEANN
passages: list[str] = []
passage_ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"]) # builder-assigned ID
print(f"📄 Loaded {len(passages)} passages for baseline")
print(f"🤖 Embedding model: {embedding_model}")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
# Build FAISS IndexFlatIP
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
emb_f32 = embeddings.astype(np.float32)
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as pf:
pickle.dump(passage_ids, pf)
print(f"✅ FAISS baseline saved: {baseline_path}")
print(f"✅ Metadata saved: {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
# ----------------------------
# Queries (optional): prepare evaluation queries file
# ----------------------------
def prepare_queries(self, min_realism: float = 0.85) -> Path:
print(
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
)
try:
from datasets import load_dataset
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
except Exception as e:
print(f"⚠️ Failed to load dataset: {e}")
return self.queries_file
kept = 0
with open(self.queries_file, "w", encoding="utf-8") as out:
for i, item in enumerate(ds):
how_realistic = float(item.get("how_realistic", 0.0))
if how_realistic < min_realism:
continue
qid = str(item.get("id", f"enron_q_{i}"))
query = item.get("question", "")
if not query:
continue
record = {
"id": qid,
"query": query,
# For reference only, not used in recall metric below
"gt_message_ids": item.get("message_ids", []),
}
out.write(json.dumps(record) + "\n")
kept += 1
print(f"✅ Wrote {kept} queries to {self.queries_file}")
return self.queries_file
def main():
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
parser.add_argument(
"--emails-csv",
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
)
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model for LEANN",
)
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
args = parser.parse_args()
setup = EnronSetup(args.data_dir)
# Build index
if not args.skip_build:
index_path = setup.build_leann_index(
emails_csv=args.emails_csv,
backend=args.backend,
embedding_model=args.embedding_model,
chunk_words=args.chunk_words,
max_emails=args.max_emails,
)
# Build FAISS baseline from the same passages & embeddings
setup.build_faiss_flat_baseline(index_path)
else:
print("⏭️ Skipping LEANN index build and baseline")
# Queries file (optional)
if not args.skip_queries:
setup.prepare_queries()
else:
print("⏭️ Skipping query preparation")
print("\n🎉 Enron Emails setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("Next steps:")
print(
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
)
if __name__ == "__main__":
main()

View File

@@ -1,115 +0,0 @@
# FinanceBench Benchmark for LEANN-RAG
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
## Dataset
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
- **Questions**: 150 financial Q&A examples
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
## Structure
```
benchmarks/financebench/
├── setup_financebench.py # Downloads PDFs and builds index
├── evaluate_financebench.py # Intelligent evaluation script
├── data/
│ ├── financebench_merged.jsonl # Q&A dataset
│ ├── pdfs/ # Downloaded financial documents
│ └── index/ # LEANN indexes
│ └── financebench_full_hnsw.leann
└── README.md
```
## Usage
### 1. Setup (Download & Build Index)
```bash
cd benchmarks/financebench
python setup_financebench.py
```
This will:
- Download the 150 Q&A examples
- Download all 368 PDF documents (parallel processing)
- Build a LEANN index from 53K+ text chunks
- Verify setup with test query
### 2. Evaluation
```bash
# Basic retrieval evaluation
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
# RAG generation evaluation with Qwen3-8B
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
```
## Evaluation Methods
### Retrieval Evaluation
Uses intelligent matching with three strategies:
1. **Exact text overlap** - Direct substring matches
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
3. **Semantic similarity** - Word overlap with 20% threshold
### QA Evaluation
LLM-based answer evaluation using GPT-4o:
- Handles numerical rounding and equivalent representations
- Considers fractions, percentages, and decimal equivalents
- Evaluates semantic meaning rather than exact text match
## Benchmark Results
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
**Retrieval Metrics:**
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
- **Number Match Rate**: 120.7% (key financial figures matched)*
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
- **Average Search Time**: 0.097s
**QA Metrics:**
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
- **Average QA Time**: 4.71s (end-to-end response time)
**System Performance:**
- **Index Size**: 53,985 chunks from 368 PDFs
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
### LEANN-RAG Generation Performance (Qwen3-8B)
- **Stage 4 (Index Comparison):**
- Compact Index: 5.0 MB
- Non-compact Index: 172.2 MB
- **Storage Saving**: 97.1%
- **Search Performance**:
- Non-compact (no recompute): 0.009s avg per query
- Compact (with recompute): 2.203s avg per query
- Speed ratio: 0.004x
**Generation Evaluation (20 queries, complexity=64):**
- **Average Search Time**: 1.638s per query
- **Average Generation Time**: 45.957s per query
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
- **Total Questions Processed**: 20
## Options
```bash
# Use different backends
python setup_financebench.py --backend diskann
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
# Use different embedding models
python setup_financebench.py --embedding-model facebook/contriever
```

View File

@@ -1,923 +0,0 @@
"""
FinanceBench Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
from typing import Optional
import numpy as np
import openai
from leann import LeannChat, LeannSearcher
from leann_backend_hnsw import faiss
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.passage_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
def evaluate_recall_at_3(
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for given queries at specified complexity"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(queries)
for i, query in enumerate(queries):
# Get ground truth: search with FAISS flat
from leann.api import compute_embeddings
query_embedding = compute_embeddings(
[query],
self.searcher.embedding_model,
mode=self.searcher.embedding_mode,
use_server=False,
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity
test_results = self.searcher.search(
query,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class FinanceBenchEvaluator:
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
self.index_path = index_path
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
self.searcher = LeannSearcher(index_path)
self.chat = LeannChat(index_path) if openai_api_key else None
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
"""Load FinanceBench dataset"""
data = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data.append(json.loads(line))
print(f"📊 Loaded {len(data)} FinanceBench examples")
return data
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes with and without embeddings"""
print("📏 Analyzing index sizes...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes = {}
total_with_embeddings = 0
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
for file_path, name in [
(index_file, "index"),
(meta_file, "metadata"),
(passages_file, "passages_text"),
(passages_idx_file, "passages_index"),
]:
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
sizes[name] = size_mb
total_with_embeddings += size_mb
else:
sizes[name] = 0
sizes["total_with_embeddings"] = total_with_embeddings
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
return sizes
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
"""Create a compact index for comparison purposes"""
print("🏗️ Building compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage
**meta.get("backend_kwargs", {}),
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building compact index at {compact_index_path}...")
builder.build_index(compact_index_path)
# Analyze the compact index size
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
compact_sizes = temp_evaluator.analyze_index_sizes()
compact_sizes["index_type"] = "compact"
return compact_sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
import json
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Build non-compact index with same passages
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=meta["embedding_model"],
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact"]
},
)
# Load all passages
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
builder.add_text(data["text"], metadata=data.get("metadata", {}))
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
builder.build_index(non_compact_index_path)
# Analyze the non-compact index size
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
import time
from leann import LeannSearcher
# Test queries
test_queries = [item["question"] for item in test_data[:5]]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for query in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for query in test_queries:
start_time = time.time()
_ = compact_searcher.search(
query, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def evaluate_timing_breakdown(
self, data: list[dict], max_samples: Optional[int] = None
) -> dict:
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
if not self.chat or not self.openai_client:
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
return {
"total_questions": 0,
"avg_search_time": 0.0,
"avg_generation_time": 0.0,
"avg_total_time": 0.0,
"accuracy": 0.0,
}
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
if max_samples:
data = data[:max_samples]
print(f"📝 Using first {max_samples} samples for timing evaluation")
search_times = []
generation_times = []
total_times = []
correct_answers = 0
for i, item in enumerate(data):
question = item["question"]
ground_truth = item["answer"]
try:
# Hack: Monkey-patch the ask method to capture internal timing
original_ask = self.chat.ask
captured_search_time = None
captured_generation_time = None
def patched_ask(*args, **kwargs):
nonlocal captured_search_time, captured_generation_time
# Time the search part
search_start = time.time()
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
captured_search_time = time.time() - search_start
# Time the generation part
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
f"{context}\n\n"
f"Question: {args[0]}\n\n"
"Please provide the best answer you can based on this context and your knowledge."
)
generation_start = time.time()
answer = self.chat.llm.ask(prompt)
captured_generation_time = time.time() - generation_start
return answer
# Apply the patch
self.chat.ask = patched_ask
# Time the total QA
total_start = time.time()
generated_answer = self.chat.ask(question)
total_time = time.time() - total_start
# Restore original method
self.chat.ask = original_ask
# Store the timings
search_times.append(captured_search_time)
generation_times.append(captured_generation_time)
total_times.append(total_time)
# Check accuracy using LLM as judge
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
if is_correct:
correct_answers += 1
status = "" if is_correct else ""
print(
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
)
print(f" GT: {ground_truth}")
print(f" Gen: {generated_answer[:100]}...")
except Exception as e:
print(f" ❌ Error: {e}")
search_times.append(0.0)
generation_times.append(0.0)
total_times.append(0.0)
accuracy = correct_answers / len(data) if data else 0.0
metrics = {
"total_questions": len(data),
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
"avg_generation_time": sum(generation_times) / len(generation_times)
if generation_times
else 0.0,
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
"accuracy": accuracy,
"correct_answers": correct_answers,
"search_times": search_times,
"generation_times": generation_times,
"total_times": total_times,
}
return metrics
def _check_answer_accuracy(
self, generated_answer: str, ground_truth: str, question: str
) -> bool:
"""Check if generated answer matches ground truth using LLM as judge"""
judge_prompt = f"""You are an expert judge evaluating financial question answering.
Question: {question}
Ground Truth Answer: {ground_truth}
Generated Answer: {generated_answer}
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
1. Numerical accuracy (exact values, units, currency)
2. Key financial concepts and terminology
3. Overall factual correctness
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
try:
judge_response = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": judge_prompt}],
max_tokens=10,
temperature=0,
)
judgment = judge_response.choices[0].message.content.strip().upper()
return judgment == "CORRECT"
except Exception as e:
print(f" ⚠️ Judge error: {e}, falling back to string matching")
# Fallback to simple string matching
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
return gt_clean in gen_clean
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 EVALUATION RESULTS")
print("=" * 50)
# Index comparison analysis
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
print("\n📏 Index Comparison Analysis:")
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
print("\n📊 Accuracy:")
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
print(
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
)
print("\n📊 Timing Breakdown:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
if timing_metrics.get("avg_total_time", 0) > 0:
search_pct = (
timing_metrics.get("avg_search_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
gen_pct = (
timing_metrics.get("avg_generation_time", 0)
/ timing_metrics.get("avg_total_time", 1)
* 100
)
print("\n📈 Time Distribution:")
print(f" Search: {search_pct:.1f}%")
print(f" Generation: {gen_pct:.1f}%")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
parser.add_argument(
"--stage",
choices=["2", "3", "4", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
)
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_financebench.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load FinanceBench queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Test with more queries for robust measurement
test_queries = queries[:2000]
print(f"🧪 Testing with {len(test_queries)} queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search (will be reused in Stage 4)
print("🏗️ Creating non-compact index for binary search...")
evaluator = FinanceBenchEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load queries for testing
print("📖 Loading FinanceBench dataset...")
queries = []
with open(args.dataset, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
# Use more queries for robust measurement
test_queries = queries[:200]
print(f"🧪 Testing with {len(test_queries)} queries")
# Binary search for 90% recall complexity (without recompute for speed)
target_recall = 0.9
min_complexity, max_complexity = 1, 32
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_queries, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_queries[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Comprehensive evaluation with dual index comparison
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
# Use FinanceBench evaluator for QA evaluation
evaluator = FinanceBenchEvaluator(
args.index, args.openai_api_key if args.llm_backend == "openai" else None
)
print("📖 Loading FinanceBench dataset...")
data = evaluator.load_dataset(args.dataset)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
from pathlib import Path
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes
print("\n📊 Index size comparison:")
print(
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print(
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
)
print("\n📊 Index-only size comparison (.index file only):")
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
# Use index-only size for fair comparison (same as Enron emails)
storage_saving = (
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
/ non_compact_size_metrics["index_only_mb"]
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for performance comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, data[:10], complexity=complexity
)
# Step 5: Generation evaluation
test_samples = 20
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
if args.llm_backend == "openai" and args.openai_api_key:
print("🔍🤖 Running OpenAI-based generation evaluation...")
evaluation_start = time.time()
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
evaluation_time = time.time() - evaluation_start
else:
print(
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
)
try:
# Load LLM
if args.llm_backend == "hf":
tokenizer, model = load_hf_model(args.model_name)
def llm_func(prompt):
return generate_hf(tokenizer, model, prompt)
else: # vllm
llm, sampling_params = load_vllm_model(args.model_name)
def llm_func(prompt):
return generate_vllm(llm, sampling_params, prompt)
# Simple generation evaluation
queries = [item["question"] for item in data[:test_samples]]
gen_results = evaluate_rag(
evaluator.searcher,
llm_func,
queries,
domain="finance",
complexity=complexity,
)
timing_metrics = {
"total_questions": len(queries),
"avg_search_time": gen_results["avg_search_time"],
"avg_generation_time": gen_results["avg_generation_time"],
"results": gen_results["results"],
}
evaluation_time = time.time()
except Exception as e:
print(f"❌ Generation evaluation failed: {e}")
timing_metrics = {
"total_questions": 0,
"avg_search_time": 0,
"avg_generation_time": 0,
}
evaluation_time = 0
# Combine all metrics
combined_metrics = {
**timing_metrics,
"total_evaluation_time": evaluation_time,
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print results
print("\n📊 Generation Results:")
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review accuracy and timing breakdown for performance optimization")
print(" - Run full evaluation on complete dataset if needed")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
from pathlib import Path
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,462 +0,0 @@
#!/usr/bin/env python3
"""
FinanceBench Complete Setup Script
Downloads all PDFs and builds full LEANN datastore
"""
import argparse
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import pymupdf
import requests
from leann import LeannBuilder, LeannSearcher
from tqdm import tqdm
class FinanceBenchSetup:
def __init__(self, data_dir: str = "data"):
self.base_dir = Path(__file__).parent # benchmarks/financebench/
self.data_dir = self.base_dir / data_dir
self.pdf_dir = self.data_dir / "pdfs"
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
self.index_dir = self.data_dir / "index"
self.download_lock = Lock()
def download_dataset(self):
"""Download the main FinanceBench dataset"""
print("📊 Downloading FinanceBench dataset...")
self.data_dir.mkdir(parents=True, exist_ok=True)
if self.dataset_file.exists():
print(f"✅ Dataset already exists: {self.dataset_file}")
return
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
response = requests.get(url, stream=True)
response.raise_for_status()
with open(self.dataset_file, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Dataset downloaded: {self.dataset_file}")
def get_pdf_list(self):
"""Get list of all PDF files from GitHub"""
print("📋 Fetching PDF list from GitHub...")
response = requests.get(
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
)
response.raise_for_status()
pdf_files = response.json()
print(f"Found {len(pdf_files)} PDF files")
return pdf_files
def download_single_pdf(self, pdf_info, position):
"""Download a single PDF file"""
pdf_name = pdf_info["name"]
pdf_path = self.pdf_dir / pdf_name
# Skip if already downloaded
if pdf_path.exists() and pdf_path.stat().st_size > 0:
return f"{pdf_name} (cached)"
try:
# Download PDF
response = requests.get(pdf_info["download_url"], timeout=60)
response.raise_for_status()
# Write to file
with self.download_lock:
with open(pdf_path, "wb") as f:
f.write(response.content)
return f"{pdf_name} ({len(response.content) // 1024}KB)"
except Exception as e:
return f"{pdf_name}: {e!s}"
def download_all_pdfs(self, max_workers: int = 5):
"""Download all PDF files with parallel processing"""
self.pdf_dir.mkdir(parents=True, exist_ok=True)
pdf_files = self.get_pdf_list()
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all download tasks
future_to_pdf = {
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
for i, pdf_info in enumerate(pdf_files)
}
# Process completed downloads with progress bar
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
for future in as_completed(future_to_pdf):
result = future.result()
pbar.set_postfix_str(result.split()[-1] if "" in result else "Error")
pbar.update(1)
# Verify downloads
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
# Show any failures
missing_pdfs = []
for pdf_info in pdf_files:
pdf_path = self.pdf_dir / pdf_info["name"]
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
missing_pdfs.append(pdf_info["name"])
if missing_pdfs:
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
for pdf in missing_pdfs[:5]: # Show first 5
print(f" - {pdf}")
if len(missing_pdfs) > 5:
print(f" ... and {len(missing_pdfs) - 5} more")
def build_leann_index(
self,
backend: str = "hnsw",
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
):
"""Build LEANN index from all PDFs"""
print(f"🏗️ Building LEANN index with {backend} backend...")
# Check if we have PDFs
pdf_files = list(self.pdf_dir.glob("*.pdf"))
if not pdf_files:
raise RuntimeError("No PDF files found! Run download first.")
print(f"Found {len(pdf_files)} PDF files to process")
start_time = time.time()
# Initialize builder with standard compact configuration
builder = LeannBuilder(
backend_name=backend,
embedding_model=embedding_model,
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=True, # Enable recompute (no stored embeddings)
is_compact=True, # Enable compact storage (pruned)
num_threads=4,
)
# Process PDFs and extract text
total_chunks = 0
failed_pdfs = []
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
try:
chunks = self.extract_pdf_text(pdf_path)
for chunk in chunks:
builder.add_text(chunk["text"], metadata=chunk["metadata"])
total_chunks += 1
except Exception as e:
print(f"❌ Failed to process {pdf_path.name}: {e}")
failed_pdfs.append(pdf_path.name)
continue
# Build index in index directory
self.index_dir.mkdir(parents=True, exist_ok=True)
index_path = self.index_dir / f"financebench_full_{backend}.leann"
print(f"🔨 Building index: {index_path}")
builder.build_index(str(index_path))
build_time = time.time() - start_time
print("✅ Index built successfully!")
print(f" 📁 Index path: {index_path}")
print(f" 📊 Total chunks: {total_chunks:,}")
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
print(f" ⏱️ Build time: {build_time:.1f}s")
if failed_pdfs:
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
return str(index_path)
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
print("🔨 Building FAISS Flat baseline...")
import os
import pickle
import numpy as np
from leann.api import compute_embeddings
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Read metadata from the built index
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
import json
meta = json.loads(f.read())
embedding_model = meta["embedding_model"]
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not os.path.isabs(passage_file):
index_dir = os.path.dirname(index_path)
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
print(f"📊 Loading passages from {passage_file}...")
print(f"🤖 Using embedding model: {embedding_model}")
# Load all passages for baseline
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"📄 Loaded {len(passages)} passages")
# Compute embeddings using the same method as LEANN
print("🧮 Computing embeddings...")
embeddings = compute_embeddings(
passages,
embedding_model,
mode="sentence-transformers",
use_server=False,
)
print(f"📐 Embedding shape: {embeddings.shape}")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(passage_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
"""Extract and chunk text from a PDF file"""
chunks = []
doc = pymupdf.open(pdf_path)
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text() # type: ignore
if not text.strip():
continue
# Create metadata
metadata = {
"source_file": pdf_path.name,
"page_number": page_num + 1,
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
"company": pdf_path.name.split("_")[0],
"doc_period": self.extract_year_from_filename(pdf_path.name),
}
# Use recursive character splitting like LangChain
if len(text.split()) > 500:
# Split by double newlines (paragraphs)
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
current_chunk = ""
for para in paragraphs:
# If adding this paragraph would make chunk too long, save current chunk
if current_chunk and len((current_chunk + " " + para).split()) > 300:
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
current_chunk = para
else:
current_chunk = (current_chunk + " " + para).strip()
# Add the last chunk
if current_chunk.strip():
chunks.append(
{
"text": current_chunk.strip(),
"metadata": {
**metadata,
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
},
}
)
else:
# Page is short enough, use as single chunk
chunks.append(
{
"text": text.strip(),
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
}
)
doc.close()
return chunks
def extract_year_from_filename(self, filename: str) -> str:
"""Extract year from PDF filename"""
# Try to find 4-digit year in filename
match = re.search(r"(\d{4})", filename)
return match.group(1) if match else "unknown"
def verify_setup(self, index_path: str):
"""Verify the setup by testing a simple query"""
print("🧪 Verifying setup with test query...")
try:
searcher = LeannSearcher(index_path)
# Test query
test_query = "What is the capital expenditure for 3M in 2018?"
results = searcher.search(test_query, top_k=3)
print(f"✅ Test query successful! Found {len(results)} results:")
for i, result in enumerate(results, 1):
company = result.metadata.get("company", "Unknown")
year = result.metadata.get("doc_period", "Unknown")
page = result.metadata.get("page_number", "Unknown")
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
print(f" {result.text[:100]}...")
searcher.cleanup()
print("✅ Setup verification completed successfully!")
except Exception as e:
print(f"❌ Setup verification failed: {e}")
raise
def main():
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument(
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
)
parser.add_argument(
"--embedding-model",
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model",
)
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
parser.add_argument(
"--build-baseline-only",
action="store_true",
help="Only build FAISS baseline from existing index",
)
args = parser.parse_args()
print("🏦 FinanceBench Complete Setup")
print("=" * 50)
setup = FinanceBenchSetup(args.data_dir)
try:
if args.build_baseline_only:
# Only build baseline from existing index
index_path = setup.index_dir / f"financebench_full_{args.backend}"
index_file = f"{index_path}.index"
meta_file = f"{index_path}.leann.meta.json"
if not os.path.exists(index_file) or not os.path.exists(meta_file):
print("❌ Index files not found:")
print(f" Index: {index_file}")
print(f" Meta: {meta_file}")
print("💡 Run without --build-baseline-only to build the index first")
exit(1)
print(f"🔨 Building baseline from existing index: {index_path}")
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
print(f"✅ Baseline built at {baseline_path}")
return
# Step 1: Download dataset
setup.download_dataset()
# Step 2: Download PDFs
if not args.skip_download:
setup.download_all_pdfs(max_workers=args.max_workers)
else:
print("⏭️ Skipping PDF download")
# Step 3: Build LEANN index
if not args.skip_build:
index_path = setup.build_leann_index(
backend=args.backend, embedding_model=args.embedding_model
)
# Step 4: Build FAISS flat baseline
print("\n🔨 Building FAISS flat baseline...")
baseline_path = setup.build_faiss_flat_baseline(index_path)
print(f"✅ Baseline built at {baseline_path}")
# Step 5: Verify setup
setup.verify_setup(index_path)
else:
print("⏭️ Skipping index building")
print("\n🎉 FinanceBench setup completed!")
print(f"📁 Data directory: {setup.data_dir.absolute()}")
print("\nNext steps:")
print(
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
)
print(
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
)
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,214 +0,0 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "faiss-cpu",
# "numpy",
# "sentence-transformers",
# "torch",
# "tqdm",
# ]
# ///
"""
Independent recall verification script using standard FAISS.
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
"""
import json
import time
from pathlib import Path
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
"""
Direct embedding computation using sentence-transformers.
Copied logic to avoid dependency issues.
"""
print(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
print(f"Computing embeddings for {len(chunks)} chunks...")
embeddings = model.encode(
chunks,
show_progress_bar=True,
batch_size=32,
convert_to_numpy=True,
normalize_embeddings=False,
)
return embeddings.astype(np.float32)
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
"""Load FinanceBench queries from dataset"""
queries = []
with open(dataset_path, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
queries.append(data["question"])
if len(queries) >= max_queries:
break
return queries
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
"""Load passages from LEANN index structure"""
meta_path = f"{index_path}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
index_dir = Path(index_path).parent
passage_file = index_dir / Path(passage_file).name
print(f"Loading passages from {passage_file}")
passages = []
passage_ids = []
with open(passage_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Loading passages"):
if line.strip():
data = json.loads(line)
passages.append(data["text"])
passage_ids.append(data["id"])
print(f"Loaded {len(passages)} passages")
return passages, passage_ids
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
dimension = embeddings.shape[1]
# Build Flat index (ground truth)
print("Building FAISS IndexFlatIP (ground truth)...")
flat_index = faiss.IndexFlatIP(dimension)
flat_index.add(embeddings)
# Build HNSW index
print("Building FAISS IndexHNSWFlat...")
M = 32 # Same as LEANN default
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
hnsw_index.add(embeddings)
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
return flat_index, hnsw_index
def evaluate_recall_at_k(
query_embeddings: np.ndarray,
flat_index: faiss.Index,
hnsw_index: faiss.Index,
passage_ids: list[str],
k: int = 3,
ef_search: int = 64,
) -> float:
"""Evaluate recall@k comparing HNSW vs Flat"""
# Set search parameters for HNSW
hnsw_index.hnsw.efSearch = ef_search
total_recall = 0.0
num_queries = query_embeddings.shape[0]
for i in range(num_queries):
query = query_embeddings[i : i + 1] # Keep 2D shape
# Get ground truth from Flat index (standard FAISS API)
flat_distances, flat_indices = flat_index.search(query, k)
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
# Get results from HNSW index (standard FAISS API)
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
# Calculate recall
intersection = ground_truth_ids.intersection(hnsw_ids)
recall = len(intersection) / k
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
print(f" Flat: {list(ground_truth_ids)}")
print(f" HNSW: {list(hnsw_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
return avg_recall
def main():
# Configuration
dataset_path = "data/financebench_merged.jsonl"
index_path = "data/index/financebench_full_hnsw.leann"
embedding_model = "sentence-transformers/all-mpnet-base-v2"
print("🔍 FAISS Recall Verification")
print("=" * 50)
# Check if files exist
if not Path(dataset_path).exists():
print(f"❌ Dataset not found: {dataset_path}")
return
if not Path(f"{index_path}.meta.json").exists():
print(f"❌ Index metadata not found: {index_path}.meta.json")
return
# Load data
print("📖 Loading FinanceBench queries...")
queries = load_financebench_queries(dataset_path, max_queries=50)
print(f"Loaded {len(queries)} queries")
print("📄 Loading passages from LEANN index...")
passages, passage_ids = load_passages_from_leann_index(index_path)
# Compute embeddings
print("🧮 Computing passage embeddings...")
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
print("🧮 Computing query embeddings...")
query_embeddings = compute_embeddings_direct(queries, embedding_model)
# Build FAISS indexes
print("🏗️ Building FAISS indexes...")
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
# Test different efSearch values (equivalent to LEANN complexity)
print("\n📊 Evaluating Recall@3 at different efSearch values...")
ef_search_values = [16, 32, 64, 128, 256]
for ef_search in ef_search_values:
print(f"\n🧪 Testing efSearch = {ef_search}")
start_time = time.time()
recall = evaluate_recall_at_k(
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
)
elapsed = time.time() - start_time
print(
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
)
print("\n✅ Verification completed!")
print("\n📋 Summary:")
print(" - Built independent FAISS Flat and HNSW indexes")
print(" - Compared recall@3 at different efSearch values")
print(" - Used same embedding model as LEANN")
print(" - This validates LEANN's recall measurements")
if __name__ == "__main__":
main()

View File

@@ -1 +0,0 @@
data/

View File

@@ -1,199 +0,0 @@
# LAION Multimodal Benchmark
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
## Overview
This benchmark evaluates:
- **Image retrieval timing** using caption-based queries
- **Recall@K performance** for image search
- **Complexity analysis** across different search parameters
- **Index size and storage efficiency**
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
## Dataset Configuration
- **Dataset**: LAION-400M subset (10,000 images)
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
- **Queries**: 200 random captions from the dataset
- **Ground Truth**: Self-recall (query caption → original image)
## Quick Start
### 1. Setup the benchmark
```bash
cd benchmarks/laion
python setup_laion.py --num-samples 10000 --num-queries 200
```
This will:
- Create dummy LAION data (10K samples)
- Generate CLIP embeddings (512-dim)
- Build LEANN index with HNSW backend
- Create 200 evaluation queries
### 2. Run evaluation
```bash
# Run all evaluation stages
python evaluate_laion.py --index data/laion_index.leann
# Run specific stages
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
# Multimodal generation with Qwen2.5-VL
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
```
### 3. Save results
```bash
python evaluate_laion.py --index data/laion_index.leann --output results.json
```
## Configuration Options
### Setup Options
```bash
python setup_laion.py \
--num-samples 10000 \
--num-queries 200 \
--index-path data/laion_index.leann \
--backend hnsw
```
### Evaluation Options
```bash
python evaluate_laion.py \
--index data/laion_index.leann \
--queries data/evaluation_queries.jsonl \
--complexity 64 \
--top-k 3 \
--num-samples 100 \
--stage all
```
## Evaluation Stages
### Stage 2: Recall Evaluation
- Evaluates Recall@3 for multimodal retrieval
- Compares LEANN vs FAISS baseline performance
- Self-recall: query caption should retrieve original image
### Stage 3: Complexity Analysis
- Binary search for optimal complexity (90% recall target)
- Tests performance across different complexity levels
- Analyzes speed vs. accuracy tradeoffs
### Stage 4: Index Comparison
- Compares compact vs non-compact index sizes
- Measures search performance differences
- Reports storage efficiency and speed ratios
### Stage 5: Multimodal Generation
- Uses Qwen2.5-VL for image understanding and description
- Retrieval-Augmented Generation (RAG) with multimodal context
- Measures both search and generation timing
## Output Metrics
### Timing Metrics
- Average/median/min/max search time
- Standard deviation
- Searches per second
- Latency in milliseconds
### Recall Metrics
- Recall@3 percentage for image retrieval
- Number of queries with ground truth
### Index Metrics
- Total index size (MB)
- Component breakdown (index, passages, metadata)
- Storage savings (compact vs non-compact)
- Backend and embedding model info
### Generation Metrics (Stage 5)
- Average search time per query
- Average generation time per query
- Time distribution (search vs generation)
- Sample multimodal responses
- Model: Qwen2.5-VL performance
## Benchmark Results
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
**Stage 3: Optimal Complexity Analysis**
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
- **Binary Search Range**: 1-128
- **Target Recall**: 90%
- **Index Type**: Non-compact (for fast binary search)
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
- **Total Queries**: 20
- **Average Search Time**: 1.200s per query
- **Average Generation Time**: 6.558s per query
- **Time Distribution**: Search 15.5%, Generation 84.5%
- **LLM Backend**: HuggingFace transformers
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
- **Optimal Complexity**: 85
**System Performance:**
- **Index Size**: ~10,000 image embeddings from LAION subset
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
- **Backend**: HNSW with cosine distance
### Example Results
```
🎯 LAION MULTIMODAL BENCHMARK RESULTS
============================================================
📊 Multimodal Generation Results:
Total Queries: 20
Avg Search Time: 1.200s
Avg Generation Time: 6.558s
Time Distribution: Search 15.5%, Generation 84.5%
LLM Backend: HuggingFace transformers
Model: Qwen/Qwen2.5-VL-7B-Instruct
⚙️ Optimal Complexity Analysis:
Target Recall: 90%
Optimal Complexity: 85
Binary Search Range: 1-128
Non-compact Index (fast search, no recompute)
🚀 Performance Summary:
Multimodal RAG: 7.758s total per query
Search: 15.5% of total time
Generation: 84.5% of total time
```
## Directory Structure
```
benchmarks/laion/
├── setup_laion.py # Setup script
├── evaluate_laion.py # Evaluation script
├── README.md # This file
└── data/ # Generated data
├── laion_images/ # Image files (placeholder)
├── laion_metadata.jsonl # Image metadata
├── laion_passages.jsonl # LEANN passages
├── laion_embeddings.npy # CLIP embeddings
├── evaluation_queries.jsonl # Evaluation queries
└── laion_index.leann/ # LEANN index files
```
## Notes
- Current implementation uses dummy data for demonstration
- For real LAION data, implement actual download logic in `setup_laion.py`
- CLIP embeddings are randomly generated - replace with real CLIP model for production
- Adjust `num_samples` and `num_queries` based on available resources
- Consider using `--num-samples` during evaluation for faster testing

View File

@@ -1,725 +0,0 @@
"""
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
"""
import argparse
import json
import logging
import os
import pickle
import time
from pathlib import Path
import numpy as np
from leann import LeannSearcher
from leann_backend_hnsw import faiss
from sentence_transformers import SentenceTransformer
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
# Setup logging to reduce verbose output
logging.basicConfig(level=logging.WARNING)
logging.getLogger("leann.api").setLevel(logging.WARNING)
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
class RecallEvaluator:
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
def __init__(self, index_path: str, baseline_dir: str):
self.index_path = index_path
self.baseline_dir = baseline_dir
self.searcher = LeannSearcher(index_path)
# Load FAISS flat baseline (image embeddings)
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
self.faiss_index = faiss.read_index(baseline_index_path)
with open(metadata_path, "rb") as f:
self.image_ids = pickle.load(f)
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
self.st_clip = SentenceTransformer("clip-ViT-L-14")
def evaluate_recall_at_3(
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
) -> float:
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
total_recall = 0.0
num_queries = len(captions)
for i, caption in enumerate(captions):
# Get ground truth: search with FAISS flat using caption text embedding
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
query_embedding = self.st_clip.encode(
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
).astype(np.float32)
# Search FAISS flat for ground truth using LEANN's modified faiss API
n = query_embedding.shape[0] # Number of queries
k = 3 # Number of nearest neighbors
distances = np.zeros((n, k), dtype=np.float32)
labels = np.zeros((n, k), dtype=np.int64)
self.faiss_index.search(
n,
faiss.swig_ptr(query_embedding),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
)
# Extract the results (image IDs from FAISS)
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
# Search with LEANN at specified complexity (using caption as text query)
test_results = self.searcher.search(
caption,
top_k=3,
complexity=complexity,
recompute_embeddings=recompute_embeddings,
)
test_ids = {result.id for result in test_results}
# Calculate recall@3 = |intersection| / |ground_truth|
intersection = test_ids.intersection(baseline_ids)
recall = len(intersection) / 3.0 # Ground truth size is 3
total_recall += recall
if i < 3: # Show first few examples
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
print(f" FAISS ground truth: {list(baseline_ids)}")
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
print(f" Intersection: {list(intersection)}")
avg_recall = total_recall / num_queries
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
return avg_recall
def cleanup(self):
"""Cleanup resources"""
if hasattr(self, "searcher"):
self.searcher.cleanup()
class LAIONEvaluator:
def __init__(self, index_path: str):
self.index_path = index_path
self.searcher = LeannSearcher(index_path)
def load_queries(self, queries_file: str) -> list[str]:
"""Load caption queries from evaluation file"""
captions = []
with open(queries_file, encoding="utf-8") as f:
for line in f:
if line.strip():
query_data = json.loads(line)
captions.append(query_data["query"])
print(f"📊 Loaded {len(captions)} caption queries")
return captions
def analyze_index_sizes(self) -> dict:
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
print("📏 Analyzing index sizes (.index only)...")
# Get all index-related files
index_path = Path(self.index_path)
index_dir = index_path.parent
index_name = index_path.stem # Remove .leann extension
sizes: dict[str, float] = {}
# Core index files
index_file = index_dir / f"{index_name}.index"
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
# Core index size (.index only)
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
sizes["index_only_mb"] = index_mb
# Other files for reference (not counted in index_only_mb)
sizes["metadata_mb"] = (
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
)
sizes["passages_text_mb"] = (
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
)
sizes["passages_index_mb"] = (
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
)
print(f" 📁 .index size: {index_mb:.1f} MB")
if sizes["metadata_mb"]:
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
print(
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
)
return sizes
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
"""Create a non-compact index for comparison purposes"""
print("🏗️ Building non-compact index from existing passages...")
# Load existing passages from current index
from leann import LeannBuilder
current_index_path = Path(self.index_path)
current_index_dir = current_index_path.parent
current_index_name = current_index_path.name
# Read metadata to get passage source
meta_path = current_index_dir / f"{current_index_name}.meta.json"
with open(meta_path) as f:
meta = json.load(f)
passage_source = meta["passage_sources"][0]
passage_file = passage_source["path"]
# Convert relative path to absolute
if not Path(passage_file).is_absolute():
passage_file = current_index_dir / Path(passage_file).name
print(f"📄 Loading passages from {passage_file}...")
# Load CLIP embeddings
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
embeddings = np.load(embeddings_file)
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
# Build non-compact index with same passages and embeddings
builder = LeannBuilder(
backend_name="hnsw",
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
is_recompute=False, # Disable recompute (store embeddings)
is_compact=False, # Disable compact storage
distance_metric="cosine",
**{
k: v
for k, v in meta.get("backend_kwargs", {}).items()
if k not in ["is_recompute", "is_compact", "distance_metric"]
},
)
# Prepare ids and add passages
ids: list[str] = []
with open(passage_file, encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
ids.append(str(data["id"]))
# Ensure metadata contains the id used by the vector index
metadata = {**data.get("metadata", {}), "id": data["id"]}
builder.add_text(text=data["text"], metadata=metadata)
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
# Persist a pickle for build_index_from_embeddings
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
)
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
# Analyze the non-compact index size
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_sizes = temp_evaluator.analyze_index_sizes()
non_compact_sizes["index_type"] = "non_compact"
return non_compact_sizes
def compare_index_performance(
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
) -> dict:
"""Compare performance between non-compact and compact indexes"""
print("⚡ Comparing search performance between indexes...")
# Test queries
test_queries = test_captions[:5]
results = {
"non_compact": {"search_times": []},
"compact": {"search_times": []},
"avg_search_times": {},
"speed_ratio": 0.0,
}
# Test non-compact index (no recompute)
print(" 🔍 Testing non-compact index (no recompute)...")
non_compact_searcher = LeannSearcher(non_compact_path)
for caption in test_queries:
start_time = time.time()
_ = non_compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=False
)
search_time = time.time() - start_time
results["non_compact"]["search_times"].append(search_time)
# Test compact index (with recompute)
print(" 🔍 Testing compact index (with recompute)...")
compact_searcher = LeannSearcher(compact_path)
for caption in test_queries:
start_time = time.time()
_ = compact_searcher.search(
caption, top_k=3, complexity=complexity, recompute_embeddings=True
)
search_time = time.time() - start_time
results["compact"]["search_times"].append(search_time)
# Calculate averages
results["avg_search_times"]["non_compact"] = sum(
results["non_compact"]["search_times"]
) / len(results["non_compact"]["search_times"])
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
results["compact"]["search_times"]
)
# Performance ratio
if results["avg_search_times"]["compact"] > 0:
results["speed_ratio"] = (
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
)
else:
results["speed_ratio"] = float("inf")
print(
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
)
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
# Cleanup
non_compact_searcher.cleanup()
compact_searcher.cleanup()
return results
def _print_results(self, timing_metrics: dict):
"""Print evaluation results"""
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
print("=" * 60)
# Index comparison analysis (prefer .index-only view if present)
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
current = timing_metrics["current_index"]
non_compact = timing_metrics["non_compact_index"]
if "index_only_mb" in current and "index_only_mb" in non_compact:
print("\n📏 Index Comparison Analysis (.index only):")
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
# Show excluded components for reference if available
if any(
k in non_compact
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
):
print(" (passages excluded in totals, shown for reference):")
print(
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
)
else:
# Fallback to legacy totals if running with older metrics
print("\n📏 Index Comparison Analysis:")
print(
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
)
print(" Component breakdown (non-compact):")
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
# Performance comparison
if "performance_comparison" in timing_metrics:
perf = timing_metrics["performance_comparison"]
print("\n⚡ Performance Comparison:")
print(
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
)
print(
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
)
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
# Legacy single index analysis (fallback)
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
print("\n📏 Index Size Analysis:")
print(
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
)
print(
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
)
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
def cleanup(self):
"""Cleanup resources"""
if self.searcher:
self.searcher.cleanup()
def main():
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
parser.add_argument("--index", required=True, help="Path to LEANN index")
parser.add_argument(
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
)
parser.add_argument(
"--stage",
choices=["2", "3", "4", "5", "all"],
default="all",
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
)
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
parser.add_argument("--output", help="Save results to JSON file")
parser.add_argument(
"--llm-backend",
choices=["hf"],
default="hf",
help="LLM backend (Qwen2.5-VL only supports HF)",
)
parser.add_argument(
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
)
args = parser.parse_args()
try:
# Check if baseline exists
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
if not os.path.exists(baseline_index_path):
print(f"❌ FAISS baseline not found at {baseline_index_path}")
print("💡 Please run setup_laion.py first to build the baseline")
exit(1)
if args.stage == "2" or args.stage == "all":
# Stage 2: Recall@3 evaluation
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
evaluator = RecallEvaluator(args.index, args.baseline_dir)
# Load caption queries for testing
laion_evaluator = LAIONEvaluator(args.index)
captions = laion_evaluator.load_queries(args.queries)
# Test with queries for robust measurement
test_captions = captions[:100] # Use subset for speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Test with complexity 64
complexity = 64
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
evaluator.cleanup()
print("✅ Stage 2 completed!\n")
# Shared non-compact index path for Stage 3 and 4
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
complexity = args.complexity
if args.stage == "3" or args.stage == "all":
# Stage 3: Binary search for 90% recall complexity
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
print(
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
)
# Create non-compact index for binary search
print("🏗️ Creating non-compact index for binary search...")
evaluator = LAIONEvaluator(args.index)
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
# Use non-compact index for binary search
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
# Load caption queries for testing
captions = evaluator.load_queries(args.queries)
# Use subset for robust measurement
test_captions = captions[:50] # Smaller subset for binary search speed
print(f"🧪 Testing with {len(test_captions)} caption queries")
# Binary search for 90% recall complexity
target_recall = 0.9
min_complexity, max_complexity = 1, 128
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
print(f"Search range: {min_complexity} to {max_complexity}")
best_complexity = None
best_recall = 0.0
while min_complexity <= max_complexity:
mid_complexity = (min_complexity + max_complexity) // 2
print(
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
)
# Use recompute_embeddings=False on non-compact index for fast binary search
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, mid_complexity, recompute_embeddings=False
)
print(
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
)
if recall >= target_recall:
best_complexity = mid_complexity
best_recall = recall
max_complexity = mid_complexity - 1
print(" ✅ Target reached! Searching for lower complexity...")
else:
min_complexity = mid_complexity + 1
print(" ❌ Below target. Searching for higher complexity...")
if best_complexity is not None:
print("\n🎯 Optimal complexity found!")
print(f" Complexity: {best_complexity}")
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
# Test a few complexities around the optimal one for verification
print("\n🔬 Verification test around optimal complexity:")
verification_complexities = [
max(1, best_complexity - 2),
max(1, best_complexity - 1),
best_complexity,
best_complexity + 1,
best_complexity + 2,
]
for complexity in verification_complexities:
if complexity <= 512: # reasonable upper bound
recall = binary_search_evaluator.evaluate_recall_at_3(
test_captions, complexity, recompute_embeddings=False
)
status = "" if recall >= target_recall else ""
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
# Now test the optimal complexity with compact index and recompute for comparison
print(
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
)
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
test_captions[:10], best_complexity, recompute_embeddings=True
)
print(
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
)
complexity = best_complexity
print(
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
)
compact_evaluator.cleanup()
else:
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
print("All tested complexities were below target.")
# Cleanup evaluators (keep non-compact index for Stage 4)
binary_search_evaluator.cleanup()
evaluator.cleanup()
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
if args.stage == "4" or args.stage == "all":
# Stage 4: Index comparison (without LLM generation)
print("🚀 Starting Stage 4: Index comparison analysis")
# Use LAION evaluator for index comparison
evaluator = LAIONEvaluator(args.index)
# Load caption queries
captions = evaluator.load_queries(args.queries)
# Step 1: Analyze current (compact) index
print("\n📏 Analyzing current index (compact, pruned)...")
compact_size_metrics = evaluator.analyze_index_sizes()
compact_size_metrics["index_type"] = "compact"
# Step 2: Use existing non-compact index or create if needed
if Path(non_compact_index_path).exists():
print(
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
)
temp_evaluator = LAIONEvaluator(non_compact_index_path)
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
non_compact_size_metrics["index_type"] = "non_compact"
else:
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
non_compact_index_path
)
# Step 3: Compare index sizes (.index only)
print("\n📊 Index size comparison (.index only):")
print(
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
)
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
storage_saving = 0.0
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
storage_saving = (
(
non_compact_size_metrics.get("index_only_mb", 0)
- compact_size_metrics.get("index_only_mb", 0)
)
/ non_compact_size_metrics.get("index_only_mb", 1)
* 100
)
print(f" Storage saving by compact: {storage_saving:.1f}%")
# Step 4: Performance comparison between the two indexes
if complexity is None:
raise ValueError("Complexity is required for index comparison")
print("\n⚡ Performance comparison between indexes...")
performance_metrics = evaluator.compare_index_performance(
non_compact_index_path, args.index, captions[:10], complexity=complexity
)
# Combine all metrics
combined_metrics = {
"current_index": compact_size_metrics,
"non_compact_index": non_compact_size_metrics,
"performance_comparison": performance_metrics,
"storage_saving_percent": storage_saving,
}
# Print comprehensive results
evaluator._print_results(combined_metrics)
# Save results if requested
if args.output:
print(f"\n💾 Saving results to {args.output}...")
with open(args.output, "w") as f:
json.dump(combined_metrics, f, indent=2, default=str)
print(f"✅ Results saved to {args.output}")
evaluator.cleanup()
print("✅ Stage 4 completed!\n")
if args.stage in ("5", "all"):
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
evaluator = LAIONEvaluator(args.index)
captions = evaluator.load_queries(args.queries)
test_captions = captions[: min(20, len(captions))] # Use subset for generation
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
# Load Qwen2.5-VL model
try:
print("Loading Qwen2.5-VL model...")
processor, model = load_qwen_vl_model(args.model_name)
# Run multimodal generation evaluation
complexity = args.complexity or 64
gen_results = evaluate_multimodal_rag(
evaluator.searcher,
test_captions,
processor=processor,
model=model,
complexity=complexity,
)
print("\n📊 Multimodal Generation Results:")
print(f" Total Queries: {len(test_captions)}")
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
search_pct = (gen_results["avg_search_time"] / total_time) * 100
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
print(" LLM Backend: HuggingFace transformers")
print(f" Model: {args.model_name}")
# Show sample results
print("\n📝 Sample Multimodal Generations:")
for i, response in enumerate(gen_results["results"][:3]):
# Handle both string and dict formats for captions
if isinstance(test_captions[i], dict):
caption_text = test_captions[i].get("query", str(test_captions[i]))
else:
caption_text = str(test_captions[i])
print(f" Query {i + 1}: {caption_text[:60]}...")
print(f" Response {i + 1}: {response[:100]}...")
print()
except Exception as e:
print(f"❌ Multimodal generation evaluation failed: {e}")
print("💡 Make sure transformers and Qwen2.5-VL are installed")
import traceback
traceback.print_exc()
evaluator.cleanup()
print("✅ Stage 5 completed!\n")
if args.stage == "all":
print("🎉 All evaluation stages completed successfully!")
print("\n📋 Summary:")
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
print(" Stage 3: ✅ Optimal complexity found")
print(" Stage 4: ✅ Index comparison analysis completed")
print(" Stage 5: ✅ Multimodal generation evaluation completed")
print("\n🔧 Recommended next steps:")
print(" - Use optimal complexity for best speed/accuracy balance")
print(" - Review index comparison for storage vs performance tradeoffs")
# Clean up non-compact index after all stages complete
print("\n🧹 Cleaning up temporary non-compact index...")
if Path(non_compact_index_path).exists():
temp_index_dir = Path(non_compact_index_path).parent
temp_index_name = Path(non_compact_index_path).name
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
temp_file.unlink()
print(f"✅ Cleaned up {non_compact_index_path}")
else:
print("📝 No temporary index to clean up")
except KeyboardInterrupt:
print("\n⚠️ Evaluation interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Stage {args.stage} failed: {e}")
import traceback
traceback.print_exc()
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,576 +0,0 @@
"""
LAION Multimodal Benchmark Setup Script
Downloads LAION subset and builds LEANN index with sentence embeddings
"""
import argparse
import asyncio
import io
import json
import os
import pickle
import time
from pathlib import Path
import aiohttp
import numpy as np
from datasets import load_dataset
from leann import LeannBuilder
from PIL import Image
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
class LAIONSetup:
def __init__(self, data_dir: str = "data"):
self.data_dir = Path(data_dir)
self.images_dir = self.data_dir / "laion_images"
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
# Create directories
self.data_dir.mkdir(exist_ok=True)
self.images_dir.mkdir(exist_ok=True)
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
"""Download a single image asynchronously"""
async with semaphore: # Limit concurrent downloads
try:
image_url = sample_data["url"]
image_path = sample_data["image_path"]
# Skip if already exists
if os.path.exists(image_path):
progress_bar.update(1)
return sample_data
async with session.get(image_url, timeout=10) as response:
if response.status == 200:
content = await response.read()
# Verify it's a valid image
try:
img = Image.open(io.BytesIO(content))
img = img.convert("RGB")
img.save(image_path, "JPEG")
progress_bar.update(1)
return sample_data
except Exception:
progress_bar.update(1)
return None # Skip invalid images
else:
progress_bar.update(1)
return None
except Exception:
progress_bar.update(1)
return None
def download_laion_subset(self, num_samples: int = 1000):
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
# Load LAION-400M subset from HuggingFace
print("🤗 Loading from HuggingFace datasets...")
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
# Collect sample metadata first (fast)
print("📋 Collecting sample metadata...")
candidates = []
for sample in dataset:
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
break
image_url = sample.get("url", "")
caption = sample.get("caption", "")
if not image_url or not caption:
continue
image_filename = f"laion_{len(candidates):06d}.jpg"
image_path = self.images_dir / image_filename
candidate = {
"id": f"laion_{len(candidates):06d}",
"url": image_url,
"caption": caption,
"image_path": str(image_path),
"width": sample.get("original_width", 512),
"height": sample.get("original_height", 512),
"similarity": sample.get("similarity", 0.0),
}
candidates.append(candidate)
print(
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
)
# Download images in parallel
async def download_batch():
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
timeout = aiohttp.ClientTimeout(total=30)
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
tasks = []
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
task = self.download_single_image(session, candidate, semaphore, progress_bar)
tasks.append(task)
# Wait for all downloads
results = await asyncio.gather(*tasks, return_exceptions=True)
progress_bar.close()
# Filter successful downloads
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
return successful[:num_samples]
# Run async download
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
samples = loop.run_until_complete(download_batch())
finally:
loop.close()
# Save metadata
with open(self.metadata_file, "w", encoding="utf-8") as f:
for sample in samples:
f.write(json.dumps(sample) + "\n")
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
return samples
def generate_clip_image_embeddings(self, samples: list[dict]):
"""Generate CLIP image embeddings for downloaded images"""
print("🔍 Generating CLIP image embeddings...")
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
# This single model can encode both images and text.
model = SentenceTransformer("clip-ViT-L-14")
embeddings = []
valid_samples = []
for sample in tqdm(samples, desc="Processing images"):
try:
# Load image
image_path = sample["image_path"]
image = Image.open(image_path).convert("RGB")
# Encode image to 768-dim embedding via sentence-transformers (normalized)
vec = model.encode(
[image],
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=1,
show_progress_bar=False,
)[0]
embeddings.append(vec.astype(np.float32))
valid_samples.append(sample)
except Exception as e:
print(f" ⚠️ Failed to process {sample['id']}: {e}")
# Skip invalid images
embeddings = np.array(embeddings, dtype=np.float32)
# Save embeddings
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
np.save(embeddings_file, embeddings)
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
return embeddings, valid_samples
def build_faiss_baseline(
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
):
"""Build FAISS flat baseline using CLIP image embeddings"""
print("🔨 Building FAISS Flat baseline...")
from leann_backend_hnsw import faiss
os.makedirs(output_dir, exist_ok=True)
baseline_path = os.path.join(output_dir, "faiss_flat.index")
metadata_path = os.path.join(output_dir, "metadata.pkl")
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
print(f"✅ Baseline already exists at {baseline_path}")
return baseline_path
# Extract image IDs (must be present)
if not samples or "id" not in samples[0]:
raise KeyError("samples missing 'id' field for FAISS baseline")
image_ids: list[str] = [str(sample["id"]) for sample in samples]
print(f"📐 Embedding shape: {embeddings.shape}")
print(f"📄 Processing {len(image_ids)} images")
# Build FAISS flat index
print("🏗️ Building FAISS IndexFlatIP...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
# Add embeddings to flat index
embeddings_f32 = embeddings.astype(np.float32)
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
# Save index and metadata
faiss.write_index(index, baseline_path)
with open(metadata_path, "wb") as f:
pickle.dump(image_ids, f)
print(f"✅ FAISS baseline saved to {baseline_path}")
print(f"✅ Metadata saved to {metadata_path}")
print(f"📊 Total vectors: {index.ntotal}")
return baseline_path
def create_leann_passages(self, samples: list[dict]):
"""Create LEANN-compatible passages from LAION data"""
print("📝 Creating LEANN passages...")
passages_file = self.data_dir / "laion_passages.jsonl"
with open(passages_file, "w", encoding="utf-8") as f:
for i, sample in enumerate(samples):
passage = {
"id": sample["id"],
"text": sample["caption"], # Use caption as searchable text
"metadata": {
"image_url": sample["url"],
"image_path": sample.get("image_path", ""),
"width": sample["width"],
"height": sample["height"],
"similarity": sample["similarity"],
"image_index": i, # Index for embedding lookup
},
}
f.write(json.dumps(passage) + "\n")
print(f"✅ Created {len(samples)} passages")
return passages_file
def build_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
print(f"🏗️ Building compact LEANN index with {backend} backend...")
start_time = time.time()
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
npy_path = self.data_dir / "clip_image_embeddings.npy"
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in the same order as passages_file (matches embeddings order)
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - compact with recompute
# Note: For multimodal case, we need to handle embeddings differently
# Let's try using sentence-transformers mode but with custom embeddings
builder = LeannBuilder(
backend_name=backend,
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
# HNSW params (or forwarded to chosen backend)
graph_degree=32,
complexity=64,
# Compact/pruned with recompute at query time
is_recompute=True,
is_compact=True,
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
num_threads=4,
)
# Add passages (text + metadata)
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def build_non_compact_index(
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
):
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
start_time = time.time()
# Ensure embeddings are saved (npy + pickle)
npy_path = self.data_dir / "clip_image_embeddings.npy"
if not npy_path.exists():
np.save(npy_path, embeddings)
print(f"💾 Saved CLIP embeddings to {npy_path}")
# Prepare ids in same order as passages_file
ids: list[str] = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
rec = json.loads(line)
ids.append(str(rec["id"]))
if len(ids) != embeddings.shape[0]:
raise ValueError(
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
)
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
if not pkl_path.exists():
with open(pkl_path, "wb") as pf:
pickle.dump((ids, embeddings.astype(np.float32)), pf)
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
# Initialize builder - non-compact without recompute
builder = LeannBuilder(
backend_name=backend,
embedding_model="clip-ViT-L-14",
embedding_mode="sentence-transformers",
graph_degree=32,
complexity=64,
is_recompute=False, # Store embeddings (no recompute needed)
is_compact=False, # Store full index (not pruned)
distance_metric="cosine",
num_threads=4,
)
# Add passages - embeddings will be loaded from file
print("📚 Adding passages...")
self._add_passages_with_embeddings(builder, passages_file, embeddings)
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
builder.build_index_from_embeddings(index_path, str(pkl_path))
build_time = time.time() - start_time
print(f"✅ Non-compact index built in {build_time:.2f}s")
# Analyze index size
self._analyze_index_size(index_path)
return index_path
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
"""Helper to add passages with pre-computed CLIP embeddings"""
with open(passages_file, encoding="utf-8") as f:
for line in tqdm(f, desc="Adding passages"):
if line.strip():
passage = json.loads(line)
# Add image metadata - LEANN will handle embeddings separately
# Note: We store image metadata and caption text for searchability
# Important: ensure passage ID in metadata matches vector ID
builder.add_text(
text=passage["text"], # Image caption for searchability
metadata={**passage["metadata"], "id": passage["id"]},
)
def _analyze_index_size(self, index_path: str):
"""Analyze index file sizes"""
print("📏 Analyzing index sizes...")
index_path = Path(index_path)
index_dir = index_path.parent
index_name = index_path.name # e.g., laion_index.leann
index_prefix = index_path.stem # e.g., laion_index
files = [
(f"{index_prefix}.index", ".index", "core"),
(f"{index_name}.meta.json", ".meta.json", "core"),
(f"{index_name}.ids.txt", ".ids.txt", "core"),
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
]
def _fmt_size(bytes_val: int) -> str:
if bytes_val < 1024:
return f"{bytes_val} B"
kb = bytes_val / 1024
if kb < 1024:
return f"{kb:.1f} KB"
mb = kb / 1024
if mb < 1024:
return f"{mb:.2f} MB"
gb = mb / 1024
return f"{gb:.2f} GB"
total_index_only_mb = 0.0
total_all_mb = 0.0
for filename, label, group in files:
file_path = index_dir / filename
if file_path.exists():
size_bytes = file_path.stat().st_size
print(f" {label}: {_fmt_size(size_bytes)}")
size_mb = size_bytes / (1024 * 1024)
total_all_mb += size_mb
if group == "core":
total_index_only_mb += size_mb
else:
print(f" {label}: (missing)")
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
print(f" Total (including passages): {total_all_mb:.2f} MB")
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
"""Create evaluation queries from captions"""
print(f"📝 Creating {num_queries} evaluation queries...")
# Sample random captions as queries
import random
random.seed(42) # For reproducibility
query_samples = random.sample(samples, min(num_queries, len(samples)))
queries_file = self.data_dir / "evaluation_queries.jsonl"
with open(queries_file, "w", encoding="utf-8") as f:
for sample in query_samples:
query = {
"id": sample["id"],
"query": sample["caption"],
"ground_truth_id": sample["id"], # For potential recall evaluation
}
f.write(json.dumps(query) + "\n")
print(f"✅ Created {len(query_samples)} evaluation queries")
return queries_file
def main():
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
parser.add_argument("--data-dir", default="data", help="Data directory")
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
parser.add_argument(
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
)
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
args = parser.parse_args()
print("🚀 Setting up LAION Multimodal Benchmark")
print("=" * 50)
try:
# Initialize setup
setup = LAIONSetup(args.data_dir)
# Step 1: Download LAION subset
if not args.skip_download:
print("\n📦 Step 1: Download LAION subset")
samples = setup.download_laion_subset(args.num_samples)
# Step 2: Generate CLIP image embeddings
print("\n🔍 Step 2: Generate CLIP image embeddings")
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
# Step 3: Create LEANN passages (image metadata with embeddings)
print("\n📝 Step 3: Create LEANN passages")
passages_file = setup.create_leann_passages(valid_samples)
else:
print("⏭️ Skipping LAION dataset download")
# Load existing data
passages_file = setup.data_dir / "laion_passages.jsonl"
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
if not passages_file.exists() or not embeddings_file.exists():
raise FileNotFoundError(
"Passages or embeddings file not found. Run without --skip-download first."
)
embeddings = np.load(embeddings_file)
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
# Step 4: Build LEANN indexes (both compact and non-compact)
if not args.skip_build:
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
# Build compact index (production mode - small, recompute required)
compact_index_path = args.index_path
print(f"Building compact index: {compact_index_path}")
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
# Build non-compact index (comparison mode - large, fast search)
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
print(f"Building non-compact index: {non_compact_index_path}")
setup.build_non_compact_index(
passages_file, embeddings, non_compact_index_path, args.backend
)
# Step 5: Build FAISS flat baseline
print("\n🔨 Step 5: Build FAISS flat baseline")
if not args.skip_download:
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
else:
# Load valid_samples from passages file for FAISS baseline
valid_samples = []
with open(passages_file, encoding="utf-8") as f:
for line in f:
if line.strip():
passage = json.loads(line)
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
# Step 6: Create evaluation queries
print("\n📝 Step 6: Create evaluation queries")
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
else:
print("⏭️ Skipping index building")
baseline_path = "data/baseline/faiss_index.bin"
queries_file = setup.data_dir / "evaluation_queries.jsonl"
print("\n🎉 Setup completed successfully!")
print("📊 Summary:")
if not args.skip_download:
print(f" Downloaded samples: {len(samples)}")
print(f" Valid samples with embeddings: {len(valid_samples)}")
else:
print(f" Loaded {len(embeddings)} embeddings")
if not args.skip_build:
print(f" Compact index: {compact_index_path}")
print(f" Non-compact index: {non_compact_index_path}")
print(f" FAISS baseline: {baseline_path}")
print(f" Queries: {queries_file}")
print("\n🔧 Next steps:")
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
else:
print(" Skipped building indexes")
except KeyboardInterrupt:
print("\n⚠️ Setup interrupted by user")
exit(1)
except Exception as e:
print(f"\n❌ Setup failed: {e}")
exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,342 +0,0 @@
"""
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
"""
import time
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
try:
from vllm import LLM, SamplingParams
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
def is_qwen3_model(model_name):
"""Check if model is Qwen3"""
return "Qwen3" in model_name or "qwen3" in model_name.lower()
def is_qwen_vl_model(model_name):
"""Check if model is Qwen2.5-VL"""
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
def apply_qwen3_chat_template(tokenizer, prompt):
"""Apply Qwen3 chat template with thinking enabled"""
messages = [{"role": "user", "content": prompt}]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
def extract_thinking_answer(response):
"""Extract final answer from Qwen3 thinking model response"""
if "<think>" in response and "</think>" in response:
try:
think_end = response.index("</think>") + len("</think>")
final_answer = response[think_end:].strip()
return final_answer
except (ValueError, IndexError):
pass
return response.strip()
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
"""Load HuggingFace model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading HF: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return tokenizer, model
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
"""Load vLLM model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not VLLM_AVAILABLE:
raise ImportError("vllm not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading vLLM: {model_name}")
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
# Qwen3 specific config
if is_qwen3_model(model_name):
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
max_tokens = 2048
else:
stop_tokens = None
max_tokens = 1024
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
return llm, sampling_params
def generate_hf(tokenizer, model, prompt, max_tokens=None):
"""Generate with HF - supports Qwen3 thinking models"""
model_name = getattr(model, "name_or_path", "unknown")
is_qwen3 = is_qwen3_model(model_name)
# Apply chat template for Qwen3
if is_qwen3:
prompt = apply_qwen3_chat_template(tokenizer, prompt)
max_tokens = max_tokens or 2048
else:
max_tokens = max_tokens or 1024
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt) :].strip()
# Extract final answer for thinking models
if is_qwen3:
return extract_thinking_answer(response)
return response
def generate_vllm(llm, sampling_params, prompt):
"""Generate with vLLM - supports Qwen3 thinking models"""
outputs = llm.generate([prompt], sampling_params)
response = outputs[0].outputs[0].text.strip()
# Extract final answer for Qwen3 thinking models
model_name = str(llm.llm_engine.model_config.model)
if is_qwen3_model(model_name):
return extract_thinking_answer(response)
return response
def create_prompt(context, query, domain="default"):
"""Create RAG prompt"""
if domain == "emails":
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "finance":
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
elif domain == "multimodal":
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
else:
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
"""Simple RAG evaluation with timing"""
search_times = []
gen_times = []
results = []
for i, query in enumerate(queries):
# Search
start = time.time()
docs = searcher.search(query, top_k=top_k, complexity=complexity)
search_time = time.time() - start
# Generate
context = "\n\n".join([doc.text for doc in docs])
prompt = create_prompt(context, query, domain)
start = time.time()
response = llm_func(prompt)
gen_time = time.time() - start
search_times.append(search_time)
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
"""Load Qwen2.5-VL multimodal model
Args:
model_name (str): Name of the model to load
trust_remote_code (bool): Whether to allow execution of code from the model repository.
Defaults to False for security. Only enable for trusted models.
"""
if not HF_AVAILABLE:
raise ImportError("transformers not available")
if trust_remote_code:
print(
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
)
print(f"Loading Qwen2.5-VL: {model_name}")
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
model = AutoModelForVision2Seq.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return processor, model
except Exception as e:
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
# Fallback to specific class
try:
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=trust_remote_code
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=trust_remote_code,
)
return processor, model
except Exception as e2:
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
"""Generate with Qwen2.5-VL multimodal model"""
from PIL import Image
# Prepare inputs
if image_path:
image = Image.open(image_path)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
)
# Decode response
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
response = processor.decode(generated_ids[0], skip_special_tokens=True)
return response
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
"""Create prompt for multimodal RAG"""
if task_type == "images":
return f"""Based on the retrieved images and their descriptions, answer the following question.
Retrieved Image Descriptions:
{context}
Question: {query}
Provide a detailed answer based on the visual content described above."""
return f"Context: {context}\nQuestion: {query}\nAnswer:"
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
"""Evaluate multimodal RAG with Qwen2.5-VL"""
search_times = []
gen_times = []
results = []
for i, query_item in enumerate(queries):
# Handle both string and dict formats for queries
if isinstance(query_item, dict):
query = query_item.get("query", "")
image_path = query_item.get("image_path") # Optional reference image
else:
query = str(query_item)
image_path = None
# Search
start_time = time.time()
search_results = searcher.search(query, top_k=3, complexity=complexity)
search_time = time.time() - start_time
search_times.append(search_time)
# Prepare context from search results
context_parts = []
for result in search_results:
context_parts.append(f"- {result.text}")
context = "\n".join(context_parts)
# Generate with multimodal model
start_time = time.time()
if processor and model:
prompt = create_multimodal_prompt(context, query, context_parts)
response = generate_qwen_vl(processor, model, prompt, image_path)
else:
response = f"Context: {context}"
gen_time = time.time() - start_time
gen_times.append(gen_time)
results.append(response)
if i < 3:
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
return {
"avg_search_time": sum(search_times) / len(search_times),
"avg_generation_time": sum(gen_times) / len(gen_times),
"results": results,
}

View File

@@ -1,143 +0,0 @@
# Update Benchmarks
This directory hosts two benchmark suites that exercise LEANNs HNSW “update +
search” pipeline under different assumptions:
1. **RNG recompute latency** measure how random-neighbour pruning and cache
settings influence incremental `add()` latency when embeddings are fetched
over the ZMQ embedding server.
2. **Update strategy comparison** compare a fully sequential update pipeline
against an offline approach that keeps the graph static and fuses results.
Both suites build a non-compact, `is_recompute=True` index so that new
embeddings are pulled from the embedding server. Benchmark outputs are written
under `.leann/bench/` by default and appended to CSV files for later plotting.
## Benchmarks
### 1. HNSW RNG Recompute Benchmark
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
changes the forward / reverse RNG pruning flags and whether the embedding cache
is enabled:
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
| ---------------------------------- | ----------- | ----------- | ------------------- |
| `baseline` | Enabled | Enabled | Enabled |
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
For each scenario the script:
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
3. Appends the requested updates using the scenarios RNG flags.
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
timings before appending a row to the CSV output.
**Run:**
```bash
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
LEANN_LOG_LEVEL=INFO \
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--runs 1 \
--index-path .leann/bench/test.leann \
--initial-files data/PrideandPrejudice.txt \
--update-files data/huawei_pangu.md \
--max-initial 300 \
--max-updates 1 \
--add-timeout 120
```
**Output:**
- `benchmarks/update/bench_results.csv` per-scenario timing statistics
(including ms/passage) for each run.
- `.leann/bench/hnsw_server.log` detailed ZMQ/server logs (path controlled by
`LEANN_HNSW_LOG_PATH`).
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
### 2. Sequential vs. Offline Update Benchmark
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
same dataset:
- **Scenario A Sequential Update**
- Start an embedding server.
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
mutates the HNSW graph.
- After all inserts, run a search on the updated graph.
- Metrics recorded: update time (`add_total_s`), post-update search time
(`search_time_s`), combined total (`total_time_s`), and per-passage
latency.
- **Scenario B Offline Embedding + Concurrent Search**
- Stop Scenario As server and start a fresh embedding server.
- Spawn two threads: one generates embeddings for the new passages offline
(graph unchanged); the other computes the query embedding and searches the
existing graph.
- Merge offline similarities with the graph search results to emulate late
fusion, then report the merged topk preview.
- Metrics recorded: embedding time (`emb_time_s`), search time
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
**Run (both scenarios):**
```bash
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 \
--num-updates 1
```
You can pass `--only A` or `--only B` to run a single scenario. The script will
print timing summaries to stdout and append the results to CSV.
**Output:**
- `benchmarks/update/offline_vs_update.csv` per-scenario timing statistics for
Scenario A and B.
- Console output includes Scenario Bs merged topk preview for quick sanity
checks.
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
### 3. Visualisation
`plot_bench_results.py` combines the RNG benchmark and the update strategy
benchmark into a single two-panel plot.
**Run:**
```bash
uv run -m benchmarks.update.plot_bench_results \
--csv benchmarks/update/bench_results.csv \
--csv-right benchmarks/update/offline_vs_update.csv \
--out benchmarks/update/bench_latency_from_csv.png
```
**Options:**
- `--broken-y` Enable a broken Y-axis (default: true when appropriate).
- `--csv` RNG benchmark results CSV (left panel).
- `--csv-right` Update strategy results CSV (right panel).
- `--out` Output image path (PNG/PDF supported).
**Output:**
- `benchmarks/update/bench_latency_from_csv.png` visual comparison of the two
suites.
- `benchmarks/update/bench_latency_from_csv.pdf` PDF version, suitable for
slides/papers.
## Parameters & Environment
### Common CLI Flags
- `--max-initial` Number of initial passages used to seed the index.
- `--max-updates` / `--num-updates` Number of passages to treat as updates.
- `--index-path` Base path (without extension) where the LEANN index is stored.
- `--runs` Number of repetitions (RNG benchmark only).
### Environment Variables
- `LEANN_HNSW_LOG_PATH` File to receive embedding-server logs (optional).
- `LEANN_LOG_LEVEL` Logging verbosity (DEBUG/INFO/WARNING/ERROR).
- `CUDA_VISIBLE_DEVICES` Set to empty string if you want to force CPU
execution of the embedding model.
With these scripts you can easily replicate LEANNs update benchmarks, compare
multiple RNG strategies, and evaluate whether sequential updates or offline
fusion better match your latency/accuracy trade-offs.

View File

@@ -1,16 +0,0 @@
"""Benchmarks for LEANN update workflows."""
# Expose helper to locate repository root for other modules that need it.
from pathlib import Path
def find_repo_root() -> Path:
"""Return the project root containing pyproject.toml."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
return current.parents[1]
__all__ = ["find_repo_root"]

View File

@@ -1,804 +0,0 @@
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
embedding recomputation.
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
so that we build a non-compact ``is_recompute=True`` index, spin up the
standard HNSW embedding server, and measure how long incremental ``add`` takes
when RNG pruning is fully enabled vs. partially/fully disabled.
Example usage (run from the repo root; downloads the model on first run)::
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
--index-path .leann/bench/leann-demo.leann \
--runs 1
You can tweak the input documents with ``--initial-files`` / ``--update-files``
if you want a larger or different workload, and change the embedding model via
``--model-name``.
"""
import argparse
import json
import logging
import os
import pickle
import re
import sys
import time
from pathlib import Path
from typing import Any
import msgpack
import numpy as np
import zmq
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
return [{"text": text, "metadata": {}} for text in paragraphs]
def benchmark_update_with_mode(
index_path: Path,
new_chunks: list[dict[str, Any]],
model_name: str,
embedding_mode: str,
distance_metric: str,
disable_forward_rng: bool,
disable_reverse_rng: bool,
server_port: int,
add_timeout: int,
ef_construction: int,
) -> tuple[float, float]:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
index_file = index_path.parent / f"{index_path.stem}.index"
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in new_chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
model_name,
mode=embedding_mode,
is_build=False,
batch_size=16,
)
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
index.is_recompute = True
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
try:
storage_index.ntotal = index.ntotal
except AttributeError:
pass
try:
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
if ef_construction is not None:
index.hnsw.efConstruction = ef_construction
except AttributeError:
pass
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
logger.info(
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
disable_forward_rng,
disable_reverse_rng,
applied_forward,
applied_reverse,
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
offset_map_backup = offset_map.copy()
try:
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
server_started, actual_port = server_manager.start_server(
port=server_port,
model_name=model_name,
embedding_mode=embedding_mode,
passages_file=str(meta_path),
distance_metric=distance_metric,
)
if not server_started:
raise RuntimeError("Failed to start embedding server.")
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(actual_port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(actual_port)
_warmup_embedding_server(actual_port)
total_start = time.time()
add_elapsed = 0.0
try:
import signal
def _timeout_handler(signum, frame):
raise TimeoutError("incremental add timed out")
if add_timeout > 0:
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(add_timeout)
add_start = time.time()
for i in range(embeddings.shape[0]):
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
add_elapsed = time.time() - add_start
if add_timeout > 0:
signal.alarm(0)
faiss.write_index(index, str(index_file))
finally:
server_manager.stop_server()
except TimeoutError:
raise
except Exception:
if passages_file.exists():
with open(passages_file, "rb+") as f:
f.truncate(rollback_size)
with open(offset_file, "wb") as f:
pickle.dump(offset_map_backup, f)
raise
prune_hnsw_embeddings_inplace(str(index_file))
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
# Reset toggles so the index on disk returns to baseline behaviour.
try:
index.hnsw.set_disable_rng_during_add(False)
index.hnsw.set_disable_reverse_prune(False)
except AttributeError:
pass
faiss.write_index(index, str(index_file))
total_elapsed = time.time() - total_start
return total_elapsed, add_elapsed
def _total_zmq_nodes(log_path: Path) -> int:
if not log_path.exists():
return 0
with log_path.open("r", encoding="utf-8") as log_file:
text = log_file.read()
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
def _warmup_embedding_server(port: int) -> None:
"""Send a dummy REQ so the embedding server loads its model."""
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.REQ)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.connect(f"tcp://127.0.0.1:{port}")
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
sock.send(payload)
try:
sock.recv()
except zmq.error.Again:
pass
finally:
sock.close()
ctx.term()
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/leann-demo.leann"),
help="Output index base path (without extension).",
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
help="Files used to build the initial index.",
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
help="Files appended during the benchmark.",
)
parser.add_argument(
"--runs", type=int, default=1, help="How many times to repeat each scenario."
)
parser.add_argument(
"--model-name",
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model used for build/update.",
)
parser.add_argument(
"--embedding-mode",
default="sentence-transformers",
help="Embedding mode passed to LeannBuilder/embedding server.",
)
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
help="Distance metric for HNSW backend.",
)
parser.add_argument(
"--ef-construction",
type=int,
default=200,
help="efConstruction setting for initial build.",
)
parser.add_argument(
"--server-port",
type=int,
default=5557,
help="Port for the real embedding server.",
)
parser.add_argument(
"--max-initial",
type=int,
default=300,
help="Optional cap on initial passages (after chunking).",
)
parser.add_argument(
"--max-updates",
type=int,
default=1,
help="Optional cap on update passages (after chunking).",
)
parser.add_argument(
"--add-timeout",
type=int,
default=900,
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
)
parser.add_argument(
"--plot-path",
type=Path,
default=Path("bench_latency.png"),
help="Where to save the latency bar plot.",
)
parser.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
)
parser.add_argument(
"--broken-y",
action="store_true",
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
)
parser.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
)
parser.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Where to append per-scenario results as CSV.",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
if not update_paragraphs:
raise ValueError("No update passages found; please provide --update-files with content.")
update_chunks = prepare_new_chunks(update_paragraphs)
ensure_index_dir(args.index_path)
scenarios = [
("baseline", False, False, True),
("no_cache_baseline", False, False, False),
("disable_forward_rng", True, False, True),
("disable_forward_and_reverse_rng", True, True, True),
]
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
log_path.parent.mkdir(parents=True, exist_ok=True)
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
# CSV setup
import csv
run_id = time.strftime("%Y%m%d-%H%M%S")
csv_fields = [
"run_id",
"scenario",
"cache_enabled",
"ef_construction",
"max_initial",
"max_updates",
"total_time_s",
"add_only_s",
"latency_ms_per_passage",
"zmq_nodes",
"stageA_time_s",
"stageBC_time_s",
"model_name",
"embedding_mode",
"distance_metric",
]
# Create CSV with header if missing
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
for run in range(args.runs):
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
print(f"\nScenario: {name}")
cleanup_index_files(args.index_path)
if log_path.exists():
try:
log_path.unlink()
except OSError:
pass
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
prev_size = log_path.stat().st_size if log_path.exists() else 0
try:
total_elapsed, add_elapsed = benchmark_update_with_mode(
args.index_path,
update_chunks,
args.model_name,
args.embedding_mode,
args.distance_metric,
disable_forward,
disable_reverse,
args.server_port,
args.add_timeout,
args.ef_construction,
)
except TimeoutError as exc:
print(f"Scenario {name} timed out: {exc}")
continue
curr_size = log_path.stat().st_size if log_path.exists() else 0
if curr_size < prev_size:
prev_size = 0
zmq_count = 0
if log_path.exists():
with log_path.open("r", encoding="utf-8") as log_file:
log_file.seek(prev_size)
new_entries = log_file.read()
zmq_count = sum(
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
)
stageA = sum(
float(x)
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
)
stageBC = sum(
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
)
else:
stageA = 0.0
stageBC = 0.0
per_chunk = add_elapsed / len(update_chunks)
print(
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
)
print(f"ZMQ node fetch total: {zmq_count}")
results_total[name].append(total_elapsed)
results_add[name].append(add_elapsed)
results_zmq[name].append(zmq_count)
results_ms_per_passage[name].append(per_chunk * 1e3)
results_stageA[name].append(stageA)
results_stageBC[name].append(stageBC)
# Append row to CSV
if args.csv_path:
row = {
"run_id": run_id,
"scenario": name,
"cache_enabled": 1 if cache_enabled else 0,
"ef_construction": args.ef_construction,
"max_initial": args.max_initial,
"max_updates": args.max_updates,
"total_time_s": round(total_elapsed, 6),
"add_only_s": round(add_elapsed, 6),
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
"zmq_nodes": int(zmq_count),
"stageA_time_s": round(stageA, 6),
"stageBC_time_s": round(stageBC, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row)
print("\n=== Summary ===")
for name in results_add:
add_values = results_add[name]
total_values = results_total[name]
zmq_values = results_zmq[name]
latency_values = results_ms_per_passage[name]
if not add_values:
print(f"{name}: no successful runs")
continue
avg_add = sum(add_values) / len(add_values)
avg_total = sum(total_values) / len(total_values)
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
runs = len(add_values)
print(
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
)
if args.plot_path:
try:
import matplotlib.pyplot as plt
labels = [name for name, *_ in scenarios]
values = [
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
if results_ms_per_passage[name]
else 0.0
for name in labels
]
def _auto_cap(vals: list[float]) -> float | None:
s = sorted(vals, reverse=True)
if len(s) < 2:
return None
if s[1] > 0 and s[0] >= 2.5 * s[1]:
return s[1] * 1.1
return None
def _fmt_ms(v: float) -> str:
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
if args.broken_y:
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.4, 5.0),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
)
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i,
v + lower_cap * 0.02,
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False)
ax_bottom.xaxis.tick_bottom()
d = 0.015
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
ax_top.plot((-d, +d), (-d, +d), **kwargs)
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_bottom.set_xticks(range(len(labels)))
ax_bottom.set_xticklabels(labels)
ax = ax_bottom
else:
cap = args.cap_y or _auto_cap(values)
plt.figure(figsize=(7.2, 4.2))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (v, show) in enumerate(zip(values, show_vals)):
b = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(b[0])
if v > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(v),
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
ax.plot(
[0.02 - 0.02, 0.02 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
ax.plot(
[0.98 - 0.02, 0.98 + 0.02],
[0.98 + 0.02, 0.98 - 0.02],
transform=ax.transAxes,
color="k",
lw=1,
)
if any(v > cap for v in values):
ax.legend(
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels)
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
plt.ylabel("Average add latency (ms per passage)")
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
plt.tight_layout()
plt.savefig(args.plot_path)
print(f"Saved latency bar plot to {args.plot_path}")
# ZMQ time split (Stage A vs B/C)
try:
plt.figure(figsize=(6, 4))
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
bc_vals = [
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
]
ind = range(len(labels))
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
plt.bar(
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
)
plt.xticks(list(ind), labels, rotation=10)
plt.ylabel("Server ZMQ time (s)")
plt.title(
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
)
plt.legend()
out2 = args.plot_path.with_name(
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
)
plt.tight_layout()
plt.savefig(out2)
print(f"Saved ZMQ time split plot to {out2}")
except Exception as e:
print("Failed to plot ZMQ split:", e)
except ImportError:
print("matplotlib not available; skipping plot generation")
# leave the last build on disk for inspection
if __name__ == "__main__":
main()

View File

@@ -1,5 +0,0 @@
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario cache_enabled ef_construction max_initial max_updates total_time_s add_only_s latency_ms_per_passage zmq_nodes stageA_time_s stageBC_time_s model_name embedding_mode distance_metric
2 20251024-133101 baseline 1 200 300 1 3.391856 1.120359 1120.359421 126 0.507821 0.601608 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-133101 no_cache_baseline 0 200 300 1 34.941514 32.91376 32913.760185 4033 0.506933 32.159928 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251024-133101 disable_forward_rng 1 200 300 1 2.746756 0.8202 820.200443 66 0.474354 0.338454 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251024-133101 disable_forward_and_reverse_rng 1 200 300 1 2.396566 0.521478 521.478415 1 0.508973 0.006938 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -1,704 +0,0 @@
"""
Compare two latency models for small incremental updates vs. search:
Scenario A (sequential update then search):
- Build initial HNSW (is_recompute=True)
- Start embedding server (ZMQ) for recompute
- Add N passages one-by-one (each triggers recompute over ZMQ)
- Then run a search query on the updated index
- Report total time = sum(add_i) + search_time, with breakdowns
Scenario B (offline embeds + concurrent search; no graph updates):
- Do NOT insert the N passages into the graph
- In parallel: (1) compute embeddings for the N passages; (2) compute query
embedding and run a search on the existing index
- After both finish, compute similarity between the query embedding and the N
new passage embeddings, merge with the index search results by score, and
report time = max(embed_time, search_time) (i.e., no blocking on updates)
This script reuses the model/data loading conventions of
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
comparison for the two execution strategies above.
Example (from the repository root):
uv run -m benchmarks.update.bench_update_vs_offline_search \
--index-path .leann/bench/offline_vs_update.leann \
--max-initial 300 --num-updates 5 --k 10
"""
import argparse
import csv
import json
import logging
import os
import pickle
import sys
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import psutil # type: ignore
from leann.api import LeannBuilder
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
from leann.embedding_compute import compute_embeddings
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_project_directory
from leann_backend_hnsw import faiss # type: ignore
logger = logging.getLogger(__name__)
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
def _find_repo_root() -> Path:
"""Locate project root by walking up until pyproject.toml is found."""
current = Path(__file__).resolve()
for parent in current.parents:
if (parent / "pyproject.toml").exists():
return parent
# Fallback: assume repo is two levels up (../..)
return current.parents[2]
REPO_ROOT = _find_repo_root()
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from apps.chunking import create_text_chunks # noqa: E402
DEFAULT_INITIAL_FILES = [
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
REPO_ROOT / "data" / "huawei_pangu.md",
]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
if limit is not None:
cleaned = cleaned[:limit]
return cleaned
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
distance_metric: str,
ef_construction: int,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=True,
distance_metric=distance_metric,
backend_kwargs={
"distance_metric": distance_metric,
"is_compact": False,
"is_recompute": True,
"efConstruction": ef_construction,
},
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
if metric == "cosine":
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
norms[norms == 0] = 1
vecs = vecs / norms
return vecs
def _read_index_for_search(index_path: Path) -> Any:
index_file = index_path.parent / f"{index_path.stem}.index"
# Force-disable experimental disk cache when loading the index so that
# incremental benchmarks don't pick up stale top-degree bitmaps.
cfg = faiss.HNSWIndexConfig()
cfg.is_recompute = True
if hasattr(cfg, "disk_cache_ratio"):
cfg.disk_cache_ratio = 0.0
if hasattr(cfg, "external_storage_path"):
cfg.external_storage_path = None
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
index = faiss.read_index(str(index_file), io_flags, cfg)
# ensure recompute mode persists after reload
try:
index.is_recompute = True
except AttributeError:
pass
try:
actual_ntotal = index.hnsw.levels.size()
except AttributeError:
actual_ntotal = index.ntotal
if actual_ntotal != index.ntotal:
print(
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
flush=True,
)
index.ntotal = actual_ntotal
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
return index
def _append_passages_for_updates(
meta_path: Path,
start_id: int,
texts: list[str],
) -> list[str]:
"""Append update passages so the embedding server can serve recompute fetches."""
if not texts:
return []
index_dir = meta_path.parent
meta_name = meta_path.name
if not meta_name.endswith(".meta.json"):
raise ValueError(f"Unexpected meta filename: {meta_path}")
index_base = meta_name[: -len(".meta.json")]
passages_file = index_dir / f"{index_base}.passages.jsonl"
offsets_file = index_dir / f"{index_base}.passages.idx"
if not passages_file.exists() or not offsets_file.exists():
raise FileNotFoundError(
"Passage store missing; cannot register update passages for recompute mode."
)
with open(offsets_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
assigned_ids: list[str] = []
with open(passages_file, "a", encoding="utf-8") as f:
for i, text in enumerate(texts):
passage_id = str(start_id + i)
offset = f.tell()
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
f.write("\n")
offset_map[passage_id] = offset
assigned_ids.append(passage_id)
with open(offsets_file, "wb") as f:
pickle.dump(offset_map, f)
try:
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
except json.JSONDecodeError:
meta = {}
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
return assigned_ids
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
q = np.ascontiguousarray(q, dtype=np.float32)
distances = np.zeros((1, k), dtype=np.float32)
indices = np.zeros((1, k), dtype=np.int64)
index.search(
1,
faiss.swig_ptr(q),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
return distances[0], indices[0]
def _score_for_metric(dist: float, metric: str) -> float:
# Convert FAISS distance to a "higher is better" score
if metric in ("mips", "cosine"):
return float(dist)
# l2 distance (smaller better) -> negative distance as score
return -float(dist)
def _merge_results(
index_results: tuple[np.ndarray, np.ndarray],
offline_scores: list[tuple[int, float]],
k: int,
metric: str,
) -> list[tuple[str, float]]:
distances, indices = index_results
merged: list[tuple[str, float]] = []
for distance, idx in zip(distances.tolist(), indices.tolist()):
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
for j, s in offline_scores:
merged.append((f"offline:{j}", s))
merged.sort(key=lambda x: x[1], reverse=True)
return merged[:k]
@dataclass
class ScenarioResult:
name: str
update_total_s: float
search_s: float
overall_s: float
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/bench/offline-vs-update.leann"),
)
parser.add_argument(
"--initial-files",
nargs="*",
type=Path,
default=DEFAULT_INITIAL_FILES,
)
parser.add_argument(
"--update-files",
nargs="*",
type=Path,
default=DEFAULT_UPDATE_FILES,
)
parser.add_argument("--max-initial", type=int, default=300)
parser.add_argument("--num-updates", type=int, default=5)
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
parser.add_argument(
"--query",
type=str,
default="neural network",
help="Query text used for the search benchmark.",
)
parser.add_argument("--server-port", type=int, default=5557)
parser.add_argument("--add-timeout", type=int, default=600)
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--embedding-mode", default="sentence-transformers")
parser.add_argument(
"--distance-metric",
default="mips",
choices=["mips", "l2", "cosine"],
)
parser.add_argument("--ef-construction", type=int, default=200)
parser.add_argument(
"--only",
choices=["A", "B", "both"],
default="both",
help="Run only Scenario A, Scenario B, or both",
)
parser.add_argument(
"--csv-path",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Where to append results (CSV).",
)
args = parser.parse_args()
register_project_directory(REPO_ROOT)
# Load data
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
update_paragraphs = load_chunks_from_files(args.update_files, None)
if not update_paragraphs:
raise ValueError("No update passages loaded from --update-files")
update_paragraphs = update_paragraphs[: args.num_updates]
if len(update_paragraphs) < args.num_updates:
raise ValueError(
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
)
ensure_index_dir(args.index_path)
cleanup_index_files(args.index_path)
# Build initial index
build_initial_index(
args.index_path,
initial_paragraphs,
args.model_name,
args.embedding_mode,
args.distance_metric,
args.ef_construction,
)
# Prepare index object and meta
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
index = _read_index_for_search(args.index_path)
# CSV setup
run_id = time.strftime("%Y%m%d-%H%M%S")
if args.csv_path:
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
csv_fields = [
"run_id",
"scenario",
"max_initial",
"num_updates",
"k",
"total_time_s",
"add_total_s",
"search_time_s",
"emb_time_s",
"makespan_s",
"model_name",
"embedding_mode",
"distance_metric",
]
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writeheader()
# Debug: list existing HNSW server PIDs before starting
try:
existing = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if existing:
print("[debug] Found existing hnsw_embedding_server processes before run:")
for p in existing:
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
except Exception as _e:
pass
add_total = 0.0
search_after_add = 0.0
total_seq = 0.0
port_a = None
if args.only in ("A", "both"):
# Scenario A: sequential update then search
start_id = index.ntotal
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
if assigned_ids:
logger.debug(
"Registered %d update passages starting at id %s",
len(assigned_ids),
assigned_ids[0],
)
server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
ok, port = server_manager.start_server(
port=args.server_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok:
raise RuntimeError("Failed to start embedding server")
try:
# Set ZMQ port for recompute mode
if hasattr(index.hnsw, "set_zmq_port"):
index.hnsw.set_zmq_port(port)
elif hasattr(index, "set_zmq_port"):
index.set_zmq_port(port)
# Start A overall timer BEFORE computing update embeddings
t0 = time.time()
# Compute embeddings for updates (counted into A's overall)
t_emb0 = time.time()
upd_embs = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time_updates = time.time() - t_emb0
upd_embs = np.asarray(upd_embs, dtype=np.float32)
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
# Perform sequential adds
for i in range(upd_embs.shape[0]):
t_add0 = time.time()
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
add_total += time.time() - t_add0
# Don't persist index after adds to avoid contaminating Scenario B
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
# faiss.write_index(index, str(index_file))
# Search after updates
q_emb = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_emb = np.asarray(q_emb, dtype=np.float32)
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
# Warm up search with a dummy query first
print("[DEBUG] Warming up search...")
_ = _search(index, q_emb, 1)
t_s0 = time.time()
D_upd, I_upd = _search(index, q_emb, args.k)
search_after_add = time.time() - t_s0
total_seq = time.time() - t0
finally:
server_manager.stop_server()
port_a = port
print("\n=== Scenario A: update->search (sequential) ===")
# emb_time_updates is defined only when A runs
try:
_emb_a = emb_time_updates
except NameError:
_emb_a = 0.0
print(
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
)
# CSV row for A
if args.csv_path:
row_a = {
"run_id": run_id,
"scenario": "A",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": round(total_seq, 6),
"add_total_s": round(add_total, 6),
"search_time_s": round(search_after_add, 6),
"emb_time_s": round(_emb_a, 6),
"makespan_s": 0.0,
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_a)
# Verify server cleanup
try:
# short sleep to allow signal handling to finish
time.sleep(0.5)
leftovers = [
p
for p in psutil.process_iter(attrs=["pid", "cmdline"])
if any(
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
for arg in (p.info.get("cmdline") or [])
)
]
if leftovers:
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
for p in leftovers:
print(
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
)
else:
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
except Exception:
pass
# Scenario B: offline embeds + concurrent search (no graph updates)
if args.only in ("B", "both"):
# ensure a server is available for recompute search
server_manager_b = EmbeddingServerManager(
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
requested_port = args.server_port if port_a is None else port_a
ok_b, port_b = server_manager_b.start_server(
port=requested_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
passages_file=str(meta_path),
distance_metric=args.distance_metric,
)
if not ok_b:
raise RuntimeError("Failed to start embedding server for Scenario B")
# Wait for server to fully initialize
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
time.sleep(2)
try:
# Read the index first
index_no_update = _read_index_for_search(args.index_path) # unchanged index
# Then configure ZMQ port on the correct index object
if hasattr(index_no_update.hnsw, "set_zmq_port"):
index_no_update.hnsw.set_zmq_port(port_b)
elif hasattr(index_no_update, "set_zmq_port"):
index_no_update.set_zmq_port(port_b)
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
logger.info("Warming up embedding model for Scenario B...")
_ = compute_embeddings(
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
)
# Prepare worker A: compute embeddings for the same N passages
emb_time = 0.0
updates_embs_offline: np.ndarray | None = None
def _worker_emb():
nonlocal emb_time, updates_embs_offline
t = time.time()
updates_embs_offline = compute_embeddings(
update_paragraphs,
args.model_name,
mode=args.embedding_mode,
is_build=False,
batch_size=16,
)
emb_time = time.time() - t
# Pre-compute query embedding and warm up search outside of timed section.
q_vec = compute_embeddings(
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
)
q_vec = np.asarray(q_vec, dtype=np.float32)
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
print("[DEBUG B] Warming up search...")
_ = _search(index_no_update, q_vec, 1)
# Worker B: timed search on the warmed index
search_time = 0.0
offline_elapsed = 0.0
index_results: tuple[np.ndarray, np.ndarray] | None = None
def _worker_search():
nonlocal search_time, index_results
t = time.time()
distances, indices = _search(index_no_update, q_vec, args.k)
search_time = time.time() - t
index_results = (distances, indices)
# Run two workers concurrently
t0 = time.time()
th1 = threading.Thread(target=_worker_emb)
th2 = threading.Thread(target=_worker_search)
th1.start()
th2.start()
th1.join()
th2.join()
offline_elapsed = time.time() - t0
# For mixing: compute query vs. offline update similarities (pure client-side)
offline_scores: list[tuple[int, float]] = []
if updates_embs_offline is not None:
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
for j in range(upd2.shape[0]):
if args.distance_metric in ("mips", "cosine"):
s = float(np.dot(q_vec[0], upd2[j]))
else:
diff = q_vec[0] - upd2[j]
s = -float(np.dot(diff, diff))
offline_scores.append((j, s))
merged_topk = (
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
if index_results
else []
)
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
print(
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
)
if merged_topk:
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
print(f"Merged top-5 preview: {preview}")
# CSV row for B
if args.csv_path:
row_b = {
"run_id": run_id,
"scenario": "B",
"max_initial": args.max_initial,
"num_updates": args.num_updates,
"k": args.k,
"total_time_s": 0.0,
"add_total_s": 0.0,
"search_time_s": round(search_time, 6),
"emb_time_s": round(emb_time, 6),
"makespan_s": round(offline_elapsed, 6),
"model_name": args.model_name,
"embedding_mode": args.embedding_mode,
"distance_metric": args.distance_metric,
}
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=csv_fields)
writer.writerow(row_b)
finally:
server_manager_b.stop_server()
# Summary
print("\n=== Summary ===")
msg_a = (
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
if args.only in ("A", "both")
else "A: skipped"
)
msg_b = (
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
if args.only in ("B", "both")
else "B: skipped"
)
print(msg_a + "\n" + msg_b)
if __name__ == "__main__":
main()

View File

@@ -1,5 +0,0 @@
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
1 run_id scenario max_initial num_updates k total_time_s add_total_s search_time_s emb_time_s makespan_s model_name embedding_mode distance_metric
2 20251024-141607 A 300 1 10 3.273957 3.050168 0.097825 0.017339 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
3 20251024-141607 B 300 1 10 0.0 0.0 0.111892 0.007869 0.112635 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
4 20251025-160652 A 300 5 10 5.061945 4.805962 0.123271 0.015008 0.0 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips
5 20251025-160652 B 300 5 10 0.0 0.0 0.101809 0.008817 0.102447 sentence-transformers/all-MiniLM-L6-v2 sentence-transformers mips

View File

@@ -1,645 +0,0 @@
#!/usr/bin/env python3
"""
Plot latency bars from the benchmark CSV produced by
benchmarks/update/bench_hnsw_rng_recompute.py.
If you also provide an offline_vs_update.csv via --csv-right
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
output a side-by-side figure:
- Left: ms/passage bars (four RNG scenarios).
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
Usage:
uv run python benchmarks/update/plot_bench_results.py \
--csv benchmarks/update/bench_results.csv \
--out benchmarks/update/bench_latency_from_csv.png
The script selects the latest run_id in the CSV and plots four bars for
the default scenarios:
- baseline
- no_cache_baseline
- disable_forward_rng
- disable_forward_and_reverse_rng
If multiple rows exist per scenario for that run_id, the script averages
their latency_ms_per_passage values.
"""
import argparse
import csv
from collections import defaultdict
from pathlib import Path
DEFAULT_SCENARIOS = [
"no_cache_baseline",
"baseline",
"disable_forward_rng",
"disable_forward_and_reverse_rng",
]
SCENARIO_LABELS = {
"baseline": "+ Cache",
"no_cache_baseline": "Naive \n Recompute",
"disable_forward_rng": "+ w/o \n Fwd RNG",
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
}
# Paper-style colors and hatches for scenarios
SCENARIO_STYLES = {
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
}
def load_latest_run(csv_path: Path):
rows = []
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
if not rows:
raise SystemExit("CSV is empty: no rows to plot")
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
run_ids = [r.get("run_id", "") for r in rows]
latest = max(run_ids)
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
if not latest_rows:
# Fallback: take last 4 rows
latest_rows = rows[-4:]
latest = latest_rows[-1].get("run_id", "unknown")
return latest, latest_rows
def aggregate_latency(rows):
acc = defaultdict(list)
for r in rows:
sc = r.get("scenario", "")
try:
val = float(r.get("latency_ms_per_passage", "nan"))
except ValueError:
continue
acc[sc].append(val)
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
return avg
def _auto_cap(values: list[float]) -> float | None:
if not values:
return None
sorted_vals = sorted(values, reverse=True)
if len(sorted_vals) < 2:
return None
max_v, second = sorted_vals[0], sorted_vals[1]
if second <= 0:
return None
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
if max_v >= 2.5 * second:
return second * 1.1
return None
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
# Draw small diagonal ticks near left/right to signal cap
x0, x1 = rel_x0, rel_x1
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
def _fmt_ms(v: float) -> str:
if v >= 1000:
return f"{v / 1000:.1f}k"
return f"{v:.1f}"
def main():
# Set LaTeX style for paper figures (matching paper_fig.py)
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument(
"--csv",
type=Path,
default=Path("benchmarks/update/bench_results.csv"),
help="Path to results CSV (defaults to bench_results.csv)",
)
ap.add_argument(
"--out",
type=Path,
default=Path("add_ablation.pdf"),
help="Output image path",
)
ap.add_argument(
"--csv-right",
type=Path,
default=Path("benchmarks/update/offline_vs_update.csv"),
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
)
ap.add_argument(
"--cap-y",
type=float,
default=None,
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
)
ap.add_argument(
"--no-auto-cap",
action="store_true",
help="Disable auto-cap heuristic when --cap-y is not provided.",
)
ap.add_argument(
"--broken-y",
action="store_true",
default=True,
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
)
ap.add_argument(
"--lower-cap-y",
type=float,
default=None,
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
)
ap.add_argument(
"--upper-start-y",
type=float,
default=None,
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
)
args = ap.parse_args()
latest_run, latest_rows = load_latest_run(args.csv)
avg = aggregate_latency(latest_rows)
try:
import matplotlib.pyplot as plt
except Exception as e:
raise SystemExit(f"matplotlib not available: {e}")
scenarios = DEFAULT_SCENARIOS
values = [avg.get(name, 0.0) for name in scenarios]
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
# If right CSV is provided, build side-by-side figure
if args.csv_right is not None:
try:
right_rows_all = []
with args.csv_right.open("r", encoding="utf-8") as f:
rreader = csv.DictReader(f)
right_rows_all = list(rreader)
if right_rows_all:
r_latest = max(r.get("run_id", "") for r in right_rows_all)
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
else:
r_latest = None
right_rows = []
except Exception:
r_latest = None
right_rows = []
a_total = 0.0
b_makespan = 0.0
for r in right_rows:
sc = (r.get("scenario", "") or "").strip().upper()
if sc == "A":
try:
a_total = float(r.get("total_time_s", 0.0))
except Exception:
pass
elif sc == "B":
try:
b_makespan = float(r.get("makespan_s", 0.0))
except Exception:
pass
import matplotlib.pyplot as plt
from matplotlib import gridspec
# Left subplot (reuse current style, with optional cap)
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
x = list(range(len(labels)))
if args.broken_y:
# Use broken axis for left subplot
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
gs = gridspec.GridSpec(
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
)
ax_left_top = fig.add_subplot(gs[0, 0])
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
ax_right = fig.add_subplot(gs[:, 1])
# Determine break points
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = (
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
) # Increased to show more range
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.5, lower_cap * 1.02)
)
ymax = (
max(values) * 1.90 if values else 1.0
) # Increase headroom to 1.90 for text label and tick range
# Draw bars on both axes
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Set limits
ax_left_bottom.set_ylim(0, lower_cap)
ax_left_top.set_ylim(upper_start, ymax)
# Annotate values (convert ms to s)
values_s = [v / 1000.0 for v in values]
lower_cap_s = lower_cap / 1000.0
upper_start_s = upper_start / 1000.0
ymax_s = ymax / 1000.0
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
ax_left_bottom.clear()
ax_left_top.clear()
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
# Draw in bottom axis for all bars
ax_left_bottom.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
# Only draw in top axis if the bar is tall enough to reach the upper range
if v > upper_start_s:
ax_left_top.bar(
i,
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
ax_left_bottom.set_ylim(0, lower_cap_s)
ax_left_top.set_ylim(upper_start_s, ymax_s)
for i, v in enumerate(values_s):
if v <= lower_cap_s:
ax_left_bottom.text(
i,
v + lower_cap_s * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
else:
ax_left_top.text(
i,
v + (ymax_s - upper_start_s) * 0.02,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
# Hide spines between axes
ax_left_top.spines["bottom"].set_visible(False)
ax_left_bottom.spines["top"].set_visible(False)
ax_left_top.tick_params(
labeltop=False, labelbottom=False, bottom=False
) # Hide tick marks
ax_left_bottom.xaxis.tick_bottom()
ax_left_bottom.tick_params(top=False) # Hide top tick marks
# Draw break marks (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_left_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
kwargs.update({"transform": ax_left_bottom.transAxes})
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
ax_left_bottom.set_xticks(x)
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_left_bottom.tick_params(axis="y", labelsize=10)
ax_left_top.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match bar width with right subplot
ax_left_bottom.set_xlim(-0.6, 3.6)
ax_left_top.set_xlim(-0.6, 3.6)
ax_left = ax_left_bottom # for compatibility
else:
# Regular side-by-side layout
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
for i, (val, show) in enumerate(zip(values, show_vals)):
if val > cap:
bars[i].set_hatch("//")
ax_left.text(
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
)
else:
ax_left.text(
i,
show + max(1.0, 0.01 * (cap or show)),
_fmt_ms(val),
ha="center",
va="bottom",
fontsize=9,
)
ax_left.set_ylim(0, cap * 1.10)
_add_break_marker(ax_left, y=0.98)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
else:
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
for i, v in enumerate(values):
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
ax_left.set_xticks(x)
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
ax_left.set_ylabel("Latency (ms per passage)")
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
ax_left.set_title(
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
)
# Right subplot (A vs B, seconds) - paper style
r_labels = ["Sequential", "Delayed \n Add+Search"]
r_values = [a_total or 0.0, b_makespan or 0.0]
r_styles = [
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
{"edgecolor": "#edc948", "hatch": "/////"},
]
# 2 bars, centered with proper spacing
xr = [0, 1]
bar_width = 0.50 # Reduced for wider spacing between bars
for i, (v, style) in enumerate(zip(r_values, r_styles)):
ax_right.bar(
xr[i],
v,
width=bar_width,
color="white",
edgecolor=style["edgecolor"],
hatch=style["hatch"],
linewidth=1.2,
)
for i, v in enumerate(r_values):
max_v = max(r_values) if r_values else 1.0
offset = max(0.0002, 0.02 * max_v)
ax_right.text(
xr[i],
v + offset,
f"{v:.2f}",
ha="center",
va="bottom",
fontsize=8,
fontweight="bold",
)
ax_right.set_xticks(xr)
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
# Don't set ylabel here - will use fig.text for alignment
ax_right.tick_params(axis="y", labelsize=10)
# Add subtle grid for better readability
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
# Set x-axis limits to match left subplot's bar width visually
# Accounting for width_ratios=[1.5, 1]:
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
# Right: 2 bars, need same visual width
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
# range_right = 4.2 / 1.5 = 2.8
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
ax_right.set_xlim(-0.9, 1.9)
# Set y-axis limit with headroom for text labels
if r_values:
max_v = max(r_values)
ax_right.set_ylim(0, max_v * 1.15)
# Format y-axis to avoid scientific notation
ax_right.ticklabel_format(style="plain", axis="y")
plt.tight_layout()
# Add aligned ylabels using fig.text (after tight_layout)
# Get the vertical center of the entire figure
fig_center_y = 0.5
# Left ylabel - closer to left plot
left_x = 0.05
fig.text(
left_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
# Right ylabel - closer to right plot
right_bbox = ax_right.get_position()
right_x = right_bbox.x0 - 0.07
fig.text(
right_x,
fig_center_y,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
return
# Broken-Y mode
if args.broken_y:
import matplotlib.pyplot as plt
fig, (ax_top, ax_bottom) = plt.subplots(
2,
1,
sharex=True,
figsize=(7.5, 6.75),
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
)
# Determine default breaks from second-highest
s = sorted(values, reverse=True)
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
upper_start = (
args.upper_start_y
if args.upper_start_y is not None
else max(second * 1.2, lower_cap * 1.02)
)
ymax = max(values) * 1.10 if values else 1.0
x = list(range(len(labels)))
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
# Limits
ax_bottom.set_ylim(0, lower_cap)
ax_top.set_ylim(upper_start, ymax)
# Annotate values
for i, v in enumerate(values):
if v <= lower_cap:
ax_bottom.text(
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
)
else:
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
# Hide spines between axes and draw diagonal break marks
ax_top.spines["bottom"].set_visible(False)
ax_bottom.spines["top"].set_visible(False)
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
ax_bottom.xaxis.tick_bottom()
# Diagonal lines at the break (matching paper_fig.py style)
d = 0.015
kwargs = {
"transform": ax_top.transAxes,
"color": "k",
"clip_on": False,
"linewidth": 0.8,
"zorder": 10,
}
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
kwargs.update({"transform": ax_bottom.transAxes})
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
ax_bottom.set_xticks(x)
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
ax = ax_bottom # for labeling below
else:
cap = args.cap_y
if cap is None and not args.no_auto_cap:
cap = _auto_cap(values)
plt.figure(figsize=(5.4, 3.15))
ax = plt.gca()
if cap is not None:
show_vals = [min(v, cap) for v in values]
bars = []
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
bar = ax.bar(i, show, color=colors[i], width=0.8)
bars.append(bar[0])
# Hatch and annotate when capped
if val > cap:
bars[-1].set_hatch("//")
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
else:
ax.text(
i,
show + max(1.0, 0.01 * (cap or show)),
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=9,
)
ax.set_ylim(0, cap * 1.10)
_add_break_marker(ax, y=0.98)
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
v > cap for v in values
) else None
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
else:
ax.bar(labels, values, color=colors[: len(labels)])
for idx, val in enumerate(values):
ax.text(
idx,
val + 1.0,
f"{_fmt_ms(val)}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
# Try to extract some context for title
max_initial = latest_rows[0].get("max_initial", "?")
max_updates = latest_rows[0].get("max_updates", "?")
if args.broken_y:
fig.text(
0.02,
0.5,
"Latency (s)",
va="center",
rotation="vertical",
fontsize=11,
fontweight="bold",
)
fig.suptitle(
"Add Operation Latency",
fontsize=11,
y=0.98,
fontweight="bold",
)
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
else:
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
plt.tight_layout()
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
# Also save PDF for paper
pdf_out = args.out.with_suffix(".pdf")
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
print(f"Saved: {args.out}")
print(f"Saved: {pdf_out}")
if __name__ == "__main__":
main()

82
data/.gitattributes vendored Normal file
View File

@@ -0,0 +1,82 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.lz4 filter=lfs diff=lfs merge=lfs -text
*.mds filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
# Audio files - uncompressed
*.pcm filter=lfs diff=lfs merge=lfs -text
*.sam filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
# Audio files - compressed
*.aac filter=lfs diff=lfs merge=lfs -text
*.flac filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
# Image files - uncompressed
*.bmp filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
# Image files - compressed
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.webp filter=lfs diff=lfs merge=lfs -text
# Video files - compressed
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text

0
benchmarks/data/README.md → data/README.md Executable file → Normal file
View File

View File

@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
### Setup Pre-commit
1. **Install pre-commit tools**:
1. **Install pre-commit** (already included when you run `uv sync`):
```bash
uv sync lint
uv pip install pre-commit
```
2. **Install the git hooks**:
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
3. **Run pre-commit manually** (optional):
```bash
uv run pre-commit run --all-files
pre-commit run --all-files
```
### Pre-commit Checks
@@ -85,9 +85,6 @@ Our pre-commit configuration includes:
### Running Tests
```bash
# Install test tools only (no project runtime)
uv sync --group test
# Run all tests
uv run pytest

View File

@@ -1,123 +0,0 @@
# Thinking Budget Feature Implementation
## Overview
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
## Feature Description
The thinking budget feature provides three levels of computational effort for reasoning models:
- **`low`**: Fast responses, basic reasoning (default for simple queries)
- **`medium`**: Balanced speed and reasoning depth
- **`high`**: Maximum reasoning effort, best for complex analytical questions
## Implementation Details
### 1. Command Line Interface
Added `--thinking-budget` parameter to both CLI and RAG examples:
```bash
# LEANN CLI
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# RAG Examples
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
```
### 2. LLM Backend Support
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
options.pop("thinking_budget", None)
if thinking_budget in ["low", "medium", "high"]:
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
```
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
params["reasoning_effort"] = thinking_budget
```
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
### 3. Parameter Propagation
The thinking budget parameter is properly propagated through the LEANN architecture:
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
4. **LLM Interface**: Handles the parameter in backend-specific implementations
## Files Modified
### Core Implementation
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
### Documentation
- `README.md`: Added thinking budget parameter to usage examples
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
### Examples
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
## Usage Examples
### Basic Usage
```bash
# High reasoning effort for complex questions
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# Medium reasoning for balanced performance
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
# Low reasoning for fast responses
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
```
### RAG Examples
```bash
# Email RAG with high reasoning
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
# Document RAG with medium reasoning
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
```
## Supported Models
### Ollama Models
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
### OpenAI Models
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
- **GPT-OSS models**: Models that support reasoning capabilities
## Testing
The implementation includes comprehensive testing:
- Parameter handling verification
- Backend-specific API format validation
- CLI argument parsing tests
- Integration with existing LEANN architecture

View File

@@ -1,143 +0,0 @@
# AST-Aware Code chunking guide
## Overview
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
## Quick Start
### Basic Usage
```bash
# Enable AST chunking for mixed content (code + docs)
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
# Specialized code repository indexing
python -m apps.code_rag --repo-dir ./my_codebase
# Global CLI with AST support
leann build my-code-index --docs ./src --use-ast-chunking
```
### Installation
```bash
# Install LEANN with AST chunking support
uv pip install -e "."
```
#### For normal users (PyPI install)
- Use `pip install leann` or `uv pip install leann`.
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
#### For developers (from source, editable)
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
## Best Practices
### When to Use AST Chunking
**Recommended for:**
- Code repositories with multiple languages
- Mixed documentation and code content
- Complex codebases with deep function/class hierarchies
- When working with Claude Code for code assistance
**Not recommended for:**
- Pure text documents
- Very large files (>1MB)
- Languages not supported by tree-sitter
### Optimal Configuration
```bash
# Recommended settings for most codebases
python -m apps.code_rag \
--repo-dir ./src \
--ast-chunk-size 768 \
--ast-chunk-overlap 96 \
--exclude-dirs .git __pycache__ node_modules build dist
```
### Supported Languages
| Extension | Language | Status |
|-----------|----------|--------|
| `.py` | Python | ✅ Full support |
| `.java` | Java | ✅ Full support |
| `.cs` | C# | ✅ Full support |
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
## Integration Examples
### Document RAG with Code Support
```python
# Enable code chunking in document RAG
python -m apps.document_rag \
--enable-code-chunking \
--data-dir ./project \
--query "How does authentication work in the codebase?"
```
### Claude Code Integration
When using with Claude Code MCP server, AST chunking provides better context for:
- Code completion and suggestions
- Bug analysis and debugging
- Architecture understanding
- Refactoring assistance
## Troubleshooting
### Common Issues
1. **Fallback to Traditional Chunking**
- Normal behavior for unsupported languages
- Check logs for specific language support
2. **Performance with Large Files**
- Adjust `--max-file-size` parameter
- Use `--exclude-dirs` to skip unnecessary directories
3. **Quality Issues**
- Try different `--ast-chunk-size` values (512, 768, 1024)
- Adjust overlap for better context preservation
### Debug Mode
```bash
export LEANN_LOG_LEVEL=DEBUG
python -m apps.code_rag --repo-dir ./my_code
```
## Migration from Traditional Chunking
Existing workflows continue to work without changes. To enable AST chunking:
```bash
# Before
python -m apps.document_rag --chunk-size 256
# After (maintains traditional chunking for non-code files)
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
```
## References
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
- [Research Paper](https://arxiv.org/html/2506.15655v1)
---
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.

View File

@@ -1,459 +0,0 @@
# LEANN Configuration Guide
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
## Getting Started: Simple is Better
When first trying LEANN, start with a small dataset to quickly validate your approach:
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
```bash
python -m apps.document_rag --query "What techniques does LEANN use?"
```
**For other data sources**: Limit the dataset size for quick testing
```bash
# WeChat: Test with recent messages only
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
# Browser history: Last few days
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
# Email: Recent inbox
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
```
Once validated, scale up gradually:
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
- This helps identify issues early before committing to long processing times
## Embedding Model Selection: Understanding the Trade-offs
Based on our experience developing LEANN, embedding models fall into three categories:
### Small Models (< 100M parameters)
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
- **Pros**: Lightweight, fast for both indexing and inference
- **Cons**: Lower semantic understanding, may miss nuanced relationships
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
### Medium Models (100M-500M parameters)
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
- **Pros**: Balanced performance, good multilingual support, reasonable speed
- **Cons**: Requires more compute than small models
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
### Large Models (500M+ parameters)
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
- **Cons**: Slower inference, longer index build times
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
### Quick Start: Cloud and Local Embedding Options
**OpenAI Embeddings (Fastest Setup)**
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
```bash
# Set OpenAI embeddings (requires OPENAI_API_KEY)
--embedding-mode openai --embedding-model text-embedding-3-small
```
**Ollama Embeddings (Privacy-Focused)**
For local embeddings with complete privacy:
```bash
# First, pull an embedding model
ollama pull nomic-embed-text
# Use Ollama embeddings
--embedding-mode ollama --embedding-model nomic-embed-text
```
<details>
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
**OpenAI Embeddings** (`text-embedding-3-small/large`)
- **Pros**: No local compute needed, consistently fast, high quality
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- **When to use**: Prototyping, non-sensitive data, need immediate results
**Local Embeddings**
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
- **Cons**: Slower than cloud APIs, requires local compute resources
- **When to use**: Production systems, sensitive data, cost-sensitive applications
</details>
## Local & Remote Inference Endpoints
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint either on the same machine or across the network with a couple of flags or environment variables.
### One-Time Environment Setup
```bash
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
export OPENAI_BASE_URL="http://localhost:1234/v1"
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
```
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
### Passing Hosts Per Command
```bash
# Build an index with a remote embedding server
leann build my-notes \
--docs ./notes \
--embedding-mode openai \
--embedding-model text-embedding-qwen3-embedding-0.6b \
--embedding-api-base http://192.168.1.50:1234/v1 \
--embedding-api-key local-dev-key
# Query using a local LM Studio instance via OpenAI-compatible API
leann ask my-notes \
--llm openai \
--llm-model qwen3-8b \
--api-base http://localhost:1234/v1 \
--api-key local-dev-key
# Query an Ollama instance running on another box
leann ask my-notes \
--llm ollama \
--llm-model qwen3:14b \
--host http://192.168.1.101:11434
```
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
- Configure router or cloud provider port forwarding.
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
**Tip:** If your runtime does not require an API key (many local stacks dont), leave `--api-key` unset. LEANN will skip injecting credentials.
### Python API Usage
You can pass the same configuration from Python:
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
backend_name="hnsw",
embedding_mode="openai",
embedding_model="text-embedding-qwen3-embedding-0.6b",
embedding_options={
"base_url": "http://192.168.1.50:1234/v1",
"api_key": "local-dev-key",
},
)
builder.build_index("./indexes/my-notes", chunks)
```
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
- Full recomputation required
- High memory usage during build phase
- Excellent recall (95%+)
```bash
# Optimal for most use cases
--backend-name hnsw --graph-degree 32 --build-complexity 64
```
### DiskANN
**Best for**: Large datasets, especially when you want `recompute=True`.
**Key advantages:**
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
- **Better scaling**: Designed for 100k+ documents
**Recompute behavior:**
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
```bash
# Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64
```
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison
### LLM Engines
**OpenAI** (`--llm openai`)
- **Pros**: Best quality, consistent performance, no local resources needed
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
**Ollama** (`--llm ollama`)
- **Pros**: Fully local, free, privacy-preserving, good model variety
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
**HuggingFace** (`--llm hf`)
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
- **Cons**: More complex initial setup
- **Models**: `Qwen/Qwen3-1.7B-FP8`
## Parameter Tuning Guide
### Search Complexity Parameters
**`--build-complexity`** (index building)
- Controls thoroughness during index construction
- Higher = better recall but slower build
- Recommendations:
- 32: Quick prototyping
- 64: Balanced (default)
- 128: Production systems
- 256: Maximum quality
**`--search-complexity`** (query time)
- Controls search thoroughness
- Higher = better results but slower
- Recommendations:
- 16: Fast/Interactive search
- 32: High quality with diversity
- 64+: Maximum accuracy
### Top-K Selection
**`--top-k`** (number of retrieved chunks)
- More chunks = better context but slower LLM processing
- Should be always smaller than `--search-complexity`
- Guidelines:
- 10-20: General questions (default: 20)
- 30+: Complex multi-hop reasoning requiring comprehensive context
**Trade-off formula**:
- Retrieval time ∝ log(n) × search_complexity
- LLM processing time ∝ top_k × chunk_size
- Total context = top_k × chunk_size tokens
### Thinking Budget for Reasoning Models
**`--thinking-budget`** (reasoning effort level)
- Controls the computational effort for reasoning models
- Options: `low`, `medium`, `high`
- Guidelines:
- `low`: Fast responses, basic reasoning (default for simple queries)
- `medium`: Balanced speed and reasoning depth
- `high`: Maximum reasoning effort, best for complex analytical questions
- **Supported Models**:
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
- **Example**: `--thinking-budget high` for complex analytical questions
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
**💡 Quick Examples:**
```bash
# OpenAI o-series reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm openai --llm-model o3 --thinking-budget medium
# Ollama reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
```
### Graph Degree (HNSW/DiskANN)
**`--graph-degree`**
- Number of connections per node in the graph
- Higher = better recall but more memory
- HNSW: 16-32 (default: 32)
- DiskANN: 32-128 (default: 64)
## Performance Optimization Checklist
### If Embedding is Too Slow
1. **Switch to smaller model**:
```bash
# From large model
--embedding-model Qwen/Qwen3-Embedding-0.6B
# To small model
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
2. **Limit dataset size for testing**:
```bash
--max-items 1000 # Process first 1k items only
```
3. **Use MLX on Apple Silicon** (optional optimization):
```bash
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
```
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
4. **Use Ollama**
```bash
--embedding-mode ollama --embedding-model nomic-embed-text
```
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
### If Search Quality is Poor
1. **Increase retrieval count**:
```bash
--top-k 30 # Retrieve more candidates
```
2. **Upgrade embedding model**:
```bash
# For English
--embedding-model BAAI/bge-base-en-v1.5
# For multilingual
--embedding-model intfloat/multilingual-e5-large
```
## Understanding the Trade-offs
Every configuration choice involves trade-offs:
| Factor | Small/Fast | Large/Quality |
|--------|------------|---------------|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
| Chunk Size | 512 tokens | 128 tokens |
| Index Type | HNSW | DiskANN |
| LLM | `qwen3:1.7b` | `gpt-4o` |
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Low-resource setups
If you dont have a local GPU or builds/searches are too slow, use one or more of the options below.
### 1) Use OpenAI embeddings (no local compute)
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
```bash
export OPENAI_API_KEY=sk-...
# Build with OpenAI embeddings
leann build my-index \
--embedding-mode openai \
--embedding-model text-embedding-3-small
# Search with OpenAI embeddings (recompute at query time)
leann search my-index "your query" \
--recompute
```
### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
```bash
# One-time: install and configure SkyPilot
pip install skypilot
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
sky launch -c leann-gpu sky/leann-build.yaml
# Override parameters via -e key=value (optional)
sky launch -c leann-gpu sky/leann-build.yaml \
-e index_name=my-index \
-e backend=hnsw \
-e embedding_mode=sentence-transformers \
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
# Copy the built index back to your local .leann (use rsync)
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
```
### 3) Disable recomputation to trade storage for speed
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
```bash
# Build without recomputation (HNSW requires non-compact in this mode)
leann build my-index --no-recompute --no-compact
# Search without recomputation
leann search my-index "your query" --no-recompute
```
When to use:
- Extreme low latency requirements (high QPS, interactive assistants)
- Read-heavy workloads where storage is cheaper than latency
- No always-available GPU
Constraints:
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
- DiskANN: supported; `--no-recompute` skips selective recompute during search
Storage impact:
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
Converting an existing index (rebuild required):
```bash
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
leann build my-index --force --no-recompute --no-compact
```
Python API usage:
```python
from leann import LeannSearcher
searcher = LeannSearcher("/path/to/my-index.leann")
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
```
Trade-offs:
- Lower latency and fewer network hops at query time
- Significantly higher storage (10100× vs selective recomputation)
- Slightly larger memory footprint during build and search
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
- HNSW
```text
recompute=True: search_time=0.818s, size=1.1MB
recompute=False: search_time=0.012s, size=16.6MB
```
- DiskANN
```text
recompute=True: search_time=0.041s, size=5.9MB
recompute=False: search_time=0.013s, size=24.6MB
```
Conclusion:
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
## Further Reading
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)

View File

@@ -3,10 +3,9 @@
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
@@ -14,7 +13,7 @@
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
## 🎨 Developer Experience

View File

@@ -1,149 +0,0 @@
# LEANN Grep Search Usage Guide
## Overview
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
## Basic Usage
### Simple Grep Search
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("your_index_path")
# Exact text search
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:100]}...")
print("-" * 40)
```
### Comparison: Semantic vs Grep Search
```python
# Semantic search - finds conceptually similar content
semantic_results = searcher.search("machine learning algorithms", top_k=3)
# Grep search - finds exact text matches
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
```
## When to Use Grep Search
### Use Cases
- **Code Search**: Finding specific function definitions, class names, or variable references
- **Error Debugging**: Locating exact error messages or stack traces
- **Documentation**: Finding specific API endpoints or exact terminology
### Examples
```python
# Find function definitions
functions = searcher.search("def __init__", use_grep=True)
# Find import statements
imports = searcher.search("from sklearn import", use_grep=True)
# Find specific error types
errors = searcher.search("FileNotFoundError", use_grep=True)
# Find TODO comments
todos = searcher.search("TODO:", use_grep=True)
# Find configuration entries
configs = searcher.search("server_port=", use_grep=True)
```
## Technical Details
### How It Works
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
3. **Result Processing**: Parses JSON lines and extracts text and metadata
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
### Search Process
```
Query: "def train_model"
grep -i -n "def train_model" documents.leann.passages.jsonl
Parse matching JSON lines
Calculate scores based on term frequency
Return top_k results
```
### Scoring Algorithm
```python
# Term frequency in document
score = text.lower().count(query.lower())
```
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
## Error Handling
### Common Issues
#### Grep Command Not Found
```
RuntimeError: grep command not found. Please install grep or use semantic search.
```
**Solution**: Install grep on your system:
- **Ubuntu/Debian**: `sudo apt-get install grep`
- **macOS**: grep is pre-installed
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
#### No Results Found
```python
# Check if your query exists in the raw data
results = searcher.search("your_query", use_grep=True)
if not results:
print("No exact matches found. Try:")
print("1. Check spelling and case")
print("2. Use partial terms")
print("3. Switch to semantic search")
```
## Complete Example
```python
#!/usr/bin/env python3
"""
Grep Search Example
Demonstrates grep search for exact text matching.
"""
from leann.api import LeannSearcher
def demonstrate_grep_search():
# Initialize searcher
searcher = LeannSearcher("my_index")
print("=== Function Search ===")
functions = searcher.search("def __init__", use_grep=True, top_k=5)
for i, result in enumerate(functions, 1):
print(f"{i}. Score: {result.score}")
print(f" Preview: {result.text[:60]}...")
print()
print("=== Error Search ===")
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
for result in errors:
print(f"Content: {result.text.strip()}")
print("-" * 40)
if __name__ == "__main__":
demonstrate_grep_search()
```

View File

@@ -1,300 +0,0 @@
# LEANN Metadata Filtering Usage Guide
## Overview
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
## Basic Usage
### Adding Metadata to Your Documents
When building your index, add metadata to each text chunk:
```python
from leann.api import LeannBuilder
builder = LeannBuilder("hnsw")
# Add text with metadata
builder.add_text(
text="Chapter 1: Alice falls down the rabbit hole",
metadata={
"chapter": 1,
"character": "Alice",
"themes": ["adventure", "curiosity"],
"word_count": 150
}
)
builder.build_index("alice_in_wonderland_index")
```
### Searching with Metadata Filters
Use the `metadata_filters` parameter in search calls:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("alice_in_wonderland_index")
# Search with filters
results = searcher.search(
query="What happens to Alice?",
top_k=10,
metadata_filters={
"chapter": {"<=": 5}, # Only chapters 1-5
"spoiler_level": {"!=": "high"} # No high spoilers
}
)
```
## Filter Syntax
### Basic Structure
```python
metadata_filters = {
"field_name": {"operator": value},
"another_field": {"operator": value}
}
```
### Supported Operators
#### Comparison Operators
- `"=="`: Equal to
- `"!="`: Not equal to
- `"<"`: Less than
- `"<="`: Less than or equal
- `">"`: Greater than
- `">="`: Greater than or equal
```python
# Examples
{"chapter": {"==": 1}} # Exactly chapter 1
{"page": {">": 100}} # Pages after 100
{"rating": {">=": 4.0}} # Rating 4.0 or higher
{"word_count": {"<": 500}} # Short passages
```
#### Membership Operators
- `"in"`: Value is in list
- `"not_in"`: Value is not in list
```python
# Examples
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
```
#### String Operators
- `"contains"`: String contains substring
- `"starts_with"`: String starts with prefix
- `"ends_with"`: String ends with suffix
```python
# Examples
{"title": {"contains": "alice"}} # Title contains "alice"
{"filename": {"ends_with": ".py"}} # Python files
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
```
#### Boolean Operators
- `"is_true"`: Field is truthy
- `"is_false"`: Field is falsy
```python
# Examples
{"is_published": {"is_true": True}} # Published content
{"is_draft": {"is_false": False}} # Not drafts
```
### Multiple Operators on Same Field
You can apply multiple operators to the same field (AND logic):
```python
metadata_filters = {
"word_count": {
">=": 100, # At least 100 words
"<=": 500 # At most 500 words
}
}
```
### Compound Filters
Multiple fields are combined with AND logic:
```python
metadata_filters = {
"chapter": {"<=": 10}, # Up to chapter 10
"character": {"==": "Alice"}, # About Alice
"spoiler_level": {"!=": "high"} # No major spoilers
}
```
## Use Case Examples
### 1. Spoiler-Free Book Search
```python
# Reader has only read up to chapter 5
def search_spoiler_free(query, max_chapter):
return searcher.search(
query=query,
metadata_filters={
"chapter": {"<=": max_chapter},
"spoiler_level": {"in": ["none", "low"]}
}
)
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
```
### 2. Document Management by Date
```python
# Find recent documents
recent_docs = searcher.search(
query="project updates",
metadata_filters={
"date": {">=": "2024-01-01"},
"document_type": {"==": "report"}
}
)
```
### 3. Code Search by File Type
```python
# Search only Python files
python_code = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
### 4. Content Filtering by Audience
```python
# Age-appropriate content
family_content = searcher.search(
query="adventure stories",
metadata_filters={
"age_rating": {"in": ["G", "PG"]},
"content_warnings": {"not_in": ["violence", "adult_themes"]}
}
)
```
### 5. Multi-Book Series Management
```python
# Search across first 3 books only
early_series = searcher.search(
query="character development",
metadata_filters={
"series": {"==": "Harry Potter"},
"book_number": {"<=": 3}
}
)
```
## Running the Example
You can see metadata filtering in action with our spoiler-free book RAG example:
```bash
# Don't forget to set up the environment
uv venv
source .venv/bin/activate
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
export OPENAI_API_KEY="your-api-key-here"
# Run the spoiler-free book RAG example
uv run examples/spoiler_free_book_rag.py
```
This example demonstrates:
- Building an index with metadata (chapter numbers, characters, themes, locations)
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
- Different scenarios for readers at various points in the book
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
## Advanced Patterns
### Custom Chunking with metadata
```python
def chunk_book_with_metadata(book_text, book_info):
chunks = []
for chapter_num, chapter_text in parse_chapters(book_text):
# Extract entities, themes, etc.
characters = extract_characters(chapter_text)
themes = classify_themes(chapter_text)
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
# Create chunks with rich metadata
for paragraph in split_paragraphs(chapter_text):
chunks.append({
"text": paragraph,
"metadata": {
"book_title": book_info["title"],
"chapter": chapter_num,
"characters": characters,
"themes": themes,
"spoiler_level": spoiler_level,
"word_count": len(paragraph.split()),
"reading_level": calculate_reading_level(paragraph)
}
})
return chunks
```
## Performance Considerations
### Efficient Filtering Strategies
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
### Best Practices
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
### Adding Metadata to Existing Indices
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
```python
# Read existing passages and add metadata
def add_metadata_to_existing_chunks(chunks):
for chunk in chunks:
# Extract or assign metadata based on content
chunk["metadata"] = extract_metadata(chunk["text"])
return chunks
# Rebuild index with metadata
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
builder = LeannBuilder("hnsw")
for chunk in enhanced_chunks:
builder.add_text(chunk["text"], chunk["metadata"])
builder.build_index("enhanced_index")
```

View File

@@ -72,4 +72,4 @@ Using the wrong distance metric with normalized embeddings can lead to:
- **Incorrect ranking** of search results
- **Suboptimal performance** compared to using the correct metric
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).

View File

@@ -2,8 +2,8 @@
## 🎯 Q2 2025
- [X] HNSW backend integration
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning

View File

@@ -1,395 +0,0 @@
# Slack Integration Setup Guide
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
## Overview
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
## Prerequisites
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
3. **LEANN**: Ensure you have LEANN installed and working
## Step 1: Create a Slack App
### 1.1 Go to Slack API Dashboard
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
2. Click **"Create New App"**
3. Choose **"From scratch"**
4. Enter your app name (e.g., "LEANN Slack Integration")
5. Select your workspace
6. Click **"Create App"**
### 1.2 Configure App Permissions
#### Token Scopes
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
2. Scroll down to **"Scopes"** section
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
4. Add the following scopes:
- `channels:read` - Read public channel information
- `channels:history` - Read messages in public channels
- `groups:read` - Read private channel information
- `groups:history` - Read messages in private channels
- `im:read` - Read direct message information
- `im:history` - Read direct messages
- `mpim:read` - Read group direct message information
- `mpim:history` - Read group direct messages
- `users:read` - Read user information
- `team:read` - Read workspace information
#### App-Level Tokens (Optional)
Some MCP servers may require app-level tokens:
1. Go to **"Basic Information"** in the left sidebar
2. Scroll down to **"App-Level Tokens"**
3. Click **"Generate Token and Scopes"**
4. Enter a name (e.g., "LEANN Integration")
5. Add the `connections:write` scope
6. Click **"Generate"**
7. Copy the token (starts with `xapp-`)
### 1.3 Install App to Workspace
1. Go to **"OAuth & Permissions"** in the left sidebar
2. Click **"Install to Workspace"**
3. Review the permissions and click **"Allow"**
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
## Step 2: Install Slack MCP Server
### Option A: Using npm (Recommended)
```bash
# Install globally
npm install -g slack-mcp-server
# Or install locally
npm install slack-mcp-server
```
### Option B: Using npx (No installation required)
```bash
# Use directly without installation
npx slack-mcp-server
```
## Step 3: Install and Configure Ollama (for Real LLM Responses)
### 3.1 Install Ollama
```bash
# Install Ollama using Homebrew (macOS)
brew install ollama
# Or download from https://ollama.ai/
```
### 3.2 Start Ollama Service
```bash
# Start Ollama as a service
brew services start ollama
# Or start manually
ollama serve
```
### 3.3 Pull a Model
```bash
# Pull a lightweight model for testing
ollama pull llama3.2:1b
# Verify the model is available
ollama list
```
## Step 4: Configure Environment Variables
Create a `.env` file or set environment variables:
```bash
# Required: User OAuth Token
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
# Optional: App-Level Token (if your MCP server requires it)
SLACK_APP_TOKEN=xapp-your-app-token-here
# Optional: Workspace-specific settings
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
```
## Step 5: Test the Setup
### 5.1 Test MCP Server Connection
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--test-connection \
--workspace-name "Your Workspace Name"
```
This will test the connection and list available tools without indexing any data.
### 5.2 Index a Specific Channel
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace Name" \
--channels general \
--query "What did we discuss about the project?"
```
### 5.3 Real RAG Query Examples
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
### Example 1: Advisor Models Query
**Query:** "train black-box models to adopt to your personal data"
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
![Advisor Models Query - Command Setup](videos/slack_integration_1.1.png)
![Advisor Models Query - Search Results](videos/slack_integration_1.2.png)
![Advisor Models Query - LLM Response](videos/slack_integration_1.3.png)
### Example 2: Barbarians at the Gate Query
**Query:** "AI-driven research systems ADRS"
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
![Barbarians Query - Command Setup](videos/slack_integration_2.1.png)
![Barbarians Query - Search Results](videos/slack_integration_2.2.png)
![Barbarians Query - LLM Response](videos/slack_integration_2.3.png)
### Prerequisites
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
- Bot token available and exported in the same terminal session
### Commands
1) Set the workspace token for this shell
```bash
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
```
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
**Advisor Models Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "train black-box models to adopt to your personal data" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
**Barbarians at the Gate Query:**
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Sky Lab Computing" \
--channels C0GN5BX0F \
--max-messages-per-channel 100000 \
--query "AI-driven research systems ADRS" \
--llm ollama \
--llm-model "llama3.2:1b" \
--llm-host "http://localhost:11434" \
--no-concatenate-conversations
```
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
3) Optional: Ask a broader question
```bash
python test_channel_by_id_or_name.py \
--channel-id C0GN5BX0F \
--workspace-name "Sky Lab Computing" \
--query "What is LEANN about?"
```
Notes:
- If you see `not_in_channel`, invite the bot to the channel and re-run.
- If you see `channel_not_found`, confirm the channel ID and workspace.
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
## Common Issues and Solutions
### Issue 1: "users cache is not ready yet" Error
**Problem**: You see this warning:
```
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
```
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--max-retries 10 \
--retry-delay 3.0 \
--channels general \
--query "Your query here"
```
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
```bash
# Terminal 1: Start MCP server
slack-mcp-server
# Terminal 2: Run LEANN (it will connect to the running server)
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
```
### Issue 2: "No message fetching tool found"
**Problem**: The MCP server doesn't have the expected tools.
**Solution**:
1. Check if your MCP server is properly installed and configured
2. Verify your Slack tokens are correct
3. Try a different MCP server implementation
4. Check the MCP server documentation for required configuration
### Issue 3: Permission Denied Errors
**Problem**: You get permission errors when trying to access channels.
**Solutions**:
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
2. **Verify Token Scopes**: Make sure you have all required scopes configured
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
### Issue 4: Empty Results
**Problem**: No messages are returned even though the channel has messages.
**Solutions**:
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
2. **Verify Bot Access**: Make sure the bot can access the channels
3. **Check Date Ranges**: Some MCP servers have limitations on message history
4. **Increase Message Limits**: Try increasing the message limit:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--channels general \
--max-messages-per-channel 1000 \
--query "test"
```
## Advanced Configuration
### Custom MCP Server Commands
If you need to pass additional parameters to your MCP server:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
--workspace-name "Your Workspace" \
--channels general \
--query "Your query"
```
### Multiple Workspaces
To work with multiple Slack workspaces, you can:
1. Create separate apps for each workspace
2. Use different environment variables
3. Run separate instances with different configurations
### Performance Optimization
For better performance with large workspaces:
```bash
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "Your Workspace" \
--max-messages-per-channel 500 \
--no-concatenate-conversations \
--query "Your query"
```
---
## Troubleshooting Checklist
- [ ] Slack app created with proper permissions
- [ ] Bot token (xoxb-) copied correctly
- [ ] App-level token (xapp-) created if needed
- [ ] MCP server installed and accessible
- [ ] Ollama installed and running (`brew services start ollama`)
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
- [ ] Environment variables set correctly
- [ ] Bot invited to relevant channels
- [ ] Channel names specified without # symbol
- [ ] Sufficient retry attempts configured
- [ ] Network connectivity to Slack APIs
## Getting Help
If you continue to have issues:
1. **Check Logs**: Look for detailed error messages in the console output
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
5. **Community Support**: Reach out to the LEANN community for help
## Example Commands
### Basic Usage
```bash
# Test connection
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
# Index specific channels
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general random \
--query "What did we decide about the project timeline?"
```
### Advanced Usage
```bash
# With custom retry settings
python -m apps.slack_rag \
--mcp-server "slack-mcp-server" \
--workspace-name "My Company" \
--channels general \
--max-retries 10 \
--retry-delay 5.0 \
--max-messages-per-channel 2000 \
--query "Show me all decisions made in the last month"
```

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 445 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 508 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 437 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 474 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 501 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 454 KiB

View File

View File

@@ -62,7 +62,7 @@ def test_faiss_hnsw():
try:
result = subprocess.run(
[sys.executable, "benchmarks/faiss_only.py"],
[sys.executable, "examples/faiss_only.py"],
capture_output=True,
text=True,
timeout=300,
@@ -115,7 +115,7 @@ def test_leann_hnsw():
# Load and parse documents
documents = SimpleDirectoryReader(
"data",
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],

158
examples/document_search.py Normal file
View File

@@ -0,0 +1,158 @@
#!/usr/bin/env python3
"""
Document search demo with recompute mode
"""
import shutil
import time
from pathlib import Path
# Import backend packages to trigger plugin registration
try:
import leann_backend_diskann # noqa: F401
import leann_backend_hnsw # noqa: F401
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# Import upper-level API from leann-core
from leann.api import LeannBuilder, LeannChat, LeannSearcher
def load_sample_documents():
"""Create sample documents for demonstration"""
docs = [
{
"title": "Intro to Python",
"content": "Python is a high-level, interpreted language known for simplicity.",
},
{
"title": "ML Basics",
"content": "Machine learning builds systems that learn from data.",
},
{
"title": "Data Structures",
"content": "Data structures like arrays, lists, and graphs organize data.",
},
]
return docs
def main():
print("==========================================================")
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
print("==========================================================")
INDEX_DIR = Path("./test_indices")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
BACKEND_TO_TEST = "diskann"
if INDEX_DIR.exists():
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
shutil.rmtree(INDEX_DIR)
# --- 1. Build index ---
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
documents = load_sample_documents()
print(f"Loaded {len(documents)} sample documents.")
for doc in documents:
builder.add_text(doc["content"], metadata={"title": doc["title"]})
builder.build_index(INDEX_PATH)
print("\nIndex built!")
# --- 2. Basic search demo ---
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
searcher = LeannSearcher(index_path=INDEX_PATH)
query = "What is machine learning?"
print(f"\nQuery: '{query}'")
print("\n--- Basic search mode (PQ computation) ---")
start_time = time.time()
results = searcher.search(query, top_k=2)
basic_time = time.time() - start_time
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<")
for i, res in enumerate(results, 1):
print(
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
)
# --- 3. Recompute search demo ---
print("\n[PHASE 3] Recompute search using embedding server...")
print("\n--- Recompute search mode (get real embeddings via network) ---")
# Configure recompute parameters
recompute_params = {
"recompute_beighbor_embeddings": True, # Enable network recomputation
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
"skip_search_reorder": True, # Skip search reordering
"dedup_node_dis": True, # Enable node distance deduplication
"prune_ratio": 0.1, # Pruning ratio 10%
"batch_recompute": False, # Don't use batch recomputation
"global_pruning": False, # Don't use global pruning
"zmq_port": 5555, # ZMQ port
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
}
print("Recompute parameter configuration:")
for key, value in recompute_params.items():
print(f" {key}: {value}")
print("\n🔄 Executing Recompute search...")
try:
start_time = time.time()
recompute_results = searcher.search(query, top_k=2, **recompute_params)
recompute_time = time.time() - start_time
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1):
print(
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
)
# Compare results
print("\n--- Result comparison ---")
print(f"Basic search time: {basic_time:.3f} seconds")
print(f"Recompute time: {recompute_time:.3f} seconds")
print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))):
basic_score = results[i].score
recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score)
print(
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
)
if recompute_time > basic_time:
print("✅ Recompute mode working correctly (more accurate but slower)")
else:
print("i Recompute time is unusually fast, network recomputation may not be enabled")
except Exception as e:
print(f"❌ Recompute search failed: {e}")
print("This usually indicates an embedding server connection issue")
# --- 4. Chat demo ---
print("\n[PHASE 4] Starting chat session...")
chat = LeannChat(index_path=INDEX_PATH)
chat_response = chat.ask(query)
print(f"You: {query}")
print(f"Leann: {chat_response}")
print("\n==========================================================")
print("✅ Demo finished successfully!")
print("==========================================================")
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show More