Compare commits
108 Commits
fix/52-inc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
198044d033 | ||
|
|
a2e5f5294b | ||
|
|
8a2ea37871 | ||
|
|
7ddb4772c0 | ||
|
|
a1c21adbce | ||
|
|
d1b3c93a5a | ||
|
|
a6ee95b18a | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
0175bc9c20 | ||
|
|
af47dfdde7 | ||
|
|
f13bd02fbd | ||
|
|
a0bbf831db | ||
|
|
86287d8832 | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 | ||
|
|
00770aebbb | ||
|
|
e268392d5b | ||
|
|
eb909ccec5 | ||
|
|
13beb98164 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
3c4785bb63 | ||
|
|
930b79cc98 | ||
|
|
9b7353f336 | ||
|
|
9dd0e0b26f | ||
|
|
3766ad1fd2 | ||
|
|
c3aceed1e0 | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e | ||
|
|
64b92a04a7 | ||
|
|
a85d0ad4a7 | ||
|
|
dbb5f4d352 | ||
|
|
f180b83589 | ||
|
|
abf312d998 | ||
|
|
ab251ab751 | ||
|
|
28085f6f04 | ||
|
|
6495833887 | ||
|
|
5543b3c5f7 | ||
|
|
a99983b3d9 | ||
|
|
36482e016c | ||
|
|
32967daf81 | ||
|
|
b4bb8dec75 | ||
|
|
5ba9cf6442 | ||
|
|
1484406a8d | ||
|
|
761ec1f0ac | ||
|
|
4808afc686 | ||
|
|
0bba4b2157 | ||
|
|
e67b5f44fa | ||
|
|
658bce47ef | ||
|
|
6b399ad8d2 | ||
|
|
16f35aa067 | ||
|
|
ab9c6bd69e | ||
|
|
e2b37914ce | ||
|
|
e588100674 | ||
|
|
fecee94af1 | ||
|
|
01475c10a0 | ||
|
|
c8aa063f48 | ||
|
|
576beb13db | ||
|
|
63c7b0c8a3 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 |
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Bug Report
|
||||
description: Report a bug in LEANN
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: A clear description of the bug
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: How to reproduce
|
||||
placeholder: |
|
||||
1. Install with...
|
||||
2. Run command...
|
||||
3. See error
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: error
|
||||
attributes:
|
||||
label: Error message
|
||||
description: Paste any error messages
|
||||
render: shell
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: LEANN Version
|
||||
placeholder: "0.1.0"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
options:
|
||||
- macOS
|
||||
- Linux
|
||||
- Windows
|
||||
- Docker
|
||||
validations:
|
||||
required: true
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Documentation
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||
about: Read the docs first
|
||||
- name: Discussions
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||
about: Ask questions and share ideas
|
||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature for LEANN
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: What problem does this solve?
|
||||
description: Describe the problem or need
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: Proposed solution
|
||||
description: How would you like this to work?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: example
|
||||
attributes:
|
||||
label: Example usage
|
||||
description: Show how the API might look
|
||||
render: python
|
||||
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
## What does this PR do?
|
||||
|
||||
<!-- Brief description of your changes -->
|
||||
|
||||
## Related Issues
|
||||
|
||||
Fixes #
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Tests pass (`uv run pytest`)
|
||||
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||
387
.github/workflows/build-reusable.yml
vendored
@@ -17,102 +17,159 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install ruff
|
||||
- name: Run pre-commit with only lint group (no project deps)
|
||||
run: |
|
||||
uv tool install ruff
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
|
||||
build:
|
||||
needs: lint
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.9'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
python: '3.11'
|
||||
- os: macos-14
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.9'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
python: '3.11'
|
||||
- os: macos-15
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
- os: macos-13
|
||||
python: '3.9'
|
||||
- os: macos-13
|
||||
python: '3.10'
|
||||
- os: macos-13
|
||||
python: '3.11'
|
||||
- os: macos-13
|
||||
python: '3.12'
|
||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||
runs-on: ${{ matrix.os }}
|
||||
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
type-check:
|
||||
name: Type Check with ty
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install ty
|
||||
run: uv tool install ty
|
||||
|
||||
- name: Run ty type checker
|
||||
run: |
|
||||
# Run ty on core packages, apps, and tests
|
||||
ty check packages/leann-core/src apps tests
|
||||
|
||||
build:
|
||||
needs: [lint, type-check]
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
# Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
|
||||
# which requires Python 3.10+
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
# ARM64 Linux builds
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.10'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.11'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.12'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
python: '3.11'
|
||||
- os: macos-14
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
python: '3.11'
|
||||
- os: macos-15
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
# Intel Mac builds (x86_64) - replaces deprecated macos-13
|
||||
# Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
|
||||
# (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
|
||||
- os: macos-15-intel
|
||||
python: '3.10'
|
||||
- os: macos-15-intel
|
||||
python: '3.11'
|
||||
- os: macos-15-intel
|
||||
python: '3.12'
|
||||
# macOS 26 (beta) - arm64
|
||||
- os: macos-26
|
||||
python: '3.10'
|
||||
- os: macos-26
|
||||
python: '3.11'
|
||||
- os: macos-26
|
||||
python: '3.12'
|
||||
- os: macos-26
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
patchelf
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
# Debug: Show system information
|
||||
echo "🔍 System Information:"
|
||||
echo "Architecture: $(uname -m)"
|
||||
echo "OS: $(uname -a)"
|
||||
echo "CPU info: $(lscpu | head -5)"
|
||||
|
||||
# Install math library based on architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||
|
||||
if [[ "$ARCH" == "x86_64" ]]; then
|
||||
# Install Intel MKL for DiskANN on x86_64
|
||||
echo "📦 Installing Intel MKL for x86_64..."
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||
echo "✅ Intel MKL installed for x86_64"
|
||||
|
||||
# Debug: Check MKL installation
|
||||
echo "🔍 MKL Installation Check:"
|
||||
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||
echo "📦 Installing OpenBLAS for ARM64..."
|
||||
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||
echo "✅ OpenBLAS installed for ARM64"
|
||||
|
||||
# Debug: Check OpenBLAS installation
|
||||
echo "🔍 OpenBLAS Installation Check:"
|
||||
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||
fi
|
||||
|
||||
# Debug: Show final library paths
|
||||
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
@@ -122,11 +179,24 @@ jobs:
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }} .uv-build
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||
else
|
||||
uv pip install --system delocate
|
||||
BUILD_PY=".uv-build/bin/python"
|
||||
fi
|
||||
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --python "$BUILD_PY" auditwheel
|
||||
else
|
||||
uv pip install --python "$BUILD_PY" delocate
|
||||
fi
|
||||
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||
else
|
||||
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
@@ -157,13 +227,16 @@ jobs:
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
@@ -177,14 +250,16 @@ jobs:
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||
# But Homebrew libraries on each macOS version require matching minimum version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
@@ -222,16 +297,19 @@ jobs:
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Determine deployment target based on runner OS
|
||||
# Must match the Homebrew libraries for each macOS version
|
||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||
HNSW_TARGET="13.0"
|
||||
DISKANN_TARGET="13.3"
|
||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
HNSW_TARGET="26.0"
|
||||
DISKANN_TARGET="26.0"
|
||||
fi
|
||||
|
||||
# Repair HNSW wheel
|
||||
@@ -262,18 +340,69 @@ jobs:
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
# Create uv-managed virtual environment with the requested interpreter
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install packages using --find-links to prioritize local builds
|
||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
UV_PY=".venv\\Scripts\\python.exe"
|
||||
else
|
||||
UV_PY=".venv/bin/python"
|
||||
fi
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
# Install test dependency group only (avoids reinstalling project package)
|
||||
uv pip install --python "$UV_PY" --group test
|
||||
|
||||
# Install core wheel built in this job
|
||||
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$CORE_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||
|
||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
fi
|
||||
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$HNSW_WHL" ]]; then
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$HNSW_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||
fi
|
||||
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$DISKANN_WHL" ]]; then
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$DISKANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||
fi
|
||||
|
||||
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$LEANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
@@ -304,3 +433,53 @@ jobs:
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
|
||||
|
||||
arch-smoke:
|
||||
name: Arch Linux smoke test (install & import)
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: archlinux:latest
|
||||
|
||||
steps:
|
||||
- name: Prepare system
|
||||
run: |
|
||||
pacman -Syu --noconfirm
|
||||
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||
|
||||
- name: Download ALL wheel artifacts from this run
|
||||
uses: actions/download-artifact@v5
|
||||
with:
|
||||
# Don't specify name, download all artifacts
|
||||
path: ./wheels
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Create virtual environment and install wheels
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
uv pip install --find-links wheels leann-core
|
||||
uv pip install --find-links wheels leann-backend-hnsw
|
||||
uv pip install --find-links wheels leann-backend-diskann
|
||||
uv pip install --find-links wheels leann
|
||||
|
||||
- name: Import & tiny runtime check
|
||||
env:
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
python - <<'PY'
|
||||
import leann
|
||||
import leann_backend_hnsw as h
|
||||
import leann_backend_diskann as d
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
b = LeannBuilder(backend_name="hnsw")
|
||||
b.add_text("hello arch")
|
||||
b.build_index("arch_demo.leann")
|
||||
s = LeannSearcher("arch_demo.leann")
|
||||
print("search:", s.search("hello", top_k=1))
|
||||
PY
|
||||
|
||||
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
19
.gitignore
vendored
@@ -18,9 +18,12 @@ demo/experiment_results/**/*.json
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
*.png
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
!CMakeLists.txt
|
||||
!llms.txt
|
||||
latency_breakdown*.json
|
||||
experiment_results/eval_results/diskann/*.json
|
||||
aws/
|
||||
@@ -88,7 +91,21 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
*.npy
|
||||
*.db
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
benchmarks/data/
|
||||
|
||||
## multi vector
|
||||
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||
|
||||
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
||||
# If you need to commit a specific demo PDF, remove this negation locally.
|
||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||
|
||||
# AUR build directory (Arch Linux)
|
||||
paru-bin/
|
||||
|
||||
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||
url = https://github.com/zeromq/libzmq.git
|
||||
[submodule "packages/astchunk-leann"]
|
||||
path = packages/astchunk-leann
|
||||
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||
|
||||
@@ -13,4 +13,5 @@ repos:
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"charliermarsh.ruff",
|
||||
]
|
||||
}
|
||||
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit",
|
||||
"source.fixAll": "explicit"
|
||||
},
|
||||
"editor.insertSpaces": true,
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"ruff.enable": true,
|
||||
"files.watcherExclude": {
|
||||
"**/.venv/**": true,
|
||||
"**/__pycache__/**": true,
|
||||
"**/*.egg-info/**": true,
|
||||
"**/build/**": true,
|
||||
"**/dist/**": true
|
||||
}
|
||||
}
|
||||
699
README.md
@@ -5,20 +5,38 @@
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q">
|
||||
<img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||
</a>
|
||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group">
|
||||
<img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://forms.gle/rDbZf864gMNxhpTq8">
|
||||
<img src="https://img.shields.io/badge/📣_Community_Survey-Help_Shape_v0.4-007ec6?style=for-the-badge&logo=google-forms&logoColor=white" alt="Take Survey">
|
||||
</a>
|
||||
<p>
|
||||
We track <b>zero telemetry</b>. This survey is the ONLY way to tell us if you want <br>
|
||||
<b>GPU Acceleration</b> or <b>More Integrations</b> next.<br>
|
||||
👉 <a href="https://forms.gle/rDbZf864gMNxhpTq8"><b>Click here to cast your vote (2 mins)</b></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
The smallest vector index in the world. RAG Everything with LEANN!
|
||||
</h2>
|
||||
|
||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)** ([WeChat](#-wechat-detective-unlock-your-golden-memories), [iMessage](#-imessage-history-your-personal-conversation-archive)), **[agent memory](#-chatgpt-chat-history-your-personal-ai-conversation-archive)** ([ChatGPT](#-chatgpt-chat-history-your-personal-ai-conversation-archive), [Claude](#-claude-chat-history-your-personal-ai-conversation-archive)), **[live data](#mcp-integration-rag-on-live-data-from-any-platform)** ([Slack](#slack-messages-search-your-team-conversations), [Twitter](#-twitter-bookmarks-your-personal-tweet-library)), **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
@@ -70,8 +88,9 @@ uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
<!--
|
||||
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||
> Low-resource? See "Low-resource setups" in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
@@ -87,15 +106,60 @@ git submodule update --init --recursive
|
||||
```
|
||||
|
||||
**macOS:**
|
||||
|
||||
Note: DiskANN requires MacOS 13.3 or later.
|
||||
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
brew install libomp boost protobuf zeromq pkgconf
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
**Linux (Ubuntu/Debian):**
|
||||
|
||||
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
||||
|
||||
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
||||
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
uv sync
|
||||
sudo apt-get update && sudo apt-get install -y \
|
||||
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
libmkl-full-dev
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (Arch Linux):**
|
||||
|
||||
```bash
|
||||
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
||||
boost boost-libs protobuf abseil-cpp libaio zeromq
|
||||
|
||||
# For MKL in DiskANN
|
||||
sudo pacman -S --needed base-devel git
|
||||
git clone https://aur.archlinux.org/paru-bin.git
|
||||
cd paru-bin && makepkg -si
|
||||
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
||||
|
||||
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
||||
|
||||
```bash
|
||||
sudo dnf groupinstall -y "Development Tools"
|
||||
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
||||
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
||||
|
||||
# For MKL in DiskANN
|
||||
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
||||
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
uv sync --extra diskann
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -129,11 +193,16 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, ChatGPT conversations, Claude conversations, iMessage conversations, and **live data from any platform through MCP (Model Context Protocol) servers** - including Slack, Twitter, and more.
|
||||
|
||||
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
#### LLM Backend
|
||||
|
||||
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, Anthropic, and Any OpenAI compatible API).
|
||||
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
@@ -144,6 +213,69 @@ Set your OpenAI API key as an environment variable:
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
Make sure to use `--llm openai` flag when using the CLI.
|
||||
You can also specify the model name with `--llm-model <model-name>` flag.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
|
||||
|
||||
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
|
||||
|
||||
```sh
|
||||
export OPENAI_API_KEY="xxx"
|
||||
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
|
||||
```
|
||||
|
||||
To use OpenAI compatible endpoint with the CLI interface:
|
||||
|
||||
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
|
||||
|
||||
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
|
||||
|
||||
-----
|
||||
|
||||
|
||||
Below is a list of base URLs for common providers to get you started.
|
||||
|
||||
|
||||
### 🖥️ Local Inference Engines (Recommended for full privacy)
|
||||
|
||||
| Provider | Sample Base URL |
|
||||
| ---------------- | --------------------------- |
|
||||
| **Ollama** | `http://localhost:11434/v1` |
|
||||
| **LM Studio** | `http://localhost:1234/v1` |
|
||||
| **vLLM** | `http://localhost:8000/v1` |
|
||||
| **llama.cpp** | `http://localhost:8080/v1` |
|
||||
| **SGLang** | `http://localhost:30000/v1` |
|
||||
| **LiteLLM** | `http://localhost:4000` |
|
||||
|
||||
-----
|
||||
|
||||
### ☁️ Cloud Providers
|
||||
|
||||
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
|
||||
|
||||
|
||||
| Provider | Base URL |
|
||||
| ---------------- | ---------------------------------------------------------- |
|
||||
| **OpenAI** | `https://api.openai.com/v1` |
|
||||
| **OpenRouter** | `https://openrouter.ai/api/v1` |
|
||||
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
|
||||
| **x.AI (Grok)** | `https://api.x.ai/v1` |
|
||||
| **Groq AI** | `https://api.groq.com/openai/v1` |
|
||||
| **DeepSeek** | `https://api.deepseek.com/v1` |
|
||||
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||
| **Anthropic** | `https://api.anthropic.com/v1` |
|
||||
|
||||
|
||||
|
||||
|
||||
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -173,7 +305,8 @@ ollama pull llama3.2:1b
|
||||
|
||||
</details>
|
||||
|
||||
### ⭐ Flexible Configuration
|
||||
|
||||
## ⭐ Flexible Configuration
|
||||
|
||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||
|
||||
@@ -196,7 +329,7 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm TYPE # LLM backend: openai, ollama, hf, or anthropic (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
@@ -249,10 +382,64 @@ python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||
|
||||
# Filter only markdown and Python files with smaller chunks
|
||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||
|
||||
# Enable AST-aware chunking for code files
|
||||
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
||||
|
||||
# Or use the specialized code RAG for better code understanding
|
||||
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 🎨 ColQwen: Multimodal PDF Retrieval with Vision-Language Models
|
||||
|
||||
Search through PDFs using both text and visual understanding with ColQwen2/ColPali models. Perfect for research papers, technical documents, and any PDFs with complex layouts, figures, or diagrams.
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ColQwen Setup & Usage</strong></summary>
|
||||
|
||||
#### Prerequisites
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
#### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 # or colpali
|
||||
```
|
||||
|
||||
#### Search
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
#### Models
|
||||
- **ColQwen2** (`colqwen2`): Latest vision-language model with improved performance
|
||||
- **ColPali** (`colpali`): Proven multimodal retriever
|
||||
|
||||
For detailed usage, see the [ColQwen Guide](docs/COLQWEN_GUIDE.md).
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
@@ -421,28 +608,414 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 ChatGPT Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your ChatGPT conversations into a searchable knowledge base! Search through all your ChatGPT discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.html --query "How do I create a list in Python?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your ChatGPT discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export ChatGPT Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Sign in to ChatGPT**
|
||||
2. **Click your profile icon** in the top right corner
|
||||
3. **Navigate to Settings** → **Data Controls**
|
||||
4. **Click "Export"** under Export Data
|
||||
5. **Confirm the export** request
|
||||
6. **Download the ZIP file** from the email link (expires in 24 hours)
|
||||
7. **Extract or use directly** with LEANN
|
||||
|
||||
**Supported formats:**
|
||||
- `.html` files from ChatGPT exports
|
||||
- `.zip` archives from ChatGPT
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: ChatGPT-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to ChatGPT export file (.html/.zip) or directory (default: ./chatgpt_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with HTML export
|
||||
python -m apps.chatgpt_rag --export-path conversations.html
|
||||
|
||||
# Process ZIP archive from ChatGPT
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.chatgpt_rag --export-path chatgpt_data.html --query "Python programming help"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.chatgpt_rag --separate-messages --export-path chatgpt_export.html
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.chatgpt_rag --export-path ./chatgpt_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your ChatGPT conversations are indexed, you can search with queries like:
|
||||
- "What did I ask ChatGPT about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about web development frameworks"
|
||||
- "What coding advice did ChatGPT give me?"
|
||||
- "Search for conversations about debugging techniques"
|
||||
- "Find ChatGPT's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 🤖 Claude Chat History: Your Personal AI Conversation Archive!
|
||||
|
||||
Transform your Claude conversations into a searchable knowledge base! Search through all your Claude discussions about coding, research, brainstorming, and more.
|
||||
|
||||
```bash
|
||||
python -m apps.claude_rag --export-path claude_export.json --query "What did I ask about Python dictionaries?"
|
||||
```
|
||||
|
||||
**Unlock your AI conversation history.** Never lose track of valuable insights from your Claude discussions again.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Export Claude Data</strong></summary>
|
||||
|
||||
**Step-by-step export process:**
|
||||
|
||||
1. **Open Claude** in your browser
|
||||
2. **Navigate to Settings** (look for gear icon or settings menu)
|
||||
3. **Find Export/Download** options in your account settings
|
||||
4. **Download conversation data** (usually in JSON format)
|
||||
5. **Place the file** in your project directory
|
||||
|
||||
*Note: Claude export methods may vary depending on the interface you're using. Check Claude's help documentation for the most current export instructions.*
|
||||
|
||||
**Supported formats:**
|
||||
- `.json` files (recommended)
|
||||
- `.zip` archives containing JSON data
|
||||
- Directories with multiple export files
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Claude-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--export-path PATH # Path to Claude export file (.json/.zip) or directory (default: ./claude_export)
|
||||
--separate-messages # Process each message separately instead of concatenated conversations
|
||||
--chunk-size N # Text chunk size (default: 512)
|
||||
--chunk-overlap N # Overlap between chunks (default: 128)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage with JSON export
|
||||
python -m apps.claude_rag --export-path my_claude_conversations.json
|
||||
|
||||
# Process ZIP archive from Claude
|
||||
python -m apps.claude_rag --export-path claude_export.zip
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.claude_rag --export-path claude_data.json --query "machine learning advice"
|
||||
|
||||
# Process individual messages for fine-grained search
|
||||
python -m apps.claude_rag --separate-messages --export-path claude_export.json
|
||||
|
||||
# Process directory containing multiple exports
|
||||
python -m apps.claude_rag --export-path ./claude_exports/ --max-items 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your Claude conversations are indexed, you can search with queries like:
|
||||
- "What did I ask Claude about Python programming?"
|
||||
- "Show me conversations about machine learning algorithms"
|
||||
- "Find discussions about software architecture patterns"
|
||||
- "What debugging advice did Claude give me?"
|
||||
- "Search for conversations about data structures"
|
||||
- "Find Claude's recommendations for learning resources"
|
||||
|
||||
</details>
|
||||
|
||||
### 💬 iMessage History: Your Personal Conversation Archive!
|
||||
|
||||
Transform your iMessage conversations into a searchable knowledge base! Search through all your text messages, group chats, and conversations with friends, family, and colleagues.
|
||||
|
||||
```bash
|
||||
python -m apps.imessage_rag --query "What did we discuss about the weekend plans?"
|
||||
```
|
||||
|
||||
**Unlock your message history.** Never lose track of important conversations, shared links, or memorable moments from your iMessage history.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: How to Access iMessage Data</strong></summary>
|
||||
|
||||
**iMessage data location:**
|
||||
|
||||
iMessage conversations are stored in a SQLite database on your Mac at:
|
||||
```
|
||||
~/Library/Messages/chat.db
|
||||
```
|
||||
|
||||
**Important setup requirements:**
|
||||
|
||||
1. **Grant Full Disk Access** to your terminal or IDE:
|
||||
- Open **System Preferences** → **Security & Privacy** → **Privacy**
|
||||
- Select **Full Disk Access** from the left sidebar
|
||||
- Click the **+** button and add your terminal app (Terminal, iTerm2) or IDE (VS Code, etc.)
|
||||
- Restart your terminal/IDE after granting access
|
||||
|
||||
2. **Alternative: Use a backup database**
|
||||
- If you have Time Machine backups or manual copies of the database
|
||||
- Use `--db-path` to specify a custom location
|
||||
|
||||
**Supported formats:**
|
||||
- Direct access to `~/Library/Messages/chat.db` (default)
|
||||
- Custom database path with `--db-path`
|
||||
- Works with backup copies of the database
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: iMessage-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--db-path PATH # Path to chat.db file (default: ~/Library/Messages/chat.db)
|
||||
--concatenate-conversations # Group messages by conversation (default: True)
|
||||
--no-concatenate-conversations # Process each message individually
|
||||
--chunk-size N # Text chunk size (default: 1000)
|
||||
--chunk-overlap N # Overlap between chunks (default: 200)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Basic usage (requires Full Disk Access)
|
||||
python -m apps.imessage_rag
|
||||
|
||||
# Search with specific query
|
||||
python -m apps.imessage_rag --query "family dinner plans"
|
||||
|
||||
# Use custom database path
|
||||
python -m apps.imessage_rag --db-path /path/to/backup/chat.db
|
||||
|
||||
# Process individual messages instead of conversations
|
||||
python -m apps.imessage_rag --no-concatenate-conversations
|
||||
|
||||
# Limit processing for testing
|
||||
python -m apps.imessage_rag --max-items 100 --query "weekend"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
Once your iMessage conversations are indexed, you can search with queries like:
|
||||
- "What did we discuss about vacation plans?"
|
||||
- "Find messages about restaurant recommendations"
|
||||
- "Show me conversations with John about the project"
|
||||
- "Search for shared links about technology"
|
||||
- "Find group chat discussions about weekend events"
|
||||
- "What did mom say about the family gathering?"
|
||||
|
||||
</details>
|
||||
|
||||
### MCP Integration: RAG on Live Data from Any Platform
|
||||
|
||||
Connect to live data sources through the Model Context Protocol (MCP). LEANN now supports real-time RAG on platforms like Slack, Twitter, and more through standardized MCP servers.
|
||||
|
||||
**Key Benefits:**
|
||||
- **Live Data Access**: Fetch real-time data without manual exports
|
||||
- **Standardized Protocol**: Use any MCP-compatible server
|
||||
- **Easy Extension**: Add new platforms with minimal code
|
||||
- **Secure Access**: MCP servers handle authentication
|
||||
|
||||
#### 💬 Slack Messages: Search Your Team Conversations
|
||||
|
||||
Transform your Slack workspace into a searchable knowledge base! Find discussions, decisions, and shared knowledge across all your channels.
|
||||
|
||||
```bash
|
||||
# Test MCP server connection
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
||||
|
||||
# Index and search Slack messages
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "my-team" \
|
||||
--channels general dev-team random \
|
||||
--query "What did we decide about the product launch?"
|
||||
```
|
||||
|
||||
**📖 Comprehensive Setup Guide**: For detailed setup instructions, troubleshooting common issues (like "users cache is not ready yet"), and advanced configuration options, see our [**Slack Setup Guide**](docs/slack-setup-guide.md).
|
||||
|
||||
**Quick Setup:**
|
||||
1. Install a Slack MCP server (e.g., `npm install -g slack-mcp-server`)
|
||||
2. Create a Slack App and get API credentials (see detailed guide above)
|
||||
3. Set environment variables:
|
||||
```bash
|
||||
export SLACK_BOT_TOKEN="xoxb-your-bot-token"
|
||||
export SLACK_APP_TOKEN="xapp-your-app-token" # Optional
|
||||
```
|
||||
4. Test connection with `--test-connection` flag
|
||||
|
||||
**Arguments:**
|
||||
- `--mcp-server`: Command to start the Slack MCP server
|
||||
- `--workspace-name`: Slack workspace name for organization
|
||||
- `--channels`: Specific channels to index (optional)
|
||||
- `--concatenate-conversations`: Group messages by channel (default: true)
|
||||
- `--max-messages-per-channel`: Limit messages per channel (default: 100)
|
||||
- `--max-retries`: Maximum retries for cache sync issues (default: 5)
|
||||
- `--retry-delay`: Initial delay between retries in seconds (default: 2.0)
|
||||
|
||||
#### 🐦 Twitter Bookmarks: Your Personal Tweet Library
|
||||
|
||||
Search through your Twitter bookmarks! Find that perfect article, thread, or insight you saved for later.
|
||||
|
||||
```bash
|
||||
# Test MCP server connection
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --test-connection
|
||||
|
||||
# Index and search Twitter bookmarks
|
||||
python -m apps.twitter_rag \
|
||||
--mcp-server "twitter-mcp-server" \
|
||||
--max-bookmarks 1000 \
|
||||
--query "What AI articles did I bookmark about machine learning?"
|
||||
```
|
||||
|
||||
**Setup Requirements:**
|
||||
1. Install a Twitter MCP server (e.g., `npm install -g twitter-mcp-server`)
|
||||
2. Get Twitter API credentials:
|
||||
- Apply for a Twitter Developer Account at [developer.twitter.com](https://developer.twitter.com)
|
||||
- Create a new app in the Twitter Developer Portal
|
||||
- Generate API keys and access tokens with "Read" permissions
|
||||
- For bookmarks access, you may need Twitter API v2 with appropriate scopes
|
||||
```bash
|
||||
export TWITTER_API_KEY="your-api-key"
|
||||
export TWITTER_API_SECRET="your-api-secret"
|
||||
export TWITTER_ACCESS_TOKEN="your-access-token"
|
||||
export TWITTER_ACCESS_TOKEN_SECRET="your-access-token-secret"
|
||||
```
|
||||
3. Test connection with `--test-connection` flag
|
||||
|
||||
**Arguments:**
|
||||
- `--mcp-server`: Command to start the Twitter MCP server
|
||||
- `--username`: Filter bookmarks by username (optional)
|
||||
- `--max-bookmarks`: Maximum bookmarks to fetch (default: 1000)
|
||||
- `--no-tweet-content`: Exclude tweet content, only metadata
|
||||
- `--no-metadata`: Exclude engagement metadata
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>💡 Click to expand: Example queries you can try</strong></summary>
|
||||
|
||||
**Slack Queries:**
|
||||
- "What did the team discuss about the project deadline?"
|
||||
- "Find messages about the new feature launch"
|
||||
- "Show me conversations about budget planning"
|
||||
- "What decisions were made in the dev-team channel?"
|
||||
|
||||
**Twitter Queries:**
|
||||
- "What AI articles did I bookmark last month?"
|
||||
- "Find tweets about machine learning techniques"
|
||||
- "Show me bookmarked threads about startup advice"
|
||||
- "What Python tutorials did I save?"
|
||||
|
||||
</details>
|
||||
<summary><strong>🔧 Using MCP with CLI Commands</strong></summary>
|
||||
|
||||
**Want to use MCP data with regular LEANN CLI?** You can combine MCP apps with CLI commands:
|
||||
|
||||
```bash
|
||||
# Step 1: Use MCP app to fetch and index data
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --workspace-name "my-team"
|
||||
|
||||
# Step 2: The data is now indexed and available via CLI
|
||||
leann search slack_messages "project deadline"
|
||||
leann ask slack_messages "What decisions were made about the product launch?"
|
||||
|
||||
# Same for Twitter bookmarks
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server"
|
||||
leann search twitter_bookmarks "machine learning articles"
|
||||
```
|
||||
|
||||
**MCP vs Manual Export:**
|
||||
- **MCP**: Live data, automatic updates, requires server setup
|
||||
- **Manual Export**: One-time setup, works offline, requires manual data export
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Adding New MCP Platforms</strong></summary>
|
||||
|
||||
Want to add support for other platforms? LEANN's MCP integration is designed for easy extension:
|
||||
|
||||
1. **Find or create an MCP server** for your platform
|
||||
2. **Create a reader class** following the pattern in `apps/slack_data/slack_mcp_reader.py`
|
||||
3. **Create a RAG application** following the pattern in `apps/slack_rag.py`
|
||||
4. **Test and contribute** back to the community!
|
||||
|
||||
**Popular MCP servers to explore:**
|
||||
- GitHub repositories and issues
|
||||
- Discord messages
|
||||
- Notion pages
|
||||
- Google Drive documents
|
||||
- And many more in the MCP ecosystem!
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
<details>
|
||||
<summary><strong>AST‑Aware Code Chunking</strong></summary>
|
||||
|
||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||
|
||||
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||
|
||||
</details>
|
||||
|
||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||
|
||||
**Key features:**
|
||||
- 🔍 **Semantic code search** across your entire project
|
||||
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
||||
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
||||
- 📚 **Context-aware assistance** for debugging and development
|
||||
- 🚀 **Zero-config setup** with automatic language detection
|
||||
|
||||
```bash
|
||||
# Install LEANN globally for MCP integration
|
||||
uv tool install leann-core
|
||||
|
||||
uv tool install leann-core --with leann
|
||||
claude mcp add --scope user leann-server -- leann_mcp
|
||||
# Setup is automatic - just start using Claude Code!
|
||||
```
|
||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||
|
||||

|
||||
|
||||
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
## Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
@@ -457,7 +1030,8 @@ leann --help
|
||||
**To make it globally available:**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann-core
|
||||
uv tool install leann-core --with leann
|
||||
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
@@ -479,13 +1053,20 @@ leann search my-docs "machine learning concepts"
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# Ask a single question (non-interactive)
|
||||
leann ask my-docs "Where are prompts configured?"
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
|
||||
# Remove an index
|
||||
leann remove my-docs
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
||||
- Smart text chunking with overlap
|
||||
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
||||
- Smart text chunking with overlap for all other content
|
||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||
- Organized index storage in `.leann/indexes/` (project-local)
|
||||
- Support for advanced search parameters
|
||||
@@ -493,7 +1074,7 @@ leann list
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||
|
||||
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help` to get the complete CLI reference.
|
||||
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
|
||||
|
||||
**Build Command:**
|
||||
```bash
|
||||
@@ -525,14 +1106,79 @@ Options:
|
||||
leann ask INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||
--llm {ollama,openai,hf,anthropic} LLM provider (default: ollama)
|
||||
--model MODEL Model name (default: qwen3:8b)
|
||||
--interactive Interactive chat mode
|
||||
--top-k N Retrieval count (default: 20)
|
||||
```
|
||||
|
||||
**List Command:**
|
||||
```bash
|
||||
leann list
|
||||
|
||||
# Lists all indexes across all projects with status indicators:
|
||||
# ✅ - Index is complete and ready to use
|
||||
# ❌ - Index is incomplete or corrupted
|
||||
# 📁 - CLI-created index (in .leann/indexes/)
|
||||
# 📄 - App-created index (*.leann.meta.json files)
|
||||
```
|
||||
|
||||
**Remove Command:**
|
||||
```bash
|
||||
leann remove INDEX_NAME [OPTIONS]
|
||||
|
||||
Options:
|
||||
--force, -f Force removal without confirmation
|
||||
|
||||
# Smart removal: automatically finds and safely removes indexes
|
||||
# - Shows all matching indexes across projects
|
||||
# - Requires confirmation for cross-project removal
|
||||
# - Interactive selection when multiple matches found
|
||||
# - Supports both CLI and app-created indexes
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🚀 Advanced Features
|
||||
|
||||
### 🎯 Metadata Filtering
|
||||
|
||||
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||
|
||||
```python
|
||||
# Add metadata during indexing
|
||||
builder.add_text(
|
||||
"def authenticate_user(token): ...",
|
||||
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||
)
|
||||
|
||||
# Search with filters
|
||||
results = searcher.search(
|
||||
query="authentication function",
|
||||
metadata_filters={
|
||||
"file_extension": {"==": ".py"},
|
||||
"lines_of_code": {"<": 100}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||
|
||||
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||
|
||||
### 🔍 Grep Search
|
||||
|
||||
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||
|
||||
```python
|
||||
# Exact text search
|
||||
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||
```
|
||||
|
||||
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||
|
||||
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||
|
||||
## 🏗️ Architecture & How It Works
|
||||
|
||||
<p align="center">
|
||||
@@ -570,8 +1216,8 @@ Options:
|
||||
## Reproduce Our Results
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
@@ -611,6 +1257,9 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||
|
||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan), [Aakash Suresh](https://github.com/ASuresh0524)
|
||||
|
||||
|
||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
@@ -625,3 +1274,7 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
## 🤖 Explore LEANN with AI
|
||||
|
||||
LEANN is indexed on [DeepWiki](https://deepwiki.com/yichuan-w/LEANN), so you can ask questions to LLMs using Deep Research to explore the codebase and get help to add new features.
|
||||
|
||||
@@ -10,7 +10,39 @@ from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Optional import: older PyPI builds may not include interactive_utils
|
||||
try:
|
||||
from leann.interactive_utils import create_rag_session
|
||||
except ImportError:
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str):
|
||||
class _SimpleSession:
|
||||
def run_interactive_loop(self, handler):
|
||||
print(f"Interactive session for {app_name}: {data_description}")
|
||||
print("Interactive mode not available in this build")
|
||||
|
||||
return _SimpleSession()
|
||||
|
||||
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
# Optional import: older PyPI builds may not include settings
|
||||
try:
|
||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
except ImportError:
|
||||
# Minimal fallbacks if settings helpers are unavailable
|
||||
import os
|
||||
|
||||
def resolve_ollama_host(value: str | None) -> str | None:
|
||||
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
|
||||
|
||||
def resolve_openai_api_key(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
def resolve_openai_base_url(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
@@ -78,6 +110,24 @@ class BaseRAGExample(ABC):
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
@@ -97,8 +147,8 @@ class BaseRAGExample(ABC):
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
default=None,
|
||||
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
@@ -107,6 +157,50 @@ class BaseRAGExample(ABC):
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# AST Chunking parameters
|
||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||
ast_group.add_argument(
|
||||
"--use-ast-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--code-file-extensions",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-fallback-traditional",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
@@ -163,8 +257,8 @@ class BaseRAGExample(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load data from the source. Returns list of text chunks as dicts with 'text' and 'metadata' keys."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
@@ -173,9 +267,13 @@ class BaseRAGExample(ABC):
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||
if resolved_key:
|
||||
config["api_key"] = resolved_key
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
config["host"] = resolve_ollama_host(args.llm_host)
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
@@ -184,17 +282,27 @@ class BaseRAGExample(ABC):
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||
"""Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys)."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
@@ -206,14 +314,25 @@ class BaseRAGExample(ABC):
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
for item in batch:
|
||||
# Handle both dict format (from create_text_chunks) and plain strings
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
metadata = item.get("metadata")
|
||||
builder.add_text(text, metadata)
|
||||
else:
|
||||
builder.add_text(item)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
# Register project directory so leann list can discover this index
|
||||
# The index is saved as args.index_dir/index_name.leann
|
||||
# We want to register the current working directory where the app is run
|
||||
register_project_directory(Path.cwd())
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
@@ -225,19 +344,12 @@ class BaseRAGExample(ABC):
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
# Create interactive session
|
||||
session = create_rag_session(
|
||||
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
||||
)
|
||||
|
||||
def handle_query(query: str):
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
@@ -251,18 +363,13 @@ class BaseRAGExample(ABC):
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
session.run_interactive_loop(handle_query)
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
@@ -304,21 +411,3 @@ class BaseRAGExample(ABC):
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
|
||||
|
||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||
"""Helper function to create text chunks from documents."""
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=" ",
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
return all_texts
|
||||
|
||||
@@ -6,11 +6,13 @@ Supports Chrome browser history.
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
@@ -84,7 +86,7 @@ class BrowserRAG(BaseRAGExample):
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
|
||||
0
apps/chatgpt_data/__init__.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads and processes ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChatGPTReader(BaseReader):
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
||||
"""
|
||||
Extract chat.html from ChatGPT export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the ChatGPT export zip file
|
||||
|
||||
Returns:
|
||||
HTML content as string, or None if not found
|
||||
"""
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for chat.html or conversations.html
|
||||
html_files = [
|
||||
f
|
||||
for f in zip_file.namelist()
|
||||
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
||||
]
|
||||
|
||||
if not html_files:
|
||||
print(f"No HTML chat file found in {zip_path}")
|
||||
return None
|
||||
|
||||
# Use the first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
|
||||
with zip_file.open(html_file) as f:
|
||||
return f.read().decode("utf-8", errors="ignore")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
||||
return None
|
||||
|
||||
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse ChatGPT HTML export to extract conversations.
|
||||
|
||||
Args:
|
||||
html_content: HTML content from ChatGPT export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
conversations = []
|
||||
|
||||
# Try different possible structures for ChatGPT exports
|
||||
# Structure 1: Look for conversation containers
|
||||
conversation_containers = soup.find_all(
|
||||
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
||||
)
|
||||
|
||||
if not conversation_containers:
|
||||
# Structure 2: Look for message containers directly
|
||||
conversation_containers = [soup] # Use the entire document as one conversation
|
||||
|
||||
for container in conversation_containers:
|
||||
conversation = self._extract_conversation_from_container(container)
|
||||
if conversation and conversation.get("messages"):
|
||||
conversations.append(conversation)
|
||||
|
||||
# If no structured conversations found, try to extract all text as one conversation
|
||||
if not conversations:
|
||||
all_text = soup.get_text(separator="\n", strip=True)
|
||||
if all_text:
|
||||
conversations.append(
|
||||
{
|
||||
"title": "ChatGPT Conversation",
|
||||
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
||||
"timestamp": None,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_container(self, container) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a container element.
|
||||
|
||||
Args:
|
||||
container: BeautifulSoup element containing conversation
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Look for message elements with various possible structures
|
||||
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
||||
|
||||
for selector in message_selectors:
|
||||
message_elements = container.select(selector)
|
||||
if message_elements:
|
||||
break
|
||||
else:
|
||||
message_elements = []
|
||||
|
||||
# If no structured messages found, treat the entire container as one message
|
||||
if not message_elements:
|
||||
text_content = container.get_text(separator="\n", strip=True)
|
||||
if text_content:
|
||||
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
||||
else:
|
||||
for element in message_elements:
|
||||
message = self._extract_message_from_element(element)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Try to extract conversation title
|
||||
title_element = container.find(["h1", "h2", "h3", "title"])
|
||||
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
||||
|
||||
# Try to extract timestamp from various possible locations
|
||||
timestamp = self._extract_timestamp_from_container(container)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_element(self, element) -> dict | None:
|
||||
"""
|
||||
Extract message data from an element.
|
||||
|
||||
Args:
|
||||
element: BeautifulSoup element containing message
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
text_content = element.get_text(separator=" ", strip=True)
|
||||
|
||||
# Skip empty or very short messages
|
||||
if not text_content or len(text_content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Try to determine role (user/assistant) from class names or content
|
||||
role = "mixed" # Default role
|
||||
|
||||
class_names = " ".join(element.get("class", [])).lower()
|
||||
if "user" in class_names or "human" in class_names:
|
||||
role = "user"
|
||||
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
||||
role = "assistant"
|
||||
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
||||
role = "user"
|
||||
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
||||
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
||||
role = "assistant"
|
||||
text_content = re.sub(
|
||||
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Try to extract timestamp
|
||||
timestamp = self._extract_timestamp_from_element(element)
|
||||
|
||||
return {"role": role, "content": text_content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_element(self, element) -> str | None:
|
||||
"""Extract timestamp from element."""
|
||||
# Look for timestamp in various attributes and child elements
|
||||
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
||||
for attr in timestamp_attrs:
|
||||
if element.get(attr):
|
||||
return element.get(attr)
|
||||
|
||||
# Look for time elements
|
||||
time_element = element.find("time")
|
||||
if time_element:
|
||||
return time_element.get("datetime") or time_element.get_text(strip=True)
|
||||
|
||||
# Look for date-like text patterns
|
||||
text = element.get_text()
|
||||
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
||||
|
||||
for pattern in date_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_container(self, container) -> str | None:
|
||||
"""Extract timestamp from conversation container."""
|
||||
return self._extract_timestamp_from_element(container)
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "ChatGPT Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[ChatGPT]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load ChatGPT export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing ChatGPT export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not chatgpt_export_path:
|
||||
print("No ChatGPT export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(chatgpt_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
html_content = None
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract HTML from zip file
|
||||
html_content = self._extract_html_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".html":
|
||||
# Read HTML file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for HTML files in directory
|
||||
html_files = list(export_path.glob("*.html"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if html_files:
|
||||
# Use first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
try:
|
||||
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {html_file}: {e}")
|
||||
return docs
|
||||
|
||||
elif zip_files:
|
||||
# Use first zip file found
|
||||
zip_file = zip_files[0]
|
||||
print(f"Found zip file: {zip_file}")
|
||||
html_content = self._extract_html_from_zip(zip_file)
|
||||
|
||||
else:
|
||||
print(f"No HTML or zip files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not html_content:
|
||||
print("No HTML content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from HTML
|
||||
print("Parsing ChatGPT conversations from HTML...")
|
||||
conversations = self._parse_chatgpt_html(html_content)
|
||||
|
||||
if not conversations:
|
||||
print("No conversations found in HTML content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from ChatGPT export")
|
||||
return docs
|
||||
187
apps/chatgpt_rag.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
ChatGPT RAG example using the unified interface.
|
||||
Supports ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
||||
|
||||
|
||||
class ChatGPTRAG(BaseRAGExample):
|
||||
"""RAG example for ChatGPT conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="ChatGPT",
|
||||
description="Process and query ChatGPT conversation exports with LEANN",
|
||||
default_index_name="chatgpt_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add ChatGPT-specific arguments."""
|
||||
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
||||
chatgpt_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./chatgpt_export",
|
||||
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find ChatGPT export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to ChatGPT export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".html"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and html files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.html"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load ChatGPT export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your ChatGPT data:")
|
||||
print("1. Sign in to ChatGPT")
|
||||
print("2. Click on your profile icon → Settings → Data Controls")
|
||||
print("3. Click 'Export' under Export Data")
|
||||
print("4. Download the zip file from the email link")
|
||||
print("5. Extract or place the file/directory at the specified path")
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_chatgpt_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} ChatGPT export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
chatgpt_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid ChatGPT export")
|
||||
print("- Check that the HTML file contains conversation data")
|
||||
print("- Try extracting the zip file and pointing to the HTML file directly")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for ChatGPT RAG
|
||||
print("\n🤖 ChatGPT RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about travel planning'")
|
||||
print("- 'What advice did ChatGPT give me about career development?'")
|
||||
print("- 'Search for conversations about cooking recipes'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
||||
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
||||
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ChatGPTRAG()
|
||||
asyncio.run(rag.run())
|
||||
47
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Unified chunking utilities facade.
|
||||
|
||||
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||
that both repo apps (importing `chunking`) and installed wheels share one
|
||||
single implementation. When running from the repo without installation, it
|
||||
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||
if leann_src.exists():
|
||||
sys.path.insert(0, str(leann_src))
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"_traditional_chunks_as_dicts",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
"detect_code_files",
|
||||
"get_language_from_extension",
|
||||
]
|
||||
0
apps/claude_data/__init__.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads and processes Claude conversation data from exported JSON files.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ClaudeReader(BaseReader):
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads Claude conversation data from exported JSON files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
||||
"""
|
||||
Extract JSON files from Claude export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the Claude export zip file
|
||||
|
||||
Returns:
|
||||
List of JSON content strings, or empty list if not found
|
||||
"""
|
||||
json_contents = []
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for JSON files
|
||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
||||
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {zip_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(json_files)} JSON files in archive")
|
||||
|
||||
for json_file in json_files:
|
||||
with zip_file.open(json_file) as f:
|
||||
content = f.read().decode("utf-8", errors="ignore")
|
||||
json_contents.append(content)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
||||
|
||||
return json_contents
|
||||
|
||||
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse Claude JSON export to extract conversations.
|
||||
|
||||
Args:
|
||||
json_content: JSON content from Claude export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
return []
|
||||
|
||||
conversations = []
|
||||
|
||||
# Handle different possible JSON structures
|
||||
if isinstance(data, list):
|
||||
# If data is a list of conversations
|
||||
for item in data:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif isinstance(data, dict):
|
||||
# Check for common structures
|
||||
if "conversations" in data:
|
||||
# Structure: {"conversations": [...]}
|
||||
for item in data["conversations"]:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif "messages" in data:
|
||||
# Single conversation with messages
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
else:
|
||||
# Try to treat the whole object as a conversation
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a JSON object.
|
||||
|
||||
Args:
|
||||
conv_data: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
if not isinstance(conv_data, dict):
|
||||
return None
|
||||
|
||||
messages = []
|
||||
|
||||
# Look for messages in various possible structures
|
||||
message_sources = []
|
||||
if "messages" in conv_data:
|
||||
message_sources = conv_data["messages"]
|
||||
elif "chat" in conv_data:
|
||||
message_sources = conv_data["chat"]
|
||||
elif "conversation" in conv_data:
|
||||
message_sources = conv_data["conversation"]
|
||||
else:
|
||||
# If no clear message structure, try to extract from the object itself
|
||||
if "content" in conv_data and "role" in conv_data:
|
||||
message_sources = [conv_data]
|
||||
|
||||
for msg_data in message_sources:
|
||||
message = self._extract_message_from_json(msg_data)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Extract conversation metadata
|
||||
title = self._extract_title_from_conversation(conv_data, messages)
|
||||
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract message data from a JSON message object.
|
||||
|
||||
Args:
|
||||
msg_data: Dictionary containing message data
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
if not isinstance(msg_data, dict):
|
||||
return None
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = ""
|
||||
content_fields = ["content", "text", "message", "body"]
|
||||
for field in content_fields:
|
||||
if msg_data.get(field):
|
||||
content = str(msg_data[field])
|
||||
break
|
||||
|
||||
if not content or len(content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Extract role (user/assistant/human/ai/claude)
|
||||
role = "mixed" # Default role
|
||||
role_fields = ["role", "sender", "from", "author", "type"]
|
||||
for field in role_fields:
|
||||
if msg_data.get(field):
|
||||
role_value = str(msg_data[field]).lower()
|
||||
if role_value in ["user", "human", "person"]:
|
||||
role = "user"
|
||||
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
||||
role = "assistant"
|
||||
break
|
||||
|
||||
# Extract timestamp
|
||||
timestamp = self._extract_timestamp_from_message(msg_data)
|
||||
|
||||
return {"role": role, "content": content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
||||
"""Extract timestamp from message data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
||||
for field in timestamp_fields:
|
||||
if msg_data.get(field):
|
||||
return str(msg_data[field])
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
||||
"""Extract timestamp from conversation data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
||||
for field in timestamp_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
return None
|
||||
|
||||
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
||||
"""Extract or generate title for conversation."""
|
||||
# Try to find explicit title
|
||||
title_fields = ["title", "name", "subject", "topic"]
|
||||
for field in title_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
|
||||
# Generate title from first user message
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if content:
|
||||
# Use first 50 characters as title
|
||||
title = content[:50].strip()
|
||||
if len(content) > 50:
|
||||
title += "..."
|
||||
return title
|
||||
|
||||
return "Claude Conversation"
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "Claude Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[Claude]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Claude export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing Claude export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
claude_export_path (str): Specific path to Claude export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not claude_export_path:
|
||||
print("No Claude export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(claude_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
json_contents = []
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract JSON from zip file
|
||||
json_contents = self._extract_json_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".json":
|
||||
# Read JSON file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for JSON files in directory
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if json_files:
|
||||
print(f"Found {len(json_files)} JSON files in directory")
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
continue
|
||||
|
||||
if zip_files:
|
||||
print(f"Found {len(zip_files)} ZIP files in directory")
|
||||
for zip_file in zip_files:
|
||||
zip_contents = self._extract_json_from_zip(zip_file)
|
||||
json_contents.extend(zip_contents)
|
||||
|
||||
if not json_files and not zip_files:
|
||||
print(f"No JSON or ZIP files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not json_contents:
|
||||
print("No JSON content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from JSON content
|
||||
print("Parsing Claude conversations from JSON...")
|
||||
all_conversations = []
|
||||
for json_content in json_contents:
|
||||
conversations = self._parse_claude_json(json_content)
|
||||
all_conversations.extend(conversations)
|
||||
|
||||
if not all_conversations:
|
||||
print("No conversations found in JSON content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(all_conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in all_conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "Claude Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "Claude Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from Claude export")
|
||||
return docs
|
||||
190
apps/claude_rag.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Claude RAG example using the unified interface.
|
||||
Supports Claude export data from JSON files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .claude_data.claude_reader import ClaudeReader
|
||||
|
||||
|
||||
class ClaudeRAG(BaseRAGExample):
|
||||
"""RAG example for Claude conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Claude",
|
||||
description="Process and query Claude conversation exports with LEANN",
|
||||
default_index_name="claude_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add Claude-specific arguments."""
|
||||
claude_group = parser.add_argument_group("Claude Parameters")
|
||||
claude_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./claude_export",
|
||||
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find Claude export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to Claude export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".json"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and json files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.json"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Claude export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your Claude data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your Claude data:")
|
||||
print("1. Open Claude in your browser")
|
||||
print("2. Look for export/download options in settings or conversation menu")
|
||||
print("3. Download the conversation data (usually in JSON format)")
|
||||
print("4. Place the file/directory at the specified path")
|
||||
print(
|
||||
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
||||
)
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_claude_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} Claude export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ClaudeReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
claude_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid Claude export")
|
||||
print("- Check that the JSON file contains conversation data")
|
||||
print("- Try using a different export format or method")
|
||||
print("- Check Claude's documentation for current export procedures")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for Claude RAG
|
||||
print("\n🤖 Claude RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask Claude about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about code optimization'")
|
||||
print("- 'What advice did Claude give me about software design?'")
|
||||
print("- 'Search for conversations about debugging techniques'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your Claude conversation data")
|
||||
print("2. Place the JSON/ZIP file in ./claude_export/")
|
||||
print("3. Run this script to build your personal Claude knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ClaudeRAG()
|
||||
asyncio.run(rag.run())
|
||||
207
apps/code_rag.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||
Specialized for code repositories with automatic language detection and
|
||||
optimized chunking parameters.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class CodeRAG(BaseRAGExample):
|
||||
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Code",
|
||||
description="Process and query code repositories with AST-aware chunking",
|
||||
default_index_name="code_index",
|
||||
)
|
||||
# Override defaults for code-specific usage
|
||||
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||
self.max_items_default = -1 # Process all code files by default
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add code-specific arguments."""
|
||||
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||
|
||||
code_group.add_argument(
|
||||
"--repo-dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Code repository directory to index (default: current directory)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-extensions",
|
||||
nargs="+",
|
||||
default=list(CODE_EXTENSIONS.keys()),
|
||||
help="File extensions to include (default: supported code extensions)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--exclude-dirs",
|
||||
nargs="+",
|
||||
default=[
|
||||
".git",
|
||||
"__pycache__",
|
||||
"node_modules",
|
||||
"venv",
|
||||
".venv",
|
||||
"build",
|
||||
"dist",
|
||||
"target",
|
||||
],
|
||||
help="Directories to exclude from indexing",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--max-file-size",
|
||||
type=int,
|
||||
default=1000000, # 1MB
|
||||
help="Maximum file size in bytes to process (default: 1MB)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-comments",
|
||||
action="store_true",
|
||||
help="Include comments in chunking (useful for documentation)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--preserve-imports",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Try to preserve import statements in chunks (default: True)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load code files and convert to AST-aware chunks."""
|
||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||
print(f"📁 Including extensions: {args.include_extensions}")
|
||||
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||
|
||||
# Check if repository directory exists
|
||||
repo_path = Path(args.repo_dir)
|
||||
if not repo_path.exists():
|
||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||
|
||||
# Create exclusion filter
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Filter out unwanted files and directories."""
|
||||
path = Path(file_path)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
if path.stat().st_size > args.max_file_size:
|
||||
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check if in excluded directory
|
||||
for exclude_dir in args.exclude_dirs:
|
||||
if exclude_dir in path.parts:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
# Load documents with file filtering
|
||||
documents = SimpleDirectoryReader(
|
||||
args.repo_dir,
|
||||
file_extractor=None,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=args.include_extensions,
|
||||
exclude_hidden=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Apply custom filtering
|
||||
filtered_docs = []
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents = filtered_docs
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading code files: {e}")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print(
|
||||
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||
)
|
||||
return []
|
||||
|
||||
print(f"✅ Loaded {len(documents)} code files")
|
||||
|
||||
# Show breakdown by language/extension
|
||||
ext_counts = {}
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_path:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||
|
||||
print("📊 Files by extension:")
|
||||
for ext, count in sorted(ext_counts.items()):
|
||||
print(f" {ext}: {count} files")
|
||||
|
||||
# Use AST-aware chunking by default for code
|
||||
print(
|
||||
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||
)
|
||||
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=256, # Fallback for non-code files
|
||||
chunk_overlap=64,
|
||||
use_ast_chunking=True, # Always use AST for code RAG
|
||||
ast_chunk_size=args.ast_chunk_size,
|
||||
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||
code_file_extensions=args.include_extensions,
|
||||
ast_fallback_traditional=True,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for code RAG
|
||||
print("\n💻 Code RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'How does the embedding computation work?'")
|
||||
print("- 'What are the main classes in this codebase?'")
|
||||
print("- 'Show me the search implementation'")
|
||||
print("- 'How is error handling implemented?'")
|
||||
print("- 'What design patterns are used?'")
|
||||
print("- 'Explain the chunking logic'")
|
||||
print("\n🚀 Features:")
|
||||
print("- ✅ AST-aware chunking preserves code structure")
|
||||
print("- ✅ Automatic language detection")
|
||||
print("- ✅ Smart filtering of large files and common excludes")
|
||||
print("- ✅ Optimized for code understanding")
|
||||
print("\nUsage examples:")
|
||||
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||
print(
|
||||
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||
)
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = CodeRAG()
|
||||
asyncio.run(rag.run())
|
||||
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||
|
||||
Usage:
|
||||
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
# Add LEANN packages to path
|
||||
_repo_root = Path(__file__).resolve().parents[1]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
import torch # noqa: E402
|
||||
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||
from pdf2image import convert_from_path # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
from torch.utils.data import DataLoader # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
# Import the existing multi-vector implementation
|
||||
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
|
||||
class ColQwenRAG:
|
||||
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||
|
||||
def __init__(self, model_type: str = "colpali"):
|
||||
"""
|
||||
Initialize ColQwen RAG system.
|
||||
|
||||
Args:
|
||||
model_type: "colqwen2" or "colpali"
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.device = self._get_device()
|
||||
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||
if self.device.type == "mps":
|
||||
self.dtype = torch.float32
|
||||
elif self.device.type == "cuda":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||
|
||||
# Load model and processor with MPS-optimized settings
|
||||
try:
|
||||
if model_type == "colqwen2":
|
||||
self.model_name = "vidore/colqwen2-v1.0"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
else: # colpali
|
||||
self.model_name = "vidore/colpali-v1.2"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||
except Exception as e:
|
||||
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_type == "colqwen2":
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_device(self):
|
||||
"""Auto-select best available device."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||
"""
|
||||
Build multimodal index from PDF files.
|
||||
|
||||
Args:
|
||||
pdf_paths: List of PDF file paths
|
||||
index_name: Name for the index
|
||||
pages_dir: Directory to save page images (optional)
|
||||
"""
|
||||
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||
|
||||
# Convert PDFs to images
|
||||
all_images = []
|
||||
all_metadata = []
|
||||
|
||||
if pages_dir:
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=150)
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# Save image if pages_dir specified
|
||||
if pages_dir:
|
||||
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||
image.save(image_path)
|
||||
|
||||
all_images.append(image)
|
||||
all_metadata.append(
|
||||
{
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page_number": i + 1,
|
||||
"image_path": str(image_path) if pages_dir else None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {pdf_path}: {e}")
|
||||
continue
|
||||
|
||||
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||
print(f"All metadata: {all_metadata}")
|
||||
|
||||
# Generate embeddings
|
||||
print("🧠 Generating embeddings...")
|
||||
embeddings = self._embed_images(all_images)
|
||||
|
||||
# Build LEANN index
|
||||
print("🔍 Building LEANN index...")
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=embeddings.shape[-1],
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Create collection and insert data
|
||||
leann_mv.create_collection()
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||
data = {
|
||||
"doc_id": i,
|
||||
"filepath": metadata.get("image_path", ""),
|
||||
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||
}
|
||||
leann_mv.insert(data)
|
||||
|
||||
# Build the index
|
||||
leann_mv.create_index()
|
||||
print(f"✅ Index '{index_name}' built successfully!")
|
||||
|
||||
return leann_mv
|
||||
|
||||
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||
"""
|
||||
Search the index with a text query.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to search
|
||||
query: Text query
|
||||
top_k: Number of results to return
|
||||
"""
|
||||
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||
|
||||
# Load index
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=128, # Will be updated when loading
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self._embed_query(query)
|
||||
|
||||
# Search (returns list of (score, doc_id) tuples)
|
||||
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||
|
||||
# Display results
|
||||
print(f"\n📋 Top {len(search_results)} results:")
|
||||
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||
# Get metadata for this doc_id (we need to load the metadata)
|
||||
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||
|
||||
return search_results
|
||||
|
||||
def ask(self, index_name: str, interactive: bool = False):
|
||||
"""
|
||||
Interactive Q&A with the indexed documents.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to query
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||
|
||||
if interactive:
|
||||
print("Type 'quit' to exit, 'help' for commands")
|
||||
while True:
|
||||
try:
|
||||
query = input("\n🤔 Your question: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("Commands: quit/exit/q (exit), help (this message)")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
self.search(index_name, query, top_k=3)
|
||||
|
||||
# TODO: Add answer generation with Qwen-VL
|
||||
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
else:
|
||||
query = input("🤔 Your question: ").strip()
|
||||
if query:
|
||||
self.search(index_name, query)
|
||||
|
||||
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Generate embeddings for a list of images."""
|
||||
dataset = ListDataset(images)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||
|
||||
embeddings = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||
batch_images = cast(list, batch)
|
||||
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||
batch_embeddings = self.model(**batch_inputs)
|
||||
embeddings.append(batch_embeddings.cpu())
|
||||
|
||||
return torch.cat(embeddings, dim=0)
|
||||
|
||||
def _embed_query(self, query: str) -> torch.Tensor:
|
||||
"""Generate embedding for a text query."""
|
||||
with torch.no_grad():
|
||||
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||
query_embedding = self.model(**query_inputs)
|
||||
return query_embedding.cpu()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||
build_parser.add_argument("--index", required=True, help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||
search_parser.add_argument("index", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||
search_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||
ask_parser.add_argument("index", help="Index name")
|
||||
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
ask_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Initialize ColQwen RAG
|
||||
if args.command == "build":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
|
||||
# Get PDF files
|
||||
pdf_dir = Path(args.pdfs)
|
||||
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||
pdf_paths = [str(pdf_dir)]
|
||||
elif pdf_dir.is_dir():
|
||||
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||
else:
|
||||
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||
return
|
||||
|
||||
if not pdf_paths:
|
||||
print(f"❌ No PDF files found in {args.pdfs}")
|
||||
return
|
||||
|
||||
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||
|
||||
elif args.command == "search":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.search(args.index, args.query, args.top_k)
|
||||
|
||||
elif args.command == "ask":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.ask(args.index, args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,11 +5,13 @@ Supports PDF, TXT, MD, and other document formats.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
@@ -44,8 +46,13 @@ class DocumentRAG(BaseRAGExample):
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--enable-code-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
@@ -59,16 +66,12 @@ class DocumentRAG(BaseRAGExample):
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=args.file_types if args.file_types else None,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
@@ -76,9 +79,22 @@ class DocumentRAG(BaseRAGExample):
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Convert to text chunks
|
||||
# Determine chunking strategy
|
||||
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||
|
||||
if use_ast:
|
||||
print("Using AST-aware chunking for code files")
|
||||
|
||||
# Convert to text chunks with optional AST support
|
||||
all_texts = create_text_chunks(
|
||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
use_ast_chunking=use_ast,
|
||||
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
@@ -102,6 +118,10 @@ if __name__ == "__main__":
|
||||
print(
|
||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||
)
|
||||
print("\n🚀 NEW: Code-aware chunking available!")
|
||||
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||
print("- Supports Python, Java, C#, TypeScript files")
|
||||
print("- Better semantic understanding of code structure")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
|
||||
@@ -127,11 +127,12 @@ class EmlxMboxReader(MboxReader):
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
file: Path, # Note: for EmlxMboxReader, this is actually a directory
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
directory = file # Rename for clarity - this is a directory of .emlx files
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
@@ -5,11 +5,13 @@ Supports Apple Mail on macOS.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
@@ -63,7 +65,7 @@ class EmailRAG(BaseRAGExample):
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
|
||||
@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
|
||||
@@ -86,7 +86,7 @@ class WeChatHistoryReader(BaseReader):
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
return result.returncode == 0 and bool(result.stdout.strip())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -314,7 +314,9 @@ class WeChatHistoryReader(BaseReader):
|
||||
|
||||
return concatenated_groups
|
||||
|
||||
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||
def _create_concatenated_content(
|
||||
self, message_group: dict, contact_name: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create concatenated content from a group of messages.
|
||||
|
||||
|
||||
219
apps/image_rag.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIP Image RAG Application
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||
You can index a directory of images and search them using text queries.
|
||||
|
||||
Usage:
|
||||
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
|
||||
|
||||
class ImageRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for images using CLIP embeddings.
|
||||
|
||||
This class provides a complete RAG pipeline for image data, including
|
||||
CLIP embedding generation, indexing, and text-based image search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Image RAG",
|
||||
description="RAG application for images using CLIP embeddings",
|
||||
default_index_name="image_index",
|
||||
)
|
||||
# Override default embedding model to use CLIP
|
||||
self.embedding_model_default = "clip-ViT-L-14"
|
||||
self.embedding_mode_default = "sentence-transformers"
|
||||
self._image_data: list[dict] = []
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add image-specific arguments."""
|
||||
image_group = parser.add_argument_group("Image Parameters")
|
||||
image_group.add_argument(
|
||||
"--image-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing images to index",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--image-extensions",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for CLIP embedding generation (default: 32)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||
self._image_data = self._load_images_and_embeddings(args)
|
||||
return [entry["text"] for entry in self._image_data]
|
||||
|
||||
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||
"""Helper to process images and produce embeddings/metadata."""
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||
|
||||
print(f"📸 Loading images from {image_dir}...")
|
||||
|
||||
# Find all image files
|
||||
image_files = []
|
||||
for ext in args.image_extensions:
|
||||
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(
|
||||
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||
)
|
||||
|
||||
print(f"✅ Found {len(image_files)} images")
|
||||
|
||||
# Limit if max_items is set
|
||||
if args.max_items > 0:
|
||||
image_files = image_files[: args.max_items]
|
||||
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||
|
||||
# Load CLIP model
|
||||
print("🔍 Loading CLIP model...")
|
||||
model = SentenceTransformer(self.embedding_model_default)
|
||||
|
||||
# Process images and generate embeddings
|
||||
print("🖼️ Processing images and generating embeddings...")
|
||||
image_data = []
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
for image_path in tqdm(image_files, desc="Processing images"):
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
batch_images.append(image)
|
||||
batch_paths.append(image_path)
|
||||
|
||||
# Process in batches
|
||||
if len(batch_images) >= args.batch_size:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=args.batch_size,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining images
|
||||
if batch_images:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=len(batch_images),
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Processed {len(image_data)} images")
|
||||
return image_data
|
||||
|
||||
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||
"""Build index using pre-computed CLIP embeddings."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if not self._image_data or len(self._image_data) != len(texts):
|
||||
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||
|
||||
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=self.embedding_model_default,
|
||||
embedding_mode=self.embedding_mode_default,
|
||||
is_recompute=False,
|
||||
distance_metric="cosine",
|
||||
graph_degree=args.graph_degree,
|
||||
build_complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
)
|
||||
|
||||
for text, data in zip(texts, self._image_data):
|
||||
builder.add_text(text=text, metadata=data["metadata"])
|
||||
|
||||
ids = [str(i) for i in range(len(self._image_data))]
|
||||
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||
pickle.dump((ids, embeddings), f)
|
||||
pkl_path = f.name
|
||||
|
||||
try:
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||
print(f"✅ Index built successfully at {index_path}")
|
||||
return index_path
|
||||
finally:
|
||||
Path(pkl_path).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the image RAG application."""
|
||||
import asyncio
|
||||
|
||||
app = ImageRAG()
|
||||
asyncio.run(app.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
apps/imessage_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""iMessage data processing module."""
|
||||
342
apps/imessage_data/imessage_reader.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads and processes iMessage conversation data from the macOS Messages database.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class IMessageReader(BaseReader):
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _get_default_chat_db_path(self) -> Path:
|
||||
"""
|
||||
Get the default path to the iMessage chat database.
|
||||
|
||||
Returns:
|
||||
Path to the chat.db file
|
||||
"""
|
||||
home = Path.home()
|
||||
return home / "Library" / "Messages" / "chat.db"
|
||||
|
||||
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
||||
"""
|
||||
Convert Cocoa timestamp to readable format.
|
||||
|
||||
Args:
|
||||
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if cocoa_timestamp == 0:
|
||||
return "Unknown"
|
||||
|
||||
try:
|
||||
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
||||
# Convert to seconds and add to Unix epoch
|
||||
cocoa_epoch = datetime(2001, 1, 1)
|
||||
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
||||
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
||||
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "Unknown"
|
||||
|
||||
def _get_contact_name(self, handle_id: str) -> str:
|
||||
"""
|
||||
Get a readable contact name from handle ID.
|
||||
|
||||
Args:
|
||||
handle_id: The handle ID (phone number or email)
|
||||
|
||||
Returns:
|
||||
Formatted contact name
|
||||
"""
|
||||
if not handle_id:
|
||||
return "Unknown"
|
||||
|
||||
# Clean up phone numbers and emails for display
|
||||
if "@" in handle_id:
|
||||
return handle_id # Email address
|
||||
elif handle_id.startswith("+"):
|
||||
return handle_id # International phone number
|
||||
else:
|
||||
# Try to format as phone number
|
||||
digits = "".join(filter(str.isdigit, handle_id))
|
||||
if len(digits) == 10:
|
||||
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
||||
elif len(digits) == 11 and digits[0] == "1":
|
||||
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
||||
else:
|
||||
return handle_id
|
||||
|
||||
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
||||
"""
|
||||
Read messages from the iMessage database.
|
||||
|
||||
Args:
|
||||
db_path: Path to the chat.db file
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
if not db_path.exists():
|
||||
print(f"iMessage database not found at: {db_path}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query to get messages with chat and handle information
|
||||
query = """
|
||||
SELECT
|
||||
m.ROWID as message_id,
|
||||
m.text,
|
||||
m.date,
|
||||
m.is_from_me,
|
||||
m.service,
|
||||
c.chat_identifier,
|
||||
c.display_name as chat_display_name,
|
||||
h.id as handle_id,
|
||||
c.ROWID as chat_id
|
||||
FROM message m
|
||||
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
||||
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
||||
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
||||
WHERE m.text IS NOT NULL AND m.text != ''
|
||||
ORDER BY c.ROWID, m.date
|
||||
"""
|
||||
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
messages = []
|
||||
for row in rows:
|
||||
(
|
||||
message_id,
|
||||
text,
|
||||
date,
|
||||
is_from_me,
|
||||
service,
|
||||
chat_identifier,
|
||||
chat_display_name,
|
||||
handle_id,
|
||||
chat_id,
|
||||
) = row
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
"timestamp": self._convert_cocoa_timestamp(date),
|
||||
"is_from_me": bool(is_from_me),
|
||||
"service": service or "iMessage",
|
||||
"chat_identifier": chat_identifier or "Unknown",
|
||||
"chat_display_name": chat_display_name or "Unknown Chat",
|
||||
"handle_id": handle_id or "Unknown",
|
||||
"contact_name": self._get_contact_name(handle_id or ""),
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
messages.append(message)
|
||||
|
||||
conn.close()
|
||||
print(f"Found {len(messages)} messages in database")
|
||||
return messages
|
||||
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error reading iMessage database: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error reading iMessage database: {e}")
|
||||
return []
|
||||
|
||||
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
||||
"""
|
||||
Group messages by chat ID.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary mapping chat_id to list of messages
|
||||
"""
|
||||
chats = {}
|
||||
for message in messages:
|
||||
chat_id = message["chat_id"]
|
||||
if chat_id not in chats:
|
||||
chats[chat_id] = []
|
||||
chats[chat_id].append(message)
|
||||
|
||||
return chats
|
||||
|
||||
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
||||
"""
|
||||
Create concatenated content from chat messages.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
messages: List of messages in the chat
|
||||
|
||||
Returns:
|
||||
Concatenated text content
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Get chat info from first message
|
||||
first_msg = messages[0]
|
||||
chat_name = first_msg["chat_display_name"]
|
||||
chat_identifier = first_msg["chat_identifier"]
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
|
||||
if is_from_me:
|
||||
prefix = "[You]"
|
||||
else:
|
||||
prefix = f"[{contact_name}]"
|
||||
|
||||
if timestamp != "Unknown":
|
||||
prefix += f" ({timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {text}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
doc_content = f"""Chat: {chat_name}
|
||||
Identifier: {chat_identifier}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def _create_individual_content(self, message: dict) -> str:
|
||||
"""
|
||||
Create content for individual message.
|
||||
|
||||
Args:
|
||||
message: Message dictionary
|
||||
|
||||
Returns:
|
||||
Formatted message content
|
||||
"""
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
chat_name = message["chat_display_name"]
|
||||
|
||||
sender = "You" if is_from_me else contact_name
|
||||
|
||||
return f"""Message from {sender} in chat "{chat_name}"
|
||||
Time: {timestamp}
|
||||
Content: {text}
|
||||
"""
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load iMessage data and return as documents.
|
||||
|
||||
Args:
|
||||
input_dir: Optional path to directory containing chat.db file.
|
||||
If not provided, uses default macOS location.
|
||||
**load_kwargs: Additional arguments (unused)
|
||||
|
||||
Returns:
|
||||
List of Document objects containing iMessage data
|
||||
"""
|
||||
docs = []
|
||||
|
||||
# Determine database path
|
||||
if input_dir:
|
||||
db_path = Path(input_dir) / "chat.db"
|
||||
else:
|
||||
db_path = self._get_default_chat_db_path()
|
||||
|
||||
print(f"Reading iMessage database from: {db_path}")
|
||||
|
||||
# Read messages from database
|
||||
messages = self._read_messages_from_db(db_path)
|
||||
if not messages:
|
||||
return docs
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Group messages by chat and create concatenated documents
|
||||
chats = self._group_messages_by_chat(messages)
|
||||
|
||||
for chat_id, chat_messages in chats.items():
|
||||
if not chat_messages:
|
||||
continue
|
||||
|
||||
content = self._create_concatenated_content(chat_id, chat_messages)
|
||||
|
||||
# Create metadata
|
||||
first_msg = chat_messages[0]
|
||||
last_msg = chat_messages[-1]
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"chat_id": chat_id,
|
||||
"chat_name": first_msg["chat_display_name"],
|
||||
"chat_identifier": first_msg["chat_identifier"],
|
||||
"message_count": len(chat_messages),
|
||||
"first_message_date": first_msg["timestamp"],
|
||||
"last_message_date": last_msg["timestamp"],
|
||||
"participants": list(
|
||||
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
||||
),
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
else:
|
||||
# Create individual documents for each message
|
||||
for message in messages:
|
||||
content = self._create_individual_content(message)
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"message_id": message["message_id"],
|
||||
"chat_id": message["chat_id"],
|
||||
"chat_name": message["chat_display_name"],
|
||||
"chat_identifier": message["chat_identifier"],
|
||||
"timestamp": message["timestamp"],
|
||||
"is_from_me": message["is_from_me"],
|
||||
"contact_name": message["contact_name"],
|
||||
"service": message["service"],
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
print(f"Created {len(docs)} documents from iMessage data")
|
||||
return docs
|
||||
126
apps/imessage_rag.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
iMessage RAG Example.
|
||||
|
||||
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from leann.chunking_utils import create_text_chunks
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.imessage_data.imessage_reader import IMessageReader
|
||||
|
||||
|
||||
class IMessageRAG(BaseRAGExample):
|
||||
"""RAG example for iMessage conversation history."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="iMessage",
|
||||
description="RAG on your iMessage conversation history",
|
||||
default_index_name="imessage_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add iMessage-specific arguments."""
|
||||
imessage_group = parser.add_argument_group("iMessage Parameters")
|
||||
imessage_group.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process each message individually instead of concatenating by conversation",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum characters per text chunk (default: 1000)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Overlap between text chunks (default: 200)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load iMessage history and convert to text chunks."""
|
||||
print("Loading iMessage conversation history...")
|
||||
|
||||
# Determine concatenation setting
|
||||
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
||||
|
||||
# Initialize iMessage reader
|
||||
reader = IMessageReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Load documents
|
||||
try:
|
||||
if args.db_path:
|
||||
# Use custom database path
|
||||
db_dir = str(Path(args.db_path).parent)
|
||||
documents = reader.load_data(input_dir=db_dir)
|
||||
else:
|
||||
# Use default macOS location
|
||||
documents = reader.load_data()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading iMessage data: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
||||
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
||||
print("3. Try specifying a custom path with --db-path if you have a backup")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print("No iMessage conversations found!")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} iMessage documents")
|
||||
|
||||
# Show some statistics
|
||||
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
||||
print(f"Total messages: {total_messages}")
|
||||
|
||||
if concatenate:
|
||||
# Show chat statistics
|
||||
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
||||
unique_chats = len(set(chat_names))
|
||||
print(f"Unique conversations: {unique_chats}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0:
|
||||
all_texts = all_texts[: args.max_items]
|
||||
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
app = IMessageRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
@@ -0,0 +1,113 @@
|
||||
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
||||
|
||||
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
||||
|
||||
### What you’ll run
|
||||
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
||||
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
||||
|
||||
## Prerequisites (macOS)
|
||||
|
||||
### 1) Homebrew poppler (for pdf2image)
|
||||
```bash
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### 2) Python environment
|
||||
Use uv (recommended) or pip. Python 3.9+.
|
||||
|
||||
Using uv:
|
||||
```bash
|
||||
uv pip install \
|
||||
colpali_engine \
|
||||
pdf2image \
|
||||
pillow \
|
||||
matplotlib qwen_vl_utils \
|
||||
einops \
|
||||
seaborn
|
||||
```
|
||||
|
||||
Notes:
|
||||
- On first run, models download from Hugging Face. Login/config if needed.
|
||||
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
||||
```bash
|
||||
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
||||
```
|
||||
|
||||
## Run the demos
|
||||
|
||||
### A) Local PDF example
|
||||
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
# If you don't have the sample PDF locally, download it (ignored by Git)
|
||||
mkdir -p pdfs
|
||||
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
||||
ls pdfs/2004.12832v2.pdf
|
||||
# Ensure output dir exists
|
||||
mkdir -p pages
|
||||
python multi-vector-leann-paper-example.py
|
||||
```
|
||||
Expected:
|
||||
- Page images in `pages/`.
|
||||
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
||||
|
||||
To use your own PDF: edit `pdf_path` near the top of the script.
|
||||
|
||||
### B) Similarity map + answer demo
|
||||
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
Artifacts (when enabled):
|
||||
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
||||
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
||||
|
||||
Key knobs in the script (top of file):
|
||||
- `QUERY`: your question
|
||||
- `MODEL`: `"colqwen2"` or `"colpali"`
|
||||
- `USE_HF_DATASET`: set `False` to use local pages
|
||||
- `PDF`, `PAGES_DIR`: for local mode
|
||||
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
||||
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
||||
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
||||
|
||||
## Troubleshooting
|
||||
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
||||
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
||||
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
||||
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
||||
|
||||
## Notes
|
||||
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
||||
- For local PDFs, page images go to `./pages/`.
|
||||
|
||||
|
||||
### Retrieval and Visualization Example
|
||||
|
||||
Example settings in `multi-vector-leann-similarity-map.py`:
|
||||
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
||||
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
||||
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
||||
|
||||
Run:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Outputs (by default):
|
||||
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
||||
- Similarity map: `./figures/similarity_map_rank1.png`
|
||||
|
||||
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
||||
"):
|
||||

|
||||
|
||||
Notes:
|
||||
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
||||
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
||||
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||
from PIL import Image
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in [".png", ".jpg", ".jpeg"]:
|
||||
for img_file in img_dir.glob(f"*{ext}"):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != "RGB":
|
||||
test_image = test_image.convert("RGB")
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, "device"):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, "dtype"):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print("✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print("✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
|
After Width: | Height: | Size: 166 KiB |
1452
apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# pip install pdf2image
|
||||
# pip install pymilvus
|
||||
# pip install colpali_engine
|
||||
# pip install tqdm
|
||||
# pip install pillow
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Ensure local leann packages are importable before importing them
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.insert(0, str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.insert(0, str(_leann_hnsw_pkg))
|
||||
|
||||
from leann_multi_vector import LeannMultiVector
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||
_device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(_device_str)
|
||||
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=_dtype,
|
||||
device_map=device,
|
||||
).eval()
|
||||
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||
|
||||
queries = [
|
||||
"How to end-to-end retrieval with ColBert",
|
||||
"Where is ColBERT performance Table, including text representation results?",
|
||||
]
|
||||
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
qs: list[torch.Tensor] = []
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
embeddings_query = model(**batch_query)
|
||||
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
print(qs[0].shape)
|
||||
# %%
|
||||
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
ds: list[torch.Tensor] = []
|
||||
for batch_doc in tqdm(dataloader):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
embeddings_doc = model(**batch_doc)
|
||||
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
|
||||
print(ds[0].shape)
|
||||
|
||||
# %%
|
||||
# Build HNSW index via LeannMultiVector primitives and run search
|
||||
index_path = "./indexes/colpali.leann"
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever.create_collection()
|
||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||
for i in range(len(filepaths)):
|
||||
data = {
|
||||
"colbert_vecs": ds[i].float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
for query in qs:
|
||||
query_np = query.float().numpy()
|
||||
result = retriever.search(query_np, topk=1)
|
||||
print(filepaths[result[0][1]])
|
||||
@@ -0,0 +1,713 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Enable faulthandler to get stack trace on segfault
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
_load_images_from_dir,
|
||||
_maybe_convert_pdf_to_images,
|
||||
_load_colvision,
|
||||
_embed_images,
|
||||
_embed_queries,
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# %%
|
||||
# Config
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
# Single dataset name (used when DATASET_NAMES is None)
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
|
||||
# Can be:
|
||||
# - List of strings: ["dataset1", "dataset2"]
|
||||
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
|
||||
# - Mixed: ["dataset1", ("dataset2", "config2")]
|
||||
#
|
||||
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
|
||||
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
|
||||
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
|
||||
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
|
||||
# - "pixparse/arxiv-papers" (if available, arXiv papers)
|
||||
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
|
||||
# - "huggingface/document-images" (if available)
|
||||
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
|
||||
# Image field name in the dataset (auto-detect if None)
|
||||
# Common names: "page_image", "image", "images", "img"
|
||||
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||
# If set, images will be loaded directly from this folder
|
||||
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||
# Whether to recursively search subdirectories when loading from custom folder
|
||||
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# Fast-Plaid index settings (alternative to LEANN index)
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
SIMILARITY_MAP: bool = True
|
||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||
ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 1024
|
||||
|
||||
|
||||
# %%
|
||||
# CLI overrides
|
||||
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
|
||||
parser.add_argument(
|
||||
"--search-method",
|
||||
type=str,
|
||||
choices=["ann", "exact", "exact-all"],
|
||||
default="ann",
|
||||
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=QUERY,
|
||||
help=f"Query string to search for. Default: '{QUERY}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
fast_plaid_index: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid index handling
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is not None:
|
||||
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Fast-Plaid index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
except Exception as e:
|
||||
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||
print("Will rebuild the index...")
|
||||
need_to_build_index = True
|
||||
fast_plaid_index = None
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
# Original LEANN index handling
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
# Check for custom folder path first (takes precedence)
|
||||
if CUSTOM_FOLDER_PATH:
|
||||
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||
if CUSTOM_FOLDER_RECURSIVE:
|
||||
print(" (recursive mode: searching subdirectories)")
|
||||
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||
print(f" Found {len(filepaths)} image files")
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||
)
|
||||
print(f" Successfully loaded {len(images)} images")
|
||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||
elif USE_HF_DATASET:
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
||||
|
||||
# Determine which datasets to load
|
||||
if DATASET_NAMES is not None:
|
||||
dataset_names_to_load = DATASET_NAMES
|
||||
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
|
||||
else:
|
||||
dataset_names_to_load = [DATASET_NAME]
|
||||
print(f"Loading single dataset: {DATASET_NAME}")
|
||||
|
||||
# Load and combine datasets
|
||||
all_datasets_to_concat = []
|
||||
|
||||
for dataset_entry in dataset_names_to_load:
|
||||
# Handle both string and tuple formats
|
||||
if isinstance(dataset_entry, tuple):
|
||||
dataset_name, config_name = dataset_entry
|
||||
else:
|
||||
dataset_name = dataset_entry
|
||||
config_name = None
|
||||
|
||||
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
|
||||
|
||||
# Load dataset to check available splits
|
||||
# If config_name is provided, use it; otherwise try without config
|
||||
try:
|
||||
if config_name:
|
||||
dataset_dict = load_dataset(dataset_name, config_name)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_name)
|
||||
except ValueError as e:
|
||||
if "Config name is missing" in str(e):
|
||||
# Try to get available configs and suggest
|
||||
from datasets import get_dataset_config_names
|
||||
try:
|
||||
available_configs = get_dataset_config_names(dataset_name)
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Available configs: {available_configs}. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Determine which splits to load
|
||||
if DATASET_SPLITS is None:
|
||||
# Auto-detect: try to load all available splits
|
||||
available_splits = list(dataset_dict.keys())
|
||||
print(f" Auto-detected splits: {available_splits}")
|
||||
splits_to_load = available_splits
|
||||
else:
|
||||
splits_to_load = DATASET_SPLITS
|
||||
|
||||
# Load and concatenate multiple splits for this dataset
|
||||
datasets_to_concat: list[Dataset] = []
|
||||
for split in splits_to_load:
|
||||
if split not in dataset_dict:
|
||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||
continue
|
||||
split_dataset = cast(Dataset, dataset_dict[split])
|
||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||
datasets_to_concat.append(split_dataset)
|
||||
|
||||
if not datasets_to_concat:
|
||||
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
|
||||
continue
|
||||
|
||||
# Concatenate splits for this dataset
|
||||
if len(datasets_to_concat) > 1:
|
||||
combined_dataset = concatenate_datasets(datasets_to_concat)
|
||||
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
|
||||
else:
|
||||
combined_dataset = datasets_to_concat[0]
|
||||
|
||||
all_datasets_to_concat.append(combined_dataset)
|
||||
|
||||
if not all_datasets_to_concat:
|
||||
raise RuntimeError("No valid datasets or splits found.")
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(all_datasets_to_concat) > 1:
|
||||
dataset = concatenate_datasets(all_datasets_to_concat)
|
||||
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
|
||||
else:
|
||||
dataset = all_datasets_to_concat[0]
|
||||
|
||||
# Apply MAX_DOCS limit if specified
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
if N < len(dataset):
|
||||
print(f"Limiting to {N} pages (from {len(dataset)} total)")
|
||||
dataset = dataset.select(range(N))
|
||||
|
||||
# Auto-detect image field name if not specified
|
||||
if IMAGE_FIELD_NAME is None:
|
||||
# Check multiple samples to find the most common image field
|
||||
# (useful when datasets are merged and may have different field names)
|
||||
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
|
||||
field_counts = {}
|
||||
|
||||
# Check first few samples to find image fields
|
||||
num_samples_to_check = min(10, len(dataset))
|
||||
for sample_idx in range(num_samples_to_check):
|
||||
sample = dataset[sample_idx]
|
||||
for field in possible_image_fields:
|
||||
if field in sample and sample[field] is not None:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
|
||||
# Choose the most common field, or first found if tied
|
||||
if field_counts:
|
||||
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
|
||||
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
|
||||
else:
|
||||
# Fallback: check first sample only
|
||||
sample = dataset[0]
|
||||
image_field = None
|
||||
for field in possible_image_fields:
|
||||
if field in sample:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
image_field = field
|
||||
break
|
||||
if image_field is None:
|
||||
raise RuntimeError(
|
||||
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
|
||||
f"Please specify IMAGE_FIELD_NAME manually."
|
||||
)
|
||||
print(f"Auto-detected image field: '{image_field}'")
|
||||
else:
|
||||
image_field = IMAGE_FIELD_NAME
|
||||
if image_field not in dataset[0]:
|
||||
raise RuntimeError(
|
||||
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
|
||||
)
|
||||
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
|
||||
p = dataset[i]
|
||||
# Try to compose a descriptive identifier
|
||||
# Handle different dataset structures
|
||||
identifier_parts = []
|
||||
|
||||
# Helper function to safely get field value
|
||||
def safe_get(field_name, default=None):
|
||||
if field_name in p and p[field_name] is not None:
|
||||
return p[field_name]
|
||||
return default
|
||||
|
||||
# Try to get various identifier fields
|
||||
if safe_get("paper_arxiv_id"):
|
||||
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
|
||||
if safe_get("paper_title"):
|
||||
identifier_parts.append(f"title:{p['paper_title']}")
|
||||
if safe_get("page_number") is not None:
|
||||
try:
|
||||
identifier_parts.append(f"page:{int(p['page_number'])}")
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use the raw value or skip
|
||||
if p['page_number']:
|
||||
identifier_parts.append(f"page:{p['page_number']}")
|
||||
if safe_get("page_id"):
|
||||
identifier_parts.append(f"id:{p['page_id']}")
|
||||
elif safe_get("questionId"):
|
||||
identifier_parts.append(f"qid:{p['questionId']}")
|
||||
elif safe_get("docId"):
|
||||
identifier_parts.append(f"docId:{p['docId']}")
|
||||
elif safe_get("id"):
|
||||
identifier_parts.append(f"id:{p['id']}")
|
||||
|
||||
# If no identifier parts found, create one from index
|
||||
if identifier_parts:
|
||||
identifier = "|".join(identifier_parts)
|
||||
else:
|
||||
# Create identifier from available fields or index
|
||||
fallback_parts = []
|
||||
# Try common fields that might exist
|
||||
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
|
||||
if safe_get(field):
|
||||
fallback_parts.append(f"{field}:{p[field]}")
|
||||
break
|
||||
if fallback_parts:
|
||||
identifier = "|".join(fallback_parts) + f"|idx:{i}"
|
||||
else:
|
||||
identifier = f"doc_{i}"
|
||||
|
||||
filepaths.append(identifier)
|
||||
|
||||
# Get image - try detected field first, then fallback to other common fields
|
||||
img = None
|
||||
if image_field in p and p[image_field] is not None:
|
||||
img = p[image_field]
|
||||
else:
|
||||
# Fallback: try other common image field names
|
||||
for fallback_field in ["image", "page_image", "images", "img"]:
|
||||
if fallback_field in p and p[fallback_field] is not None:
|
||||
img = p[fallback_field]
|
||||
break
|
||||
|
||||
if img is None:
|
||||
raise RuntimeError(
|
||||
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
|
||||
f"Expected field: {image_field}"
|
||||
)
|
||||
|
||||
# Ensure it's a PIL Image
|
||||
if not isinstance(img, Image.Image):
|
||||
if hasattr(img, 'convert'):
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
|
||||
images.append(img)
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Memory check before loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
images = [] # Not needed when using existing index
|
||||
|
||||
|
||||
# %%
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
print("Step 3: Loading model and processor...")
|
||||
print(f" Model: {MODEL}")
|
||||
try:
|
||||
import sys
|
||||
print(f" Python version: {sys.version}")
|
||||
print(f" Python executable: {sys.executable}")
|
||||
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
# Memory check after loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index:
|
||||
print("Step 4: Building index...")
|
||||
print(f" Number of images: {len(images)}")
|
||||
print(f" Number of filepaths: {len(filepaths)}")
|
||||
|
||||
try:
|
||||
print(" Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
print(f" Embedded {len(doc_vecs)} documents")
|
||||
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Build Fast-Plaid index
|
||||
print(" Building Fast-Plaid index...")
|
||||
try:
|
||||
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||
)
|
||||
from pathlib import Path
|
||||
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||
except Exception as e:
|
||||
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
else:
|
||||
# Build original LEANN index
|
||||
try:
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
except Exception as e:
|
||||
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||
|
||||
|
||||
# %%
|
||||
# Step 5: Embed query and search
|
||||
_t0 = time.perf_counter()
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
query_embed_secs = time.perf_counter() - _t0
|
||||
|
||||
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||
|
||||
# Run the selected search method and time it
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid search
|
||||
if fast_plaid_index is None:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is None:
|
||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||
|
||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||
else:
|
||||
# Original LEANN search
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
print("\n[DEBUG] Retrieval details:")
|
||||
top_images: list[Image.Image] = []
|
||||
image_hashes = {} # Track image hashes to detect duplicates
|
||||
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid: load image and get metadata
|
||||
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||
top_images.append(image)
|
||||
else:
|
||||
# Original LEANN: retrieve from retriever
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# Calculate image hash to detect duplicates
|
||||
import hashlib
|
||||
import io
|
||||
# Convert image to bytes for hashing
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
image_bytes = img_bytes.getvalue()
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||
|
||||
# Check if this image was already seen
|
||||
duplicate_info = ""
|
||||
if image_hash in image_hashes:
|
||||
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||
else:
|
||||
image_hashes[image_hash] = rank
|
||||
|
||||
# Print detailed information
|
||||
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||
if metadata:
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
base = _Path(SAVE_TOP_IMAGE)
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if base.suffix:
|
||||
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||
else:
|
||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||
img.save(str(out_path))
|
||||
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
||||
try:
|
||||
score, _doc_id = results[rank - 1]
|
||||
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
|
||||
# %%
|
||||
# Step 6: Similarity maps for top-K results
|
||||
if results and SIMILARITY_MAP:
|
||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||
from pathlib import Path as _Path
|
||||
|
||||
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if output_base:
|
||||
if output_base.suffix:
|
||||
out_dir = output_base.parent
|
||||
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||
out_path = str(out_dir / out_name)
|
||||
else:
|
||||
out_dir = output_base
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||
else:
|
||||
out_path = None
|
||||
chosen_idx, max_sim = _generate_similarity_map(
|
||||
model=model,
|
||||
processor=processor,
|
||||
image=img,
|
||||
query=QUERY,
|
||||
token_idx=token_idx,
|
||||
output_path=out_path,
|
||||
)
|
||||
if out_path:
|
||||
print(
|
||||
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 7: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
_t0 = time.perf_counter()
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
gen_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Generation: {gen_secs:.3f}s")
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
@@ -0,0 +1,451 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v1 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v1 tasks
|
||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# ViDoRe v1 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V1_TASKS = {
|
||||
"VidoreArxivQARetrieval": {
|
||||
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreDocVQARetrieval": {
|
||||
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreInfoVQARetrieval": {
|
||||
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTabfquadRetrieval": {
|
||||
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTatdqaRetrieval": {
|
||||
"dataset_path": "vidore/tatdqa_test_beir",
|
||||
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreShiftProjectRetrieval": {
|
||||
"dataset_path": "vidore/shiftproject_test_beir",
|
||||
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAAIRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
# Task name aliases (short names -> full names)
|
||||
TASK_ALIASES = {
|
||||
"arxivqa": "VidoreArxivQARetrieval",
|
||||
"docvqa": "VidoreDocVQARetrieval",
|
||||
"infovqa": "VidoreInfoVQARetrieval",
|
||||
"tabfquad": "VidoreTabfquadRetrieval",
|
||||
"tatdqa": "VidoreTatdqaRetrieval",
|
||||
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||
}
|
||||
|
||||
|
||||
def normalize_task_name(task_name: str) -> str:
|
||||
"""Normalize task name (handle aliases)."""
|
||||
task_name_lower = task_name.lower()
|
||||
if task_name in VIDORE_V1_TASKS:
|
||||
return task_name
|
||||
if task_name_lower in TASK_ALIASES:
|
||||
return TASK_ALIASES[task_name_lower]
|
||||
# Try partial match
|
||||
for alias, full_name in TASK_ALIASES.items():
|
||||
if alias in task_name_lower or task_name_lower in alias:
|
||||
return full_name
|
||||
return task_name
|
||||
|
||||
|
||||
def get_safe_model_name(model_name: str) -> str:
|
||||
"""Get a safe model name for use in file paths."""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# If it's a path, use basename or hash
|
||||
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||
# Use basename if it's reasonable, otherwise use hash
|
||||
basename = os.path.basename(model_name.rstrip("/"))
|
||||
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||
return basename
|
||||
# Use hash for very long or problematic paths
|
||||
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||
# For HuggingFace model names, replace / with _
|
||||
return model_name.replace("/", "_").replace(":", "_")
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v1 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||
|
||||
queries: dict[str, str] = {}
|
||||
for row in query_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
queries[query_id] = row_dict["query"]
|
||||
|
||||
# Load corpus (images) - cast to Dataset
|
||||
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||
|
||||
corpus: dict[str, Any] = {}
|
||||
for row in corpus_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["image"]
|
||||
elif "page_image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments) - cast to Dataset
|
||||
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||
|
||||
qrels: dict[str, dict[str, int]] = {}
|
||||
for row in qrels_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Normalize task name (handle aliases)
|
||||
task_name = normalize_task_name(task_name)
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = str(task_config["dataset_path"])
|
||||
revision = str(task_config["revision"])
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
# Use safe model name for index path (different models need different indexes)
|
||||
safe_model_name = get_safe_model_name(model_name)
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,20,100,1000",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v1_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [normalize_task_name(args.task)]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v2 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Language name to dataset language field value mapping
|
||||
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||
LANGUAGE_MAPPING = {
|
||||
"english": "eng-Latn",
|
||||
"french": "fra-Latn",
|
||||
"spanish": "spa-Latn",
|
||||
"german": "deu-Latn",
|
||||
}
|
||||
|
||||
# ViDoRe v2 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V2_TASKS = {
|
||||
"Vidore2ESGReportsRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_v2",
|
||||
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2EconomicsReportsRetrieval": {
|
||||
"dataset_path": "vidore/economics_reports_v2",
|
||||
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2BioMedicalLecturesRetrieval": {
|
||||
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2ESGReportsHLRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||
# Note: This dataset doesn't have language filtering - all queries are English
|
||||
"languages": None, # No language filtering needed
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v2_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = cast(
|
||||
Dataset,
|
||||
load_dataset(dataset_path, "queries", split=split, revision=revision),
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries: dict[str, str] = {}
|
||||
for row in query_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
queries[query_id] = row_dict["query"]
|
||||
|
||||
# Load corpus (images) - cast to Dataset
|
||||
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||
|
||||
corpus: dict[str, Any] = {}
|
||||
for row in corpus_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["image"]
|
||||
elif "page_image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments) - cast to Dataset
|
||||
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||
|
||||
qrels: dict[str, dict[str, int]] = {}
|
||||
for row in qrels_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = str(task_config["dataset_path"])
|
||||
revision = str(task_config["revision"])
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = cast(Optional[list[str]], task_config.get("languages"))
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
elif len(languages) == 1:
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Language to evaluate (default: first available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Top-k results to retrieve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,100",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
language=args.language,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
apps/semantic_file_search/leann-plus-temporal-search.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
|
||||
class TimeParser:
|
||||
def __init__(self):
|
||||
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
||||
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
||||
|
||||
# Compile for performance
|
||||
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
||||
|
||||
# Stop words to remove before regex parsing
|
||||
self.stop_words = {
|
||||
"in",
|
||||
"at",
|
||||
"of",
|
||||
"by",
|
||||
"as",
|
||||
"me",
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"find",
|
||||
"search",
|
||||
"list",
|
||||
"ago",
|
||||
"back",
|
||||
"past",
|
||||
"earlier",
|
||||
}
|
||||
|
||||
def clean_text(self, text):
|
||||
"""Remove stop words from text"""
|
||||
words = text.split()
|
||||
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
||||
return cleaned
|
||||
|
||||
def parse(self, text):
|
||||
"""Extract all time expressions from text"""
|
||||
# Clean text first
|
||||
cleaned_text = self.clean_text(text)
|
||||
|
||||
matches = []
|
||||
for match in self.regex.finditer(cleaned_text):
|
||||
fuzzy = match.group(1) # "around", "about", etc.
|
||||
number = int(match.group(2))
|
||||
unit = match.group(3).lower()
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"full_match": match.group(0),
|
||||
"fuzzy": bool(fuzzy),
|
||||
"number": number,
|
||||
"unit": unit,
|
||||
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
||||
}
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
def calculate_range(self, number, unit, is_fuzzy):
|
||||
"""Convert to actual datetime range and return ISO format strings"""
|
||||
units = {
|
||||
"hour": timedelta(hours=number),
|
||||
"day": timedelta(days=number),
|
||||
"week": timedelta(weeks=number),
|
||||
"month": timedelta(days=number * 30),
|
||||
"year": timedelta(days=number * 365),
|
||||
}
|
||||
|
||||
delta = units[unit]
|
||||
now = datetime.now()
|
||||
target = now - delta
|
||||
|
||||
if is_fuzzy:
|
||||
buffer = delta * 0.2 # 20% buffer for fuzzy
|
||||
start = (target - buffer).isoformat()
|
||||
end = (target + buffer).isoformat()
|
||||
else:
|
||||
start = target.isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
return (start, end)
|
||||
|
||||
|
||||
def search_files(query, top_k=15):
|
||||
"""Search the index and return results"""
|
||||
# Parse time expressions
|
||||
parser = TimeParser()
|
||||
time_matches = parser.parse(query)
|
||||
|
||||
# Remove time expressions from query for semantic search
|
||||
clean_query = query
|
||||
if time_matches:
|
||||
for match in time_matches:
|
||||
clean_query = clean_query.replace(match["full_match"], "").strip()
|
||||
|
||||
# Check if clean_query is less than 4 characters
|
||||
if len(clean_query) < 4:
|
||||
print("Error: add more input for accurate results.")
|
||||
return
|
||||
|
||||
# Single query to vector DB
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search(
|
||||
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
||||
)
|
||||
|
||||
# Filter by time if time expression found
|
||||
if time_matches:
|
||||
time_range = time_matches[0]["range"] # Use first time expression
|
||||
start_time, end_time = time_range
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
# Access metadata attribute directly (not .get())
|
||||
metadata = result.metadata if hasattr(result, "metadata") else {}
|
||||
|
||||
if metadata:
|
||||
# Check modification date first, fall back to creation date
|
||||
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
||||
|
||||
if date_str:
|
||||
# Convert strings to datetime objects for proper comparison
|
||||
try:
|
||||
file_date = datetime.fromisoformat(date_str)
|
||||
start_dt = datetime.fromisoformat(start_time)
|
||||
end_dt = datetime.fromisoformat(end_time)
|
||||
|
||||
# Compare dates properly
|
||||
if start_dt <= file_date <= end_dt:
|
||||
filtered_results.append(result)
|
||||
except (ValueError, TypeError):
|
||||
# Handle invalid date formats
|
||||
print(f"Warning: Invalid date format in metadata: {date_str}")
|
||||
continue
|
||||
|
||||
results = filtered_results
|
||||
|
||||
# Print results
|
||||
print(f"\nSearch results for: '{query}'")
|
||||
if time_matches:
|
||||
print(
|
||||
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
||||
)
|
||||
print(
|
||||
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n[{i}] Score: {result.score:.4f}")
|
||||
print(f"Content: {result.text}")
|
||||
|
||||
# Show metadata if present
|
||||
metadata = result.metadata if hasattr(result, "metadata") else None
|
||||
if metadata:
|
||||
if "creation_date" in metadata:
|
||||
print(f"Created: {metadata['creation_date']}")
|
||||
if "modification_date" in metadata:
|
||||
print(f"Modified: {metadata['modification_date']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python search_index.py "<search query>" [top_k]')
|
||||
sys.exit(1)
|
||||
|
||||
query = sys.argv[1]
|
||||
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
||||
|
||||
search_files(query, top_k)
|
||||
82
apps/semantic_file_search/leann_index_builder.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
def process_json_items(json_file_path):
|
||||
"""Load and process JSON file with metadata items"""
|
||||
|
||||
with open(json_file_path, encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
|
||||
# Guard against empty JSON
|
||||
if not items:
|
||||
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
||||
return
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
||||
|
||||
total_items = len(items)
|
||||
items_added = 0
|
||||
print(f"Processing {total_items} items...")
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
# Create embedding text sentence
|
||||
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
||||
|
||||
# Prepare metadata with dates
|
||||
metadata = {}
|
||||
if "CreationDate" in item:
|
||||
metadata["creation_date"] = item["CreationDate"]
|
||||
if "ContentChangeDate" in item:
|
||||
metadata["modification_date"] = item["ContentChangeDate"]
|
||||
|
||||
# Add to builder
|
||||
builder.add_text(embedding_text, metadata=metadata)
|
||||
items_added += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Show progress
|
||||
progress = (idx + 1) / total_items * 100
|
||||
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
||||
sys.stdout.flush()
|
||||
|
||||
print() # New line after progress
|
||||
|
||||
# Guard against no successfully added items
|
||||
if items_added == 0:
|
||||
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
||||
return
|
||||
|
||||
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
||||
print("Building index...")
|
||||
|
||||
try:
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✓ Index saved to {INDEX_PATH}")
|
||||
except ValueError as e:
|
||||
if "No chunks added" in str(e):
|
||||
print("⚠️ No chunks were added to the builder. Index not created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python build_index.py <json_file>")
|
||||
sys.exit(1)
|
||||
|
||||
json_file = sys.argv[1]
|
||||
if not Path(json_file).exists():
|
||||
print(f"Error: File {json_file} not found")
|
||||
sys.exit(1)
|
||||
|
||||
process_json_items(json_file)
|
||||
265
apps/semantic_file_search/spotlight_index_dump.py
Normal file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spotlight Metadata Dumper for Vector DB
|
||||
Extracts only essential metadata for semantic search embeddings
|
||||
Output is optimized for vector database storage with minimal fields
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Check platform before importing macOS-specific modules
|
||||
if sys.platform != "darwin":
|
||||
print("This script requires macOS (uses Spotlight)")
|
||||
sys.exit(1)
|
||||
|
||||
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
||||
|
||||
# EDIT THIS LIST: Add or remove folders to search
|
||||
# Can be either:
|
||||
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
||||
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
||||
SEARCH_FOLDERS = [
|
||||
"Desktop",
|
||||
"Downloads",
|
||||
"Documents",
|
||||
"Music",
|
||||
"Pictures",
|
||||
"Movies",
|
||||
# "Library", # Uncomment to include
|
||||
# "/Applications", # Absolute path example
|
||||
# "Code/Projects", # Subfolder example
|
||||
# Add any other folders here
|
||||
]
|
||||
|
||||
|
||||
def convert_to_serializable(obj):
|
||||
"""Convert NS objects to Python serializable types"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle NSDate
|
||||
if hasattr(obj, "timeIntervalSince1970"):
|
||||
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
||||
|
||||
# Handle NSArray
|
||||
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
||||
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
||||
|
||||
# Convert to string
|
||||
try:
|
||||
return str(obj)
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
||||
"""
|
||||
Dump Spotlight data using public.item predicate
|
||||
"""
|
||||
# Build full paths from SEARCH_FOLDERS
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
search_paths = []
|
||||
|
||||
print("Search locations:")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Check if it's an absolute path or relative
|
||||
if folder.startswith("/"):
|
||||
full_path = folder
|
||||
else:
|
||||
full_path = os.path.join(home_dir, folder)
|
||||
|
||||
if os.path.exists(full_path):
|
||||
search_paths.append(full_path)
|
||||
print(f" ✓ {full_path}")
|
||||
else:
|
||||
print(f" ✗ {full_path} (not found)")
|
||||
|
||||
if not search_paths:
|
||||
print("No valid search paths found!")
|
||||
return []
|
||||
|
||||
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
||||
|
||||
# Create query with public.item predicate
|
||||
query = NSMetadataQuery.alloc().init()
|
||||
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
||||
query.setPredicate_(predicate)
|
||||
|
||||
# Set search scopes to our specific folders
|
||||
query.setSearchScopes_(search_paths)
|
||||
|
||||
print("Starting query...")
|
||||
query.startQuery()
|
||||
|
||||
# Wait for gathering to complete
|
||||
run_loop = NSRunLoop.currentRunLoop()
|
||||
print("Gathering results...")
|
||||
|
||||
# Let it gather for a few seconds
|
||||
for i in range(50): # 5 seconds max
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
# Check gathering status periodically
|
||||
if i % 10 == 0:
|
||||
current_count = query.resultCount()
|
||||
if current_count > 0:
|
||||
print(f" Found {current_count} items so far...")
|
||||
|
||||
# Continue while still gathering (up to 2 more seconds)
|
||||
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
||||
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
|
||||
query.stopQuery()
|
||||
|
||||
total_results = query.resultCount()
|
||||
print(f"Found {total_results} total items")
|
||||
|
||||
if total_results == 0:
|
||||
print("No results found")
|
||||
return []
|
||||
|
||||
# Process items
|
||||
items_to_process = min(total_results, max_items)
|
||||
results = []
|
||||
|
||||
# ONLY relevant attributes for vector embeddings
|
||||
# These provide essential context for semantic search without bloat
|
||||
attributes = [
|
||||
"kMDItemPath", # Full path for file retrieval
|
||||
"kMDItemFSName", # Filename for display & embedding
|
||||
"kMDItemFSSize", # Size for filtering/ranking
|
||||
"kMDItemContentType", # File type for categorization
|
||||
"kMDItemKind", # Human-readable type for embedding
|
||||
"kMDItemFSCreationDate", # Temporal context
|
||||
"kMDItemFSContentChangeDate", # Recency for ranking
|
||||
]
|
||||
|
||||
print(f"Processing {items_to_process} items...")
|
||||
|
||||
for i in range(items_to_process):
|
||||
try:
|
||||
item = query.resultAtIndex_(i)
|
||||
metadata = {}
|
||||
|
||||
# Extract ONLY the relevant attributes
|
||||
for attr in attributes:
|
||||
try:
|
||||
value = item.valueForAttribute_(attr)
|
||||
if value is not None:
|
||||
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
||||
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
||||
metadata[clean_key] = convert_to_serializable(value)
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Only add if we have at least a path
|
||||
if metadata.get("Path"):
|
||||
results.append(metadata)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
continue
|
||||
|
||||
# Save to JSON
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
||||
|
||||
# Show summary
|
||||
print("\nSample items:")
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
for i, item in enumerate(results[:3]):
|
||||
print(f"\n[Item {i + 1}]")
|
||||
print(f" Path: {item.get('Path', 'N/A')}")
|
||||
print(f" Name: {item.get('Name', 'N/A')}")
|
||||
print(f" Type: {item.get('ContentType', 'N/A')}")
|
||||
print(f" Kind: {item.get('Kind', 'N/A')}")
|
||||
|
||||
# Handle size properly
|
||||
size = item.get("Size")
|
||||
if size:
|
||||
try:
|
||||
size_int = int(size)
|
||||
if size_int > 1024 * 1024:
|
||||
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
||||
elif size_int > 1024:
|
||||
print(f" Size: {size_int / 1024:.2f} KB")
|
||||
else:
|
||||
print(f" Size: {size_int} bytes")
|
||||
except (ValueError, TypeError):
|
||||
print(f" Size: {size}")
|
||||
|
||||
# Show dates
|
||||
if "CreationDate" in item:
|
||||
print(f" Created: {item['CreationDate']}")
|
||||
if "ContentChangeDate" in item:
|
||||
print(f" Modified: {item['ContentChangeDate']}")
|
||||
|
||||
# Count by type
|
||||
type_counts = {}
|
||||
for item in results:
|
||||
content_type = item.get("ContentType", "unknown")
|
||||
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
||||
|
||||
print(f"\nTotal items saved: {len(results)}")
|
||||
|
||||
if type_counts:
|
||||
print("\nTop content types:")
|
||||
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {ct}: {count} items")
|
||||
|
||||
# Count by folder
|
||||
folder_counts = {}
|
||||
for item in results:
|
||||
path = item.get("Path", "")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Build the full folder path
|
||||
if folder.startswith("/"):
|
||||
folder_path = folder
|
||||
else:
|
||||
folder_path = os.path.join(home_dir, folder)
|
||||
|
||||
if path.startswith(folder_path):
|
||||
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
||||
break
|
||||
|
||||
if folder_counts:
|
||||
print("\nItems by location:")
|
||||
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {folder}: {count} items")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
max_items = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Usage: python spot.py [number_of_items]")
|
||||
print("Default: 10 items")
|
||||
sys.exit(1)
|
||||
else:
|
||||
max_items = 10
|
||||
|
||||
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
||||
|
||||
# Run dump
|
||||
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
apps/slack_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Slack MCP data integration for LEANN
|
||||
516
apps/slack_data/slack_mcp_reader.py
Normal file
@@ -0,0 +1,516 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slack MCP Reader for LEANN
|
||||
|
||||
This module provides functionality to connect to Slack MCP servers and fetch message data
|
||||
for indexing in LEANN. It supports various Slack MCP server implementations and provides
|
||||
flexible message processing options.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackMCPReader:
|
||||
"""
|
||||
Reader for Slack data via MCP (Model Context Protocol) servers.
|
||||
|
||||
This class connects to Slack MCP servers to fetch message data and convert it
|
||||
into a format suitable for LEANN indexing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_command: str,
|
||||
workspace_name: Optional[str] = None,
|
||||
concatenate_conversations: bool = True,
|
||||
max_messages_per_conversation: int = 100,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize the Slack MCP Reader.
|
||||
|
||||
Args:
|
||||
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
|
||||
workspace_name: Optional workspace name to filter messages
|
||||
concatenate_conversations: Whether to group messages by channel/thread
|
||||
max_messages_per_conversation: Maximum messages to include per conversation
|
||||
max_retries: Maximum number of retries for failed operations
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.workspace_name = workspace_name
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
self.max_messages_per_conversation = max_messages_per_conversation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
"""Start the MCP server process."""
|
||||
try:
|
||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||
*self.mcp_server_command.split(),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
async def stop_mcp_server(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.mcp_process:
|
||||
self.mcp_process.terminate()
|
||||
await self.mcp_process.wait()
|
||||
logger.info("Stopped MCP server")
|
||||
|
||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a request to the MCP server and get response."""
|
||||
if not self.mcp_process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.mcp_process.stdin.write(request_json.encode())
|
||||
await self.mcp_process.stdin.drain()
|
||||
|
||||
response_line = await self.mcp_process.stdout.readline()
|
||||
if not response_line:
|
||||
raise RuntimeError("No response from MCP server")
|
||||
|
||||
return json.loads(response_line.decode().strip())
|
||||
|
||||
async def initialize_mcp_connection(self):
|
||||
"""Initialize the MCP connection."""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(init_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||
|
||||
logger.info("MCP connection initialized successfully")
|
||||
|
||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||
"""List available tools from the MCP server."""
|
||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||
|
||||
response = await self.send_mcp_request(list_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
def _is_cache_sync_error(self, error: dict) -> bool:
|
||||
"""Check if the error is related to users cache not being ready."""
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message", "").lower()
|
||||
return (
|
||||
"users cache is not ready" in message or "sync process is still running" in message
|
||||
)
|
||||
return False
|
||||
|
||||
async def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
"""Retry a function with exponential backoff, especially for cache sync issues."""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if this is a cache sync error
|
||||
error_dict = {}
|
||||
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
|
||||
error_dict = e.args[0]
|
||||
elif "Failed to fetch messages" in str(e):
|
||||
# Try to extract error from the exception message
|
||||
import re
|
||||
|
||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
else:
|
||||
# Try alternative format
|
||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
if self._is_cache_sync_error(error_dict):
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.info(
|
||||
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache sync still not ready after {self.max_retries} retries, giving up"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a cache sync error, don't retry
|
||||
break
|
||||
|
||||
# If we get here, all retries failed or it's not a retryable error
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
raise RuntimeError("Unexpected error: no exception captured during retry loop")
|
||||
|
||||
async def fetch_slack_messages(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
|
||||
|
||||
Args:
|
||||
channel: Optional channel name to filter messages
|
||||
limit: Maximum number of messages to fetch
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
|
||||
|
||||
async def _fetch_slack_messages_impl(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Internal implementation of fetch_slack_messages without retry logic.
|
||||
"""
|
||||
# This is a generic implementation - specific MCP servers may have different tool names
|
||||
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
||||
|
||||
tools = await self.list_available_tools()
|
||||
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
|
||||
message_tool = None
|
||||
|
||||
# Look for a tool that can fetch messages - prioritize conversations_history
|
||||
message_tool = None
|
||||
|
||||
# First, try to find conversations_history specifically
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if "conversations_history" in tool_name:
|
||||
message_tool = tool
|
||||
logger.info(f"Found conversations_history tool: {tool}")
|
||||
break
|
||||
|
||||
# If not found, look for other message-fetching tools
|
||||
if not message_tool:
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(
|
||||
keyword in tool_name
|
||||
for keyword in ["conversations_search", "message", "history"]
|
||||
):
|
||||
message_tool = tool
|
||||
break
|
||||
|
||||
if not message_tool:
|
||||
raise RuntimeError("No message fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {"limit": "180d"} # Use 180 days to get older messages
|
||||
if channel:
|
||||
# For conversations_history, use channel_id parameter
|
||||
if message_tool["name"] == "conversations_history":
|
||||
tool_params["channel_id"] = channel
|
||||
else:
|
||||
# Try common parameter names for channel specification
|
||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
||||
tool_params[param_name] = channel
|
||||
break
|
||||
|
||||
logger.info(f"Tool parameters: {tool_params}")
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {"name": message_tool["name"], "arguments": tool_params},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(fetch_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
|
||||
|
||||
# Extract messages from response - format may vary by MCP server
|
||||
result = response.get("result", {})
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
# Some MCP servers return content as a list
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
try:
|
||||
messages = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||
text_content = content.get("text", "")
|
||||
messages = self._parse_csv_messages(
|
||||
text_content if text_content else "", channel or "unknown"
|
||||
)
|
||||
else:
|
||||
messages = result["content"]
|
||||
else:
|
||||
# Direct message format
|
||||
messages = result.get("messages", [result])
|
||||
|
||||
return messages if isinstance(messages, list) else [messages]
|
||||
|
||||
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
|
||||
"""Parse CSV format messages from Slack MCP server."""
|
||||
import csv
|
||||
import io
|
||||
|
||||
messages = []
|
||||
try:
|
||||
# Split by lines and process each line as a CSV row
|
||||
lines = csv_text.strip().split("\n")
|
||||
if not lines:
|
||||
return messages
|
||||
|
||||
# Skip header line if it exists
|
||||
start_idx = 0
|
||||
if lines[0].startswith("MsgID,UserID,UserName"):
|
||||
start_idx = 1
|
||||
|
||||
for line in lines[start_idx:]:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Parse CSV line
|
||||
reader = csv.reader(io.StringIO(line))
|
||||
try:
|
||||
row = next(reader)
|
||||
if len(row) >= 7: # Ensure we have enough columns
|
||||
message = {
|
||||
"ts": row[0],
|
||||
"user": row[1],
|
||||
"username": row[2],
|
||||
"real_name": row[3],
|
||||
"channel": row[4],
|
||||
"thread_ts": row[5],
|
||||
"text": row[6],
|
||||
"time": row[7] if len(row) > 7 else "",
|
||||
"reactions": row[8] if len(row) > 8 else "",
|
||||
"cursor": row[9] if len(row) > 9 else "",
|
||||
}
|
||||
messages.append(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV messages: {e}")
|
||||
# Fallback: treat entire text as one message
|
||||
messages = [{"text": csv_text, "channel": channel or "unknown"}]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_message(self, message: dict[str, Any]) -> str:
|
||||
"""Format a single message for indexing."""
|
||||
text = message.get("text", "")
|
||||
user = message.get("user", message.get("username", "Unknown"))
|
||||
channel = message.get("channel", message.get("channel_name", "Unknown"))
|
||||
timestamp = message.get("ts", message.get("timestamp", ""))
|
||||
|
||||
# Format timestamp if available
|
||||
formatted_time = ""
|
||||
if timestamp:
|
||||
try:
|
||||
import datetime
|
||||
|
||||
if isinstance(timestamp, str) and "." in timestamp:
|
||||
dt = datetime.datetime.fromtimestamp(float(timestamp))
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
elif isinstance(timestamp, (int, float)):
|
||||
dt = datetime.datetime.fromtimestamp(timestamp)
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
formatted_time = str(timestamp)
|
||||
except (ValueError, TypeError):
|
||||
formatted_time = str(timestamp)
|
||||
|
||||
# Build formatted message
|
||||
parts = []
|
||||
if channel:
|
||||
parts.append(f"Channel: #{channel}")
|
||||
if user:
|
||||
parts.append(f"User: {user}")
|
||||
if formatted_time:
|
||||
parts.append(f"Time: {formatted_time}")
|
||||
if text:
|
||||
parts.append(f"Message: {text}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
|
||||
"""Create concatenated content from multiple messages in a channel."""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Sort messages by timestamp if available
|
||||
try:
|
||||
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
|
||||
except (ValueError, TypeError):
|
||||
pass # Keep original order if timestamps aren't numeric
|
||||
|
||||
# Limit messages per conversation
|
||||
if len(messages) > self.max_messages_per_conversation:
|
||||
messages = messages[-self.max_messages_per_conversation :]
|
||||
|
||||
# Create header
|
||||
content_parts = [
|
||||
f"Slack Channel: #{channel}",
|
||||
f"Message Count: {len(messages)}",
|
||||
f"Workspace: {self.workspace_name or 'Unknown'}",
|
||||
"=" * 50,
|
||||
"",
|
||||
]
|
||||
|
||||
# Add messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
content_parts.append(formatted_msg)
|
||||
content_parts.append("-" * 30)
|
||||
content_parts.append("")
|
||||
|
||||
return "\n".join(content_parts)
|
||||
|
||||
async def get_all_channels(self) -> list[str]:
|
||||
"""Get list of all available channels."""
|
||||
try:
|
||||
channels_list_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "channels_list", "arguments": {}},
|
||||
}
|
||||
channels_response = await self.send_mcp_request(channels_list_request)
|
||||
if "result" in channels_response:
|
||||
result = channels_response["result"]
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
# Parse the channels from the response
|
||||
channels = []
|
||||
lines = content["text"].split("\n")
|
||||
for line in lines:
|
||||
if line.strip() and ("#" in line or "C" in line[:10]):
|
||||
# Extract channel ID or name
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if part.startswith("C") and len(part) > 5:
|
||||
channels.append(part)
|
||||
elif part.startswith("#"):
|
||||
channels.append(part[1:]) # Remove #
|
||||
logger.info(f"Found {len(channels)} channels: {channels}")
|
||||
return channels
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get channels list: {e}")
|
||||
return []
|
||||
|
||||
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
|
||||
"""
|
||||
Read Slack data and return formatted text chunks.
|
||||
|
||||
Args:
|
||||
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
|
||||
|
||||
Returns:
|
||||
List of formatted text chunks ready for LEANN indexing
|
||||
"""
|
||||
try:
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
|
||||
all_texts = []
|
||||
|
||||
if channels:
|
||||
# Fetch specific channels
|
||||
for channel in channels:
|
||||
try:
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Fetch from all available channels
|
||||
logger.info("Fetching from all available channels...")
|
||||
all_channels = await self.get_all_channels()
|
||||
|
||||
if not all_channels:
|
||||
# Fallback to common channel names if we can't get the list
|
||||
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
|
||||
logger.info(f"Using fallback channels: {all_channels}")
|
||||
|
||||
for channel in all_channels:
|
||||
try:
|
||||
logger.info(f"Searching channel: {channel}")
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
|
||||
return all_texts
|
||||
|
||||
finally:
|
||||
await self.stop_mcp_server()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.stop_mcp_server()
|
||||
229
apps/slack_rag.py
Normal file
@@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slack RAG Application with MCP Support
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
|
||||
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
|
||||
|
||||
Usage:
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
||||
|
||||
|
||||
class SlackMCPRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for Slack messages via MCP servers.
|
||||
|
||||
This class provides a complete RAG pipeline for Slack data, including
|
||||
MCP server connection, data fetching, indexing, and interactive chat.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Slack MCP RAG",
|
||||
description="RAG application for Slack messages via MCP servers",
|
||||
default_index_name="slack_messages",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add Slack MCP-specific arguments."""
|
||||
parser.add_argument(
|
||||
"--mcp-server",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workspace-name",
|
||||
type=str,
|
||||
help="Slack workspace name for better organization and filtering",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
nargs="+",
|
||||
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Group messages by channel/thread for better context (default: True)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process individual messages instead of grouping by channel",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-messages-per-channel",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of messages to include per channel (default: 100)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-connection",
|
||||
action="store_true",
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of retries for failed operations (default: 5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Initial delay between retries in seconds (default: 2.0)",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
|
||||
try:
|
||||
reader = SlackMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=not args.no_concatenate_conversations,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
name = tool.get("name", "Unknown")
|
||||
description = tool.get("description", "No description available")
|
||||
print(f"\n{i}. {name}")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
|
||||
# Show input schema if available
|
||||
schema = tool.get("inputSchema", {})
|
||||
if schema.get("properties"):
|
||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||
print(
|
||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
print("3. Ensure you have proper authentication/credentials configured")
|
||||
print("4. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Slack messages via MCP server."""
|
||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
||||
|
||||
if args.workspace_name:
|
||||
print(f"Workspace: {args.workspace_name}")
|
||||
|
||||
# Filter out empty strings from channels
|
||||
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
|
||||
|
||||
if channels:
|
||||
print(f"Channels: {', '.join(channels)}")
|
||||
else:
|
||||
print("Fetching from all available channels")
|
||||
|
||||
concatenate = not args.no_concatenate_conversations
|
||||
print(
|
||||
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
|
||||
)
|
||||
|
||||
try:
|
||||
reader = SlackMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=concatenate,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
texts = await reader.read_slack_data(channels=channels)
|
||||
|
||||
if not texts:
|
||||
print("No messages found! This could mean:")
|
||||
print("- The MCP server couldn't fetch messages")
|
||||
print("- The specified channels don't exist or are empty")
|
||||
print("- Authentication issues with the Slack workspace")
|
||||
return []
|
||||
|
||||
print(f"Successfully loaded {len(texts)} text chunks from Slack")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
|
||||
print("\nSample content:")
|
||||
print("-" * 40)
|
||||
print(sample_text)
|
||||
print("-" * 40)
|
||||
|
||||
# Convert strings to dict format expected by base class
|
||||
return [{"text": text, "metadata": {"source": "slack"}} for text in texts]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading Slack data: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Authentication problems")
|
||||
print("- Network connectivity issues")
|
||||
print("- Incorrect channel names")
|
||||
raise
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point with MCP connection testing."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Test connection if requested
|
||||
if args.test_connection:
|
||||
success = await self.test_mcp_connection(args)
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
# Run the standard RAG pipeline
|
||||
await super().run()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the Slack MCP RAG application."""
|
||||
app = SlackMCPRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1
apps/twitter_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Twitter MCP data integration for LEANN
|
||||
295
apps/twitter_data/twitter_mcp_reader.py
Normal file
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Twitter MCP Reader for LEANN
|
||||
|
||||
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
|
||||
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
|
||||
flexible bookmark processing options.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TwitterMCPReader:
|
||||
"""
|
||||
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
|
||||
|
||||
This class connects to Twitter MCP servers to fetch bookmark data and convert it
|
||||
into a format suitable for LEANN indexing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_command: str,
|
||||
username: Optional[str] = None,
|
||||
include_tweet_content: bool = True,
|
||||
include_metadata: bool = True,
|
||||
max_bookmarks: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the Twitter MCP Reader.
|
||||
|
||||
Args:
|
||||
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
|
||||
username: Optional Twitter username to filter bookmarks
|
||||
include_tweet_content: Whether to include full tweet content
|
||||
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
|
||||
max_bookmarks: Maximum number of bookmarks to fetch
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.username = username
|
||||
self.include_tweet_content = include_tweet_content
|
||||
self.include_metadata = include_metadata
|
||||
self.max_bookmarks = max_bookmarks
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
"""Start the MCP server process."""
|
||||
try:
|
||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||
*self.mcp_server_command.split(),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
async def stop_mcp_server(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.mcp_process:
|
||||
self.mcp_process.terminate()
|
||||
await self.mcp_process.wait()
|
||||
logger.info("Stopped MCP server")
|
||||
|
||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a request to the MCP server and get response."""
|
||||
if not self.mcp_process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.mcp_process.stdin.write(request_json.encode())
|
||||
await self.mcp_process.stdin.drain()
|
||||
|
||||
response_line = await self.mcp_process.stdout.readline()
|
||||
if not response_line:
|
||||
raise RuntimeError("No response from MCP server")
|
||||
|
||||
return json.loads(response_line.decode().strip())
|
||||
|
||||
async def initialize_mcp_connection(self):
|
||||
"""Initialize the MCP connection."""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(init_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||
|
||||
logger.info("MCP connection initialized successfully")
|
||||
|
||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||
"""List available tools from the MCP server."""
|
||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||
|
||||
response = await self.send_mcp_request(list_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Twitter bookmarks using MCP tools.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of bookmarks to fetch
|
||||
|
||||
Returns:
|
||||
List of bookmark dictionaries
|
||||
"""
|
||||
tools = await self.list_available_tools()
|
||||
bookmark_tool = None
|
||||
|
||||
# Look for a tool that can fetch bookmarks
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
|
||||
bookmark_tool = tool
|
||||
break
|
||||
|
||||
if not bookmark_tool:
|
||||
raise RuntimeError("No bookmark fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {}
|
||||
if limit or self.max_bookmarks:
|
||||
tool_params["limit"] = limit or self.max_bookmarks
|
||||
if self.username:
|
||||
tool_params["username"] = self.username
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(fetch_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
|
||||
|
||||
# Extract bookmarks from response
|
||||
result = response.get("result", {})
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
try:
|
||||
bookmarks = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, treat as plain text
|
||||
bookmarks = [{"text": content["text"], "source": "twitter"}]
|
||||
else:
|
||||
bookmarks = result["content"]
|
||||
else:
|
||||
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
|
||||
|
||||
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
|
||||
|
||||
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
|
||||
"""Format a single bookmark for indexing."""
|
||||
# Extract tweet information
|
||||
text = bookmark.get("text", bookmark.get("content", ""))
|
||||
author = bookmark.get(
|
||||
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
|
||||
)
|
||||
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
|
||||
url = bookmark.get("url", bookmark.get("tweet_url", ""))
|
||||
|
||||
# Extract metadata if available
|
||||
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
|
||||
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
|
||||
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
|
||||
|
||||
# Build formatted bookmark
|
||||
parts = []
|
||||
|
||||
# Header
|
||||
parts.append("=== Twitter Bookmark ===")
|
||||
|
||||
if author:
|
||||
parts.append(f"Author: @{author}")
|
||||
|
||||
if timestamp:
|
||||
# Format timestamp if it's a standard format
|
||||
try:
|
||||
import datetime
|
||||
|
||||
if "T" in str(timestamp): # ISO format
|
||||
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
formatted_time = str(timestamp)
|
||||
parts.append(f"Date: {formatted_time}")
|
||||
except (ValueError, TypeError):
|
||||
parts.append(f"Date: {timestamp}")
|
||||
|
||||
if url:
|
||||
parts.append(f"URL: {url}")
|
||||
|
||||
# Tweet content
|
||||
if text and self.include_tweet_content:
|
||||
parts.append("")
|
||||
parts.append("Content:")
|
||||
parts.append(text)
|
||||
|
||||
# Metadata
|
||||
if self.include_metadata and any([likes, retweets, replies]):
|
||||
parts.append("")
|
||||
parts.append("Engagement:")
|
||||
if likes:
|
||||
parts.append(f" Likes: {likes}")
|
||||
if retweets:
|
||||
parts.append(f" Retweets: {retweets}")
|
||||
if replies:
|
||||
parts.append(f" Replies: {replies}")
|
||||
|
||||
# Extract hashtags and mentions if available
|
||||
hashtags = bookmark.get("hashtags", [])
|
||||
mentions = bookmark.get("mentions", [])
|
||||
|
||||
if hashtags or mentions:
|
||||
parts.append("")
|
||||
if hashtags:
|
||||
parts.append(f"Hashtags: {', '.join(hashtags)}")
|
||||
if mentions:
|
||||
parts.append(f"Mentions: {', '.join(mentions)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def read_twitter_bookmarks(self) -> list[str]:
|
||||
"""
|
||||
Read Twitter bookmark data and return formatted text chunks.
|
||||
|
||||
Returns:
|
||||
List of formatted text chunks ready for LEANN indexing
|
||||
"""
|
||||
try:
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
|
||||
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
|
||||
if self.username:
|
||||
print(f"Filtering for user: @{self.username}")
|
||||
|
||||
bookmarks = await self.fetch_twitter_bookmarks()
|
||||
|
||||
if not bookmarks:
|
||||
print("No bookmarks found")
|
||||
return []
|
||||
|
||||
print(f"Processing {len(bookmarks)} bookmarks...")
|
||||
|
||||
all_texts = []
|
||||
processed_count = 0
|
||||
|
||||
for bookmark in bookmarks:
|
||||
try:
|
||||
formatted_bookmark = self._format_bookmark(bookmark)
|
||||
if formatted_bookmark.strip():
|
||||
all_texts.append(formatted_bookmark)
|
||||
processed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format bookmark: {e}")
|
||||
continue
|
||||
|
||||
print(f"Successfully processed {processed_count} bookmarks")
|
||||
return all_texts
|
||||
|
||||
finally:
|
||||
await self.stop_mcp_server()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.stop_mcp_server()
|
||||
197
apps/twitter_rag.py
Normal file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Twitter RAG Application with MCP Support
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
|
||||
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
|
||||
|
||||
Usage:
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
||||
|
||||
|
||||
class TwitterMCPRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for Twitter bookmarks via MCP servers.
|
||||
|
||||
This class provides a complete RAG pipeline for Twitter bookmark data, including
|
||||
MCP server connection, data fetching, indexing, and interactive chat.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Twitter MCP RAG",
|
||||
description="RAG application for Twitter bookmarks via MCP servers",
|
||||
default_index_name="twitter_bookmarks",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add Twitter MCP-specific arguments."""
|
||||
parser.add_argument(
|
||||
"--mcp-server",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-bookmarks",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of bookmarks to fetch (default: 1000)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-tweet-content",
|
||||
action="store_true",
|
||||
help="Exclude tweet content, only include metadata",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-metadata",
|
||||
action="store_true",
|
||||
help="Exclude engagement metadata (likes, retweets, etc.)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-connection",
|
||||
action="store_true",
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
|
||||
try:
|
||||
reader = TwitterMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
username=args.username,
|
||||
include_tweet_content=not args.no_tweet_content,
|
||||
include_metadata=not args.no_metadata,
|
||||
max_bookmarks=args.max_bookmarks,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("\n✅ Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
name = tool.get("name", "Unknown")
|
||||
description = tool.get("description", "No description available")
|
||||
print(f"\n{i}. {name}")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
|
||||
# Show input schema if available
|
||||
schema = tool.get("inputSchema", {})
|
||||
if schema.get("properties"):
|
||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||
print(
|
||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the Twitter MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
print("3. Ensure you have proper Twitter API credentials configured")
|
||||
print("4. Verify your Twitter account has bookmarks to fetch")
|
||||
print("5. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Twitter bookmarks via MCP server."""
|
||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
||||
|
||||
if args.username:
|
||||
print(f"Username filter: @{args.username}")
|
||||
|
||||
print(f"Max bookmarks: {args.max_bookmarks}")
|
||||
print(f"Include tweet content: {not args.no_tweet_content}")
|
||||
print(f"Include metadata: {not args.no_metadata}")
|
||||
|
||||
try:
|
||||
reader = TwitterMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
username=args.username,
|
||||
include_tweet_content=not args.no_tweet_content,
|
||||
include_metadata=not args.no_metadata,
|
||||
max_bookmarks=args.max_bookmarks,
|
||||
)
|
||||
|
||||
texts = await reader.read_twitter_bookmarks()
|
||||
|
||||
if not texts:
|
||||
print("❌ No bookmarks found! This could mean:")
|
||||
print("- You don't have any bookmarks on Twitter")
|
||||
print("- The MCP server couldn't access your bookmarks")
|
||||
print("- Authentication issues with Twitter API")
|
||||
print("- The username filter didn't match any bookmarks")
|
||||
return []
|
||||
|
||||
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
|
||||
print("\nSample bookmark:")
|
||||
print("-" * 50)
|
||||
print(sample_text)
|
||||
print("-" * 50)
|
||||
|
||||
# Convert strings to dict format expected by base class
|
||||
return [{"text": text, "metadata": {"source": "twitter"}} for text in texts]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Twitter API authentication problems")
|
||||
print("- Network connectivity issues")
|
||||
print("- Rate limiting from Twitter API")
|
||||
raise
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point with MCP connection testing."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Test connection if requested
|
||||
if args.test_connection:
|
||||
success = await self.test_mcp_connection(args)
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
# Run the standard RAG pipeline
|
||||
await super().run()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the Twitter MCP RAG application."""
|
||||
app = TwitterMCPRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -6,6 +6,7 @@ Supports WeChat chat history export and search.
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
@@ -91,7 +92,7 @@ class WeChatRAG(BaseRAGExample):
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
BIN
assets/wechat_user_group.JPG
Normal file
|
After Width: | Height: | Size: 152 KiB |
0
benchmarks/__init__.py
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
BM25 vs DiskANN Baselines
|
||||
|
||||
```bash
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||
```
|
||||
|
||||
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||
- Machine-specific; results measured locally with the current repo.
|
||||
|
||||
DiskANN (NQ queries, search-only)
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||
|
||||
BM25
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||
|
||||
Notes
|
||||
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "pyserini"
|
||||
# ]
|
||||
# ///
|
||||
# sudo pacman -S jdk21-openjdk
|
||||
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||
# sudo archlinux-java status
|
||||
# sudo archlinux-java set java-21-openjdk
|
||||
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||
# fish_add_path --global $JAVA_HOME/bin
|
||||
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
|
||||
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||
queries: list[str] = []
|
||||
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||
_, ext = os.path.splitext(path)
|
||||
if ext.lower() in {".jsonl", ".json"}:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
# Not strict JSONL? treat the whole line as the query
|
||||
queries.append(line)
|
||||
continue
|
||||
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||
if q:
|
||||
queries.append(str(q))
|
||||
else:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
s = line.strip()
|
||||
if s:
|
||||
queries.append(s)
|
||||
|
||||
if limit is not None and limit > 0:
|
||||
queries = queries[:limit]
|
||||
return queries
|
||||
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
k = (len(s) - 1) * (p / 100.0)
|
||||
f = int(k)
|
||||
c = min(f + 1, len(s) - 1)
|
||||
if f == c:
|
||||
return s[f]
|
||||
return s[f] + (s[c] - s[f]) * (k - f)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||
ap.add_argument(
|
||||
"--bm25-index",
|
||||
default="benchmarks/data/indices/bm25_index",
|
||||
help="Path to Pyserini Lucene index directory",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--queries",
|
||||
default="benchmarks/data/queries/nq_open.jsonl",
|
||||
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||
)
|
||||
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||
ap.add_argument(
|
||||
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||
)
|
||||
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||
args = ap.parse_args()
|
||||
|
||||
try:
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
except Exception:
|
||||
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||
raise
|
||||
|
||||
if not os.path.isdir(args.bm25_index):
|
||||
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
queries = load_queries(args.queries, args.limit)
|
||||
if not queries:
|
||||
print("No queries loaded.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||
print(f"Opening BM25 index: {args.bm25_index}")
|
||||
searcher = LuceneSearcher(args.bm25_index)
|
||||
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||
try:
|
||||
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
latencies: list[float] = []
|
||||
total_searches = 0
|
||||
|
||||
# Warmup
|
||||
for i in range(min(args.warmup, len(queries))):
|
||||
_ = searcher.search(queries[i], k=args.k)
|
||||
|
||||
t0 = time.time()
|
||||
for i, q in enumerate(queries):
|
||||
t1 = time.time()
|
||||
hits = searcher.search(q, k=args.k)
|
||||
t2 = time.time()
|
||||
latencies.append(t2 - t1)
|
||||
total_searches += 1
|
||||
|
||||
if args.fetch_docs:
|
||||
# Optional doc fetch to include I/O time
|
||||
for h in hits:
|
||||
try:
|
||||
_ = searcher.doc(h.docid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||
|
||||
t1 = time.time()
|
||||
total_time = t1 - t0
|
||||
|
||||
if latencies:
|
||||
avg = mean(latencies)
|
||||
p50 = percentile(latencies, 50)
|
||||
p90 = percentile(latencies, 90)
|
||||
p95 = percentile(latencies, 95)
|
||||
p99 = percentile(latencies, 99)
|
||||
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||
else:
|
||||
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||
|
||||
print("BM25 Latency Report")
|
||||
print(f" queries: {total_searches}")
|
||||
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||
|
||||
if args.report:
|
||||
payload = {
|
||||
"queries": total_searches,
|
||||
"k": args.k,
|
||||
"k1": args.k1,
|
||||
"b": args.b,
|
||||
"avg_s": avg,
|
||||
"p50_s": p50,
|
||||
"p90_s": p90,
|
||||
"p95_s": p95,
|
||||
"p99_s": p99,
|
||||
"total_time_s": total_time,
|
||||
"qps": qps,
|
||||
"index_dir": os.path.abspath(args.bm25_index),
|
||||
"fetch_docs": bool(args.fetch_docs),
|
||||
}
|
||||
with open(args.report, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
print(f"Saved report to {args.report}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "leann-backend-diskann"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||
out: list[str] = []
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
out.append(obj["query"])
|
||||
if limit and len(out) >= limit:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--index-dir",
|
||||
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||
help="Directory containing DiskANN files",
|
||||
)
|
||||
ap.add_argument("--index-prefix", default="ann")
|
||||
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||
ap.add_argument("--num-queries", type=int, default=200)
|
||||
ap.add_argument("--top-k", type=int, default=10)
|
||||
ap.add_argument("--complexity", type=int, default=62)
|
||||
ap.add_argument("--threads", type=int, default=1)
|
||||
ap.add_argument("--beam-width", type=int, default=1)
|
||||
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||
args = ap.parse_args()
|
||||
|
||||
index_dir = Path(args.index_dir).resolve()
|
||||
if not index_dir.is_dir():
|
||||
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||
|
||||
qpath = Path(args.queries_file).resolve()
|
||||
if not qpath.exists():
|
||||
raise SystemExit(f"Queries file not found: {qpath}")
|
||||
|
||||
queries = load_queries(qpath, args.num_queries)
|
||||
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||
|
||||
# Compute embeddings once (exclude from timing)
|
||||
from leann.api import compute_embeddings as _compute
|
||||
|
||||
embs = _compute(
|
||||
queries,
|
||||
model_name="facebook/contriever-msmarco",
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
if embs.ndim != 2:
|
||||
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||
|
||||
# Build searcher
|
||||
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||
|
||||
index_prefix_path = str(index_dir / args.index_prefix)
|
||||
searcher = _DiskannSearcher(
|
||||
index_prefix_path,
|
||||
num_threads=int(args.threads),
|
||||
cache_mechanism=int(args.cache_mechanism),
|
||||
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||
)
|
||||
|
||||
# Warmup (not timed)
|
||||
_ = searcher.search(
|
||||
embs[0:1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
|
||||
# Timed loop
|
||||
times: list[float] = []
|
||||
for i in range(embs.shape[0]):
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
embs[i : i + 1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
times.append(time.time() - t0)
|
||||
|
||||
times_sorted = sorted(times)
|
||||
avg = float(sum(times) / len(times))
|
||||
p50 = times_sorted[len(times) // 2]
|
||||
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||
|
||||
print("\nDiskANN (NQ, search-only) Report")
|
||||
print(f" queries: {len(times)}")
|
||||
print(
|
||||
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||
)
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||
print(f" QPS: {1.0 / avg:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
benchmarks/data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - uncompressed
|
||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - compressed
|
||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - uncompressed
|
||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - compressed
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||
# Video files - compressed
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
44
benchmarks/data/README.md
Executable file
@@ -0,0 +1,44 @@
|
||||
---
|
||||
license: mit
|
||||
---
|
||||
|
||||
# LEANN-RAG Evaluation Data
|
||||
|
||||
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||
|
||||
## Dataset Components
|
||||
|
||||
This dataset is structured into three main parts:
|
||||
|
||||
1. **Pre-built LEANN Indices**:
|
||||
* `dpr/`: A pre-built index for the DPR dataset.
|
||||
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||
|
||||
2. **Ground Truth Data**:
|
||||
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||
|
||||
3. **Queries**:
|
||||
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
```
|
||||
|
||||
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir="data"
|
||||
)
|
||||
```
|
||||
|
||||
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Enron Emails Benchmark
|
||||
|
||||
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||
|
||||
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||
|
||||
## Layout
|
||||
|
||||
benchmarks/enron_emails/
|
||||
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||
- data/: Generated passages, queries, embeddings-related files
|
||||
- baseline/: FAISS Flat baseline files
|
||||
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||
|
||||
## Quickstart
|
||||
|
||||
1) Prepare the data and index
|
||||
|
||||
cd benchmarks/enron_emails
|
||||
python setup_enron_emails.py --data-dir data
|
||||
|
||||
Notes:
|
||||
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||
Alternatively, pass a local path to `--emails-csv`.
|
||||
|
||||
Notes:
|
||||
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||
|
||||
2) Run recall evaluation (Stage 2)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||
|
||||
3) Complexity sweep (Stage 3)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||
|
||||
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||
|
||||
4) Index comparison (Stage 4)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||
|
||||
5) Generation evaluation (Stage 5)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||
|
||||
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||
|
||||
Notes:
|
||||
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||
- `--baseline-dir` defaults to `baseline`
|
||||
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||
|
||||
Optional flags:
|
||||
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||
- --baseline-dir baseline (where FAISS baseline lives)
|
||||
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||
- --llm-backend hf|vllm (LLM backend for generation)
|
||||
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||
- --max-queries 1000 (limit number of queries for evaluation)
|
||||
|
||||
## Files Produced
|
||||
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||
|
||||
## Notes
|
||||
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||
|
||||
## Stages Summary
|
||||
|
||||
- Stage 2 (Recall@3):
|
||||
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||
- Compact index runs with `recompute_embeddings=True`.
|
||||
|
||||
- Stage 3 (Binary Search for Complexity):
|
||||
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||
|
||||
- Stage 4 (Index Comparison):
|
||||
- Reports .index-only sizes for compact vs non-compact.
|
||||
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||
- Stores retrieval results for Stage 5 generation evaluation.
|
||||
- Fails fast if compact recompute cannot run.
|
||||
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||
- First from the current run (when running `--stage all`), otherwise
|
||||
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||
|
||||
- Stage 5 (Generation Evaluation):
|
||||
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||
- Measures generation timing separately from search timing.
|
||||
- Requires Stage 4 results (no additional searching performed).
|
||||
|
||||
## Example Results
|
||||
|
||||
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||
|
||||
- Stage 3 (Binary Search):
|
||||
- Minimal complexity achieving 90% Recall@3: 88
|
||||
- Sampled points:
|
||||
- C=8 → 59.9% Recall@3
|
||||
- C=72 → 89.4% Recall@3
|
||||
- C=88 → 90.2% Recall@3
|
||||
- C=96 → 90.7% Recall@3
|
||||
- C=112 → 91.1% Recall@3
|
||||
- C=136 → 91.3% Recall@3
|
||||
- C=256 → 92.0% Recall@3
|
||||
|
||||
- Stage 4 (Index Sizes, .index only):
|
||||
- Compact: ~2.2 MB
|
||||
- Non-compact: ~82.0 MB
|
||||
- Storage saving by compact: ~97.3%
|
||||
|
||||
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||
- Compact (with recompute): ~1.981 s avg per query
|
||||
- Speed ratio (non-compact/compact): ~0.0038x
|
||||
|
||||
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||
- Average generation time: ~22.302 s per query
|
||||
- Total queries processed: 988
|
||||
- LLM backend: HuggingFace transformers
|
||||
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
|
||||
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
downloads/
|
||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||
On errors, fail fast without fallbacks.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
for i, query in enumerate(queries):
|
||||
# Compute query embedding with the same model/mode as the index
|
||||
q_emb = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS Flat ground truth
|
||||
n = q_emb.shape[0]
|
||||
k = 3
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(q_emb),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN (may require embedding server depending on index configuration)
|
||||
results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {r.id for r in results}
|
||||
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0
|
||||
total_recall += recall
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS: {list(baseline_ids)}")
|
||||
print(f" LEANN: {list(test_ids)}")
|
||||
print(f" ∩: {list(intersection)}")
|
||||
|
||||
avg = total_recall / max(1, len(queries))
|
||||
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||
return avg
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class EnronEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
queries: list[str] = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
if "query" in data:
|
||||
queries.append(data["query"])
|
||||
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||
return queries
|
||||
|
||||
def cleanup(self):
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||
|
||||
sizes["index_only_mb"] = (
|
||||
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source and embedding model
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
# Load all passages and ids
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
texts.append(data["text"])
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
meta["embedding_model"],
|
||||
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False,
|
||||
is_compact=False,
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||
) -> dict:
|
||||
"""Compare search speed for non-compact vs compact indexes."""
|
||||
import time
|
||||
|
||||
results: dict = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||
}
|
||||
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
# Non-compact (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Compact (with recompute). Fail fast if it cannot run.
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
docs = compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
results["compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Store retrieval results for Stage 5
|
||||
results["retrieval_results"].append(
|
||||
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||
)
|
||||
compact_searcher.cleanup()
|
||||
|
||||
if results["non_compact"]["search_times"]:
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
if results["compact"]["search_times"]:
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
if results["avg_search_times"].get("compact", 0) > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = 0.0
|
||||
|
||||
non_compact_searcher.cleanup()
|
||||
return results
|
||||
|
||||
def evaluate_complexity(
|
||||
self,
|
||||
recall_eval: "RecallEvaluator",
|
||||
queries: list[str],
|
||||
target: float = 0.90,
|
||||
c_min: int = 8,
|
||||
c_max: int = 256,
|
||||
max_iters: int = 10,
|
||||
recompute: bool = False,
|
||||
) -> dict:
|
||||
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||
|
||||
def round_c(x: int) -> int:
|
||||
# snap to multiple of 8 like other benches typically do
|
||||
return max(1, int((x + 7) // 8) * 8)
|
||||
|
||||
metrics: list[dict] = []
|
||||
|
||||
lo = round_c(c_min)
|
||||
hi = round_c(c_max)
|
||||
|
||||
print(
|
||||
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||
)
|
||||
|
||||
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||
r_lo = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=lo, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
cap = 1024
|
||||
while r_hi < target and hi < cap:
|
||||
lo = hi
|
||||
r_lo = r_hi
|
||||
hi = round_c(hi * 2)
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
if r_hi < target:
|
||||
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||
print("📈 Observations:")
|
||||
for m in metrics:
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||
|
||||
# Binary search within [lo, hi]
|
||||
best = hi
|
||||
iters = 0
|
||||
while lo < hi and iters < max_iters:
|
||||
mid = round_c((lo + hi) // 2)
|
||||
r_mid = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=mid, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||
if r_mid >= target:
|
||||
best = mid
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||
iters += 1
|
||||
|
||||
print("📈 Binary search results (sampled points):")
|
||||
# Print unique complexity entries ordered by complexity
|
||||
for m in sorted(
|
||||
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||
):
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all", "45"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument(
|
||||
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||
)
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve queries file: if default path not found, fall back to index's directory
|
||||
if not os.path.exists(args.queries):
|
||||
from pathlib import Path
|
||||
|
||||
idx_dir = Path(args.index).parent
|
||||
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||
if fallback_q.exists():
|
||||
args.queries = str(fallback_q)
|
||||
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||
raise SystemExit(1)
|
||||
|
||||
results_out: dict = {}
|
||||
|
||||
if args.stage in ("2", "all"):
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[:10]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
complexity = args.complexity or 64
|
||||
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
if args.stage in ("3", "all"):
|
||||
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[: args.max_queries]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact evaluator for binary search with recompute=False
|
||||
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
sweep = enron_eval.evaluate_complexity(
|
||||
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||
)
|
||||
results_out["stage3"] = sweep
|
||||
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||
from pathlib import Path
|
||||
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"stage3": sweep}, f, indent=2)
|
||||
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||
evaluator_nc.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 3 completed!\n")
|
||||
|
||||
if args.stage in ("4", "all", "45"):
|
||||
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
test_q = queries[: min(args.max_queries, len(queries))]
|
||||
|
||||
current_sizes = enron_eval.analyze_index_sizes()
|
||||
# Build non-compact index for comparison (no fallback)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||
nc_eval = EnronEvaluator(non_compact_path)
|
||||
|
||||
if (
|
||||
current_sizes.get("index_only_mb", 0) > 0
|
||||
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||
):
|
||||
storage_saving_percent = max(
|
||||
0.0,
|
||||
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||
)
|
||||
else:
|
||||
storage_saving_percent = 0.0
|
||||
|
||||
if args.complexity is None:
|
||||
# Prefer in-session Stage 3 result
|
||||
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||
complexity = results_out["stage3"]["best_complexity"]
|
||||
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||
else:
|
||||
# Try to load last saved Stage 3 result near index
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
if default_stage3_path.exists():
|
||||
with open(default_stage3_path, encoding="utf-8") as f:
|
||||
prev = json.load(f)
|
||||
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||
if complexity is None:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||
)
|
||||
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||
else:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||
)
|
||||
else:
|
||||
complexity = args.complexity
|
||||
|
||||
comp = enron_eval.compare_index_performance(
|
||||
non_compact_path, args.index, test_q, complexity=complexity
|
||||
)
|
||||
results_out["stage4"] = {
|
||||
"current_index": current_sizes,
|
||||
"non_compact_index": non_compact_sizes,
|
||||
"storage_saving_percent": storage_saving_percent,
|
||||
"performance_comparison": comp,
|
||||
}
|
||||
nc_eval.cleanup()
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||
|
||||
# Check if Stage 4 results exist
|
||||
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||
print("💡 Run Stage 4 first or use --stage all")
|
||||
raise SystemExit(1)
|
||||
|
||||
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||
if not retrieval_results:
|
||||
print("❌ No retrieval results found from Stage 4")
|
||||
raise SystemExit(1)
|
||||
|
||||
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||
|
||||
# Load LLM
|
||||
try:
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Run generation using stored retrieval results
|
||||
import time
|
||||
|
||||
from llm_utils import create_prompt
|
||||
|
||||
generation_times = []
|
||||
responses = []
|
||||
|
||||
print("🤖 Running generation on pre-retrieved results...")
|
||||
for i, item in enumerate(retrieval_results):
|
||||
query = item["query"]
|
||||
retrieved_docs = item["retrieved_docs"]
|
||||
|
||||
# Prepare context from retrieved docs
|
||||
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||
prompt = create_prompt(context, query, "emails")
|
||||
|
||||
# Time generation only
|
||||
gen_start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - gen_start
|
||||
|
||||
generation_times.append(gen_time)
|
||||
responses.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||
|
||||
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Queries: {len(retrieval_results)}")
|
||||
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||
print(" (Search time from Stage 4)")
|
||||
|
||||
results_out["stage5"] = {
|
||||
"total_queries": len(retrieval_results),
|
||||
"avg_generation_time": avg_gen_time,
|
||||
"generation_times": generation_times,
|
||||
"responses": responses,
|
||||
}
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Results:")
|
||||
for i in range(min(3, len(retrieval_results))):
|
||||
query = retrieval_results[i]["query"]
|
||||
response = responses[i]
|
||||
print(f" Q{i + 1}: {query[:60]}...")
|
||||
print(f" A{i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.output and results_out:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results_out, f, indent=2)
|
||||
print(f"📝 Saved results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Enron Emails Benchmark Setup Script
|
||||
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from email import message_from_string
|
||||
from email.policy import default
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
class EnronSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
self.downloads_dir = self.data_dir / "downloads"
|
||||
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ----------------------------
|
||||
# Dataset acquisition
|
||||
# ----------------------------
|
||||
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||
if emails_csv:
|
||||
p = Path(emails_csv)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||
return str(p)
|
||||
|
||||
print(
|
||||
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||
)
|
||||
try:
|
||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||
|
||||
api = KaggleApi()
|
||||
api.authenticate()
|
||||
api.dataset_download_files(
|
||||
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||
)
|
||||
candidate = self.downloads_dir / "emails.csv"
|
||||
if candidate.exists():
|
||||
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||
return str(candidate)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||
)
|
||||
print(
|
||||
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||
)
|
||||
raise e
|
||||
|
||||
# ----------------------------
|
||||
# Data preparation
|
||||
# ----------------------------
|
||||
@staticmethod
|
||||
def _extract_message_id(raw_email: str) -> str:
|
||||
msg = message_from_string(raw_email, policy=default)
|
||||
val = msg.get("Message-ID", "")
|
||||
if val.startswith("<") and val.endswith(">"):
|
||||
val = val[1:-1]
|
||||
return val or ""
|
||||
|
||||
@staticmethod
|
||||
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||
parts = raw_email.split("\n\n", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].strip(), parts[1].strip()
|
||||
# Heuristic fallback
|
||||
first_lines = raw_email.splitlines()
|
||||
if first_lines and ":" in first_lines[0]:
|
||||
return raw_email.strip(), ""
|
||||
return "", raw_email.strip()
|
||||
|
||||
@staticmethod
|
||||
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if chunk_words <= 0:
|
||||
return [text]
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
limit = len(words)
|
||||
if not keep_last:
|
||||
limit = (len(words) // chunk_words) * chunk_words
|
||||
if limit == 0:
|
||||
return []
|
||||
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||
return [c for c in (s.strip() for s in chunks) if c]
|
||||
|
||||
def _iter_passages_from_csv(
|
||||
self,
|
||||
emails_csv: Path,
|
||||
chunk_words: int = 256,
|
||||
keep_last_header: bool = True,
|
||||
keep_last_body: bool = True,
|
||||
max_emails: int | None = None,
|
||||
) -> Iterable[dict]:
|
||||
with open(emails_csv, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
count = 0
|
||||
for i, row in enumerate(reader):
|
||||
if max_emails is not None and count >= max_emails:
|
||||
break
|
||||
|
||||
raw_message = row.get("message", "")
|
||||
email_file_id = row.get("file", "")
|
||||
|
||||
if not raw_message.strip():
|
||||
continue
|
||||
|
||||
message_id = self._extract_message_id(raw_message)
|
||||
if not message_id:
|
||||
# Fallback ID based on CSV position and file path
|
||||
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||
message_id = f"enron_{i}_{safe_file}"
|
||||
|
||||
header, body = self._split_header_body(raw_message)
|
||||
|
||||
# Header chunks
|
||||
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": True,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
# Body chunks
|
||||
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": False,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
count += 1
|
||||
|
||||
# ----------------------------
|
||||
# Build LEANN index and FAISS baseline
|
||||
# ----------------------------
|
||||
def build_leann_index(
|
||||
self,
|
||||
emails_csv: Optional[str],
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
chunk_words: int = 256,
|
||||
max_emails: int | None = None,
|
||||
) -> str:
|
||||
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Stream passages and add to builder
|
||||
preview_written = 0
|
||||
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||
for p in self._iter_passages_from_csv(
|
||||
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||
):
|
||||
builder.add_text(p["text"], metadata=p["metadata"])
|
||||
if preview_written < 200:
|
||||
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||
preview_written += 1
|
||||
|
||||
print(f"🔨 Building index at {self.index_path}...")
|
||||
builder.build_index(str(self.index_path))
|
||||
print("✅ LEANN index built!")
|
||||
return str(self.index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read meta for passage source and embedding model
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
# Load passages from builder output so IDs match LEANN
|
||||
passages: list[str] = []
|
||||
passage_ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"]) # builder-assigned ID
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||
print(f"🤖 Embedding model: {embedding_model}")
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
# Build FAISS IndexFlatIP
|
||||
dim = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dim)
|
||||
emb_f32 = embeddings.astype(np.float32)
|
||||
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as pf:
|
||||
pickle.dump(passage_ids, pf)
|
||||
|
||||
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||
print(f"✅ Metadata saved: {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
return baseline_path
|
||||
|
||||
# ----------------------------
|
||||
# Queries (optional): prepare evaluation queries file
|
||||
# ----------------------------
|
||||
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||
print(
|
||||
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||
)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load dataset: {e}")
|
||||
return self.queries_file
|
||||
|
||||
kept = 0
|
||||
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||
for i, item in enumerate(ds):
|
||||
how_realistic = float(item.get("how_realistic", 0.0))
|
||||
if how_realistic < min_realism:
|
||||
continue
|
||||
qid = str(item.get("id", f"enron_q_{i}"))
|
||||
query = item.get("question", "")
|
||||
if not query:
|
||||
continue
|
||||
record = {
|
||||
"id": qid,
|
||||
"query": query,
|
||||
# For reference only, not used in recall metric below
|
||||
"gt_message_ids": item.get("message_ids", []),
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
kept += 1
|
||||
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||
return self.queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||
parser.add_argument(
|
||||
"--emails-csv",
|
||||
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model for LEANN",
|
||||
)
|
||||
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = EnronSetup(args.data_dir)
|
||||
|
||||
# Build index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
emails_csv=args.emails_csv,
|
||||
backend=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
chunk_words=args.chunk_words,
|
||||
max_emails=args.max_emails,
|
||||
)
|
||||
|
||||
# Build FAISS baseline from the same passages & embeddings
|
||||
setup.build_faiss_flat_baseline(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping LEANN index build and baseline")
|
||||
|
||||
# Queries file (optional)
|
||||
if not args.skip_queries:
|
||||
setup.prepare_queries()
|
||||
else:
|
||||
print("⏭️ Skipping query preparation")
|
||||
|
||||
print("\n🎉 Enron Emails setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("Next steps:")
|
||||
print(
|
||||
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# FinanceBench Benchmark for LEANN-RAG
|
||||
|
||||
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||
- **Questions**: 150 financial Q&A examples
|
||||
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
benchmarks/financebench/
|
||||
├── setup_financebench.py # Downloads PDFs and builds index
|
||||
├── evaluate_financebench.py # Intelligent evaluation script
|
||||
├── data/
|
||||
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||
│ ├── pdfs/ # Downloaded financial documents
|
||||
│ └── index/ # LEANN indexes
|
||||
│ └── financebench_full_hnsw.leann
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Setup (Download & Build Index)
|
||||
|
||||
```bash
|
||||
cd benchmarks/financebench
|
||||
python setup_financebench.py
|
||||
```
|
||||
|
||||
This will:
|
||||
- Download the 150 Q&A examples
|
||||
- Download all 368 PDF documents (parallel processing)
|
||||
- Build a LEANN index from 53K+ text chunks
|
||||
- Verify setup with test query
|
||||
|
||||
### 2. Evaluation
|
||||
|
||||
```bash
|
||||
# Basic retrieval evaluation
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||
|
||||
|
||||
# RAG generation evaluation with Qwen3-8B
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||
```
|
||||
|
||||
## Evaluation Methods
|
||||
|
||||
### Retrieval Evaluation
|
||||
Uses intelligent matching with three strategies:
|
||||
1. **Exact text overlap** - Direct substring matches
|
||||
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||
|
||||
### QA Evaluation
|
||||
LLM-based answer evaluation using GPT-4o:
|
||||
- Handles numerical rounding and equivalent representations
|
||||
- Considers fractions, percentages, and decimal equivalents
|
||||
- Evaluates semantic meaning rather than exact text match
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||
|
||||
**Retrieval Metrics:**
|
||||
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||
- **Average Search Time**: 0.097s
|
||||
|
||||
**QA Metrics:**
|
||||
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||
|
||||
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||
|
||||
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||
|
||||
- **Stage 4 (Index Comparison):**
|
||||
- Compact Index: 5.0 MB
|
||||
- Non-compact Index: 172.2 MB
|
||||
- **Storage Saving**: 97.1%
|
||||
- **Search Performance**:
|
||||
- Non-compact (no recompute): 0.009s avg per query
|
||||
- Compact (with recompute): 2.203s avg per query
|
||||
- Speed ratio: 0.004x
|
||||
|
||||
**Generation Evaluation (20 queries, complexity=64):**
|
||||
- **Average Search Time**: 1.638s per query
|
||||
- **Average Generation Time**: 45.957s per query
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
- **Total Questions Processed**: 20
|
||||
|
||||
## Options
|
||||
|
||||
```bash
|
||||
# Use different backends
|
||||
python setup_financebench.py --backend diskann
|
||||
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||
|
||||
# Use different embedding models
|
||||
python setup_financebench.py --embedding-model facebook/contriever
|
||||
```
|
||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
||||
"""
|
||||
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
from leann import LeannChat, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(queries)
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Get ground truth: search with FAISS flat
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
query_embedding = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity
|
||||
test_results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class FinanceBenchEvaluator:
|
||||
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||
self.index_path = index_path
|
||||
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||
|
||||
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||
"""Load FinanceBench dataset"""
|
||||
data = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||
return data
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes with and without embeddings"""
|
||||
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes = {}
|
||||
total_with_embeddings = 0
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
for file_path, name in [
|
||||
(index_file, "index"),
|
||||
(meta_file, "metadata"),
|
||||
(passages_file, "passages_text"),
|
||||
(passages_idx_file, "passages_index"),
|
||||
]:
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
sizes[name] = size_mb
|
||||
total_with_embeddings += size_mb
|
||||
|
||||
else:
|
||||
sizes[name] = 0
|
||||
|
||||
sizes["total_with_embeddings"] = total_with_embeddings
|
||||
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||
|
||||
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||
|
||||
return sizes
|
||||
|
||||
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||
"""Create a compact index for comparison purposes"""
|
||||
print("🏗️ Building compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage
|
||||
**meta.get("backend_kwargs", {}),
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||
builder.build_index(compact_index_path)
|
||||
|
||||
# Analyze the compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
compact_sizes["index_type"] = "compact"
|
||||
|
||||
return compact_sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build non-compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||
builder.build_index(non_compact_index_path)
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
import time
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
# Test queries
|
||||
test_queries = [item["question"] for item in test_data[:5]]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def evaluate_timing_breakdown(
|
||||
self, data: list[dict], max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||
if not self.chat or not self.openai_client:
|
||||
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||
return {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0.0,
|
||||
"avg_generation_time": 0.0,
|
||||
"avg_total_time": 0.0,
|
||||
"accuracy": 0.0,
|
||||
}
|
||||
|
||||
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||
|
||||
if max_samples:
|
||||
data = data[:max_samples]
|
||||
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||
|
||||
search_times = []
|
||||
generation_times = []
|
||||
total_times = []
|
||||
correct_answers = 0
|
||||
|
||||
for i, item in enumerate(data):
|
||||
question = item["question"]
|
||||
ground_truth = item["answer"]
|
||||
|
||||
try:
|
||||
# Hack: Monkey-patch the ask method to capture internal timing
|
||||
original_ask = self.chat.ask
|
||||
captured_search_time = None
|
||||
captured_generation_time = None
|
||||
|
||||
def patched_ask(*args, **kwargs):
|
||||
nonlocal captured_search_time, captured_generation_time
|
||||
|
||||
# Time the search part
|
||||
search_start = time.time()
|
||||
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||
captured_search_time = time.time() - search_start
|
||||
|
||||
# Time the generation part
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
f"{context}\n\n"
|
||||
f"Question: {args[0]}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
generation_start = time.time()
|
||||
answer = self.chat.llm.ask(prompt)
|
||||
captured_generation_time = time.time() - generation_start
|
||||
|
||||
return answer
|
||||
|
||||
# Apply the patch
|
||||
self.chat.ask = patched_ask
|
||||
|
||||
# Time the total QA
|
||||
total_start = time.time()
|
||||
generated_answer = self.chat.ask(question)
|
||||
total_time = time.time() - total_start
|
||||
|
||||
# Restore original method
|
||||
self.chat.ask = original_ask
|
||||
|
||||
# Store the timings
|
||||
search_times.append(captured_search_time)
|
||||
generation_times.append(captured_generation_time)
|
||||
total_times.append(total_time)
|
||||
|
||||
# Check accuracy using LLM as judge
|
||||
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||
if is_correct:
|
||||
correct_answers += 1
|
||||
|
||||
status = "✅" if is_correct else "❌"
|
||||
print(
|
||||
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||
)
|
||||
print(f" GT: {ground_truth}")
|
||||
print(f" Gen: {generated_answer[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
search_times.append(0.0)
|
||||
generation_times.append(0.0)
|
||||
total_times.append(0.0)
|
||||
|
||||
accuracy = correct_answers / len(data) if data else 0.0
|
||||
|
||||
metrics = {
|
||||
"total_questions": len(data),
|
||||
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||
if generation_times
|
||||
else 0.0,
|
||||
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||
"accuracy": accuracy,
|
||||
"correct_answers": correct_answers,
|
||||
"search_times": search_times,
|
||||
"generation_times": generation_times,
|
||||
"total_times": total_times,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def _check_answer_accuracy(
|
||||
self, generated_answer: str, ground_truth: str, question: str
|
||||
) -> bool:
|
||||
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Ground Truth Answer: {ground_truth}
|
||||
|
||||
Generated Answer: {generated_answer}
|
||||
|
||||
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||
1. Numerical accuracy (exact values, units, currency)
|
||||
2. Key financial concepts and terminology
|
||||
3. Overall factual correctness
|
||||
|
||||
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||
|
||||
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||
|
||||
try:
|
||||
judge_response = self.openai_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
)
|
||||
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||
return judgment == "CORRECT"
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||
# Fallback to simple string matching
|
||||
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||
return gt_clean in gen_clean
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 EVALUATION RESULTS")
|
||||
print("=" * 50)
|
||||
|
||||
# Index comparison analysis
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||
|
||||
print("\n📊 Accuracy:")
|
||||
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||
print(
|
||||
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||
)
|
||||
|
||||
print("\n📊 Timing Breakdown:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||
|
||||
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||
search_pct = (
|
||||
timing_metrics.get("avg_search_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
gen_pct = (
|
||||
timing_metrics.get("avg_generation_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
print("\n📈 Time Distribution:")
|
||||
print(f" Search: {search_pct:.1f}%")
|
||||
print(f" Generation: {gen_pct:.1f}%")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||
)
|
||||
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load FinanceBench queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Test with more queries for robust measurement
|
||||
test_queries = queries[:2000]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = FinanceBenchEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Use more queries for robust measurement
|
||||
test_queries = queries[:200]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Binary search for 90% recall complexity (without recompute for speed)
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 32
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||
|
||||
# Use FinanceBench evaluator for QA evaluation
|
||||
evaluator = FinanceBenchEvaluator(
|
||||
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||
)
|
||||
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
data = evaluator.load_dataset(args.dataset)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes
|
||||
print("\n📊 Index size comparison:")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print("\n📊 Index-only size comparison (.index file only):")
|
||||
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
# Use index-only size for fair comparison (same as Enron emails)
|
||||
storage_saving = (
|
||||
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||
/ non_compact_size_metrics["index_only_mb"]
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for performance comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Step 5: Generation evaluation
|
||||
test_samples = 20
|
||||
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||
|
||||
if args.llm_backend == "openai" and args.openai_api_key:
|
||||
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||
evaluation_start = time.time()
|
||||
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||
evaluation_time = time.time() - evaluation_start
|
||||
else:
|
||||
print(
|
||||
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||
)
|
||||
try:
|
||||
# Load LLM
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Simple generation evaluation
|
||||
queries = [item["question"] for item in data[:test_samples]]
|
||||
gen_results = evaluate_rag(
|
||||
evaluator.searcher,
|
||||
llm_func,
|
||||
queries,
|
||||
domain="finance",
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
timing_metrics = {
|
||||
"total_questions": len(queries),
|
||||
"avg_search_time": gen_results["avg_search_time"],
|
||||
"avg_generation_time": gen_results["avg_generation_time"],
|
||||
"results": gen_results["results"],
|
||||
}
|
||||
evaluation_time = time.time()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
timing_metrics = {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0,
|
||||
"avg_generation_time": 0,
|
||||
}
|
||||
evaluation_time = 0
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
**timing_metrics,
|
||||
"total_evaluation_time": evaluation_time,
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print results
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||
print(" - Run full evaluation on complete dataset if needed")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FinanceBench Complete Setup Script
|
||||
Downloads all PDFs and builds full LEANN datastore
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import pymupdf
|
||||
import requests
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class FinanceBenchSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||
self.data_dir = self.base_dir / data_dir
|
||||
self.pdf_dir = self.data_dir / "pdfs"
|
||||
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||
self.index_dir = self.data_dir / "index"
|
||||
self.download_lock = Lock()
|
||||
|
||||
def download_dataset(self):
|
||||
"""Download the main FinanceBench dataset"""
|
||||
print("📊 Downloading FinanceBench dataset...")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.dataset_file.exists():
|
||||
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||
return
|
||||
|
||||
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(self.dataset_file, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||
|
||||
def get_pdf_list(self):
|
||||
"""Get list of all PDF files from GitHub"""
|
||||
print("📋 Fetching PDF list from GitHub...")
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||
)
|
||||
response.raise_for_status()
|
||||
pdf_files = response.json()
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files")
|
||||
return pdf_files
|
||||
|
||||
def download_single_pdf(self, pdf_info, position):
|
||||
"""Download a single PDF file"""
|
||||
pdf_name = pdf_info["name"]
|
||||
pdf_path = self.pdf_dir / pdf_name
|
||||
|
||||
# Skip if already downloaded
|
||||
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||
return f"✅ {pdf_name} (cached)"
|
||||
|
||||
try:
|
||||
# Download PDF
|
||||
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Write to file
|
||||
with self.download_lock:
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ {pdf_name}: {e!s}"
|
||||
|
||||
def download_all_pdfs(self, max_workers: int = 5):
|
||||
"""Download all PDF files with parallel processing"""
|
||||
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_files = self.get_pdf_list()
|
||||
|
||||
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all download tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||
for i, pdf_info in enumerate(pdf_files)
|
||||
}
|
||||
|
||||
# Process completed downloads with progress bar
|
||||
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||
for future in as_completed(future_to_pdf):
|
||||
result = future.result()
|
||||
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||
pbar.update(1)
|
||||
|
||||
# Verify downloads
|
||||
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||
|
||||
# Show any failures
|
||||
missing_pdfs = []
|
||||
for pdf_info in pdf_files:
|
||||
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||
missing_pdfs.append(pdf_info["name"])
|
||||
|
||||
if missing_pdfs:
|
||||
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||
for pdf in missing_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf}")
|
||||
if len(missing_pdfs) > 5:
|
||||
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||
|
||||
def build_leann_index(
|
||||
self,
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
):
|
||||
"""Build LEANN index from all PDFs"""
|
||||
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||
|
||||
# Check if we have PDFs
|
||||
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
raise RuntimeError("No PDF files found! Run download first.")
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files to process")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize builder with standard compact configuration
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage (pruned)
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Process PDFs and extract text
|
||||
total_chunks = 0
|
||||
failed_pdfs = []
|
||||
|
||||
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||
try:
|
||||
chunks = self.extract_pdf_text(pdf_path)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
total_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||
failed_pdfs.append(pdf_path.name)
|
||||
continue
|
||||
|
||||
# Build index in index directory
|
||||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||
print(f"🔨 Building index: {index_path}")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
|
||||
print("✅ Index built successfully!")
|
||||
print(f" 📁 Index path: {index_path}")
|
||||
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||
|
||||
if failed_pdfs:
|
||||
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||
|
||||
return str(index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read metadata from the built index
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.loads(f.read())
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
print(f"📊 Loading passages from {passage_file}...")
|
||||
print(f"🤖 Using embedding model: {embedding_model}")
|
||||
|
||||
# Load all passages for baseline
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages")
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
print("🧮 Computing embeddings...")
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(passage_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||
"""Extract and chunk text from a PDF file"""
|
||||
chunks = []
|
||||
doc = pymupdf.open(pdf_path)
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
text = page.get_text() # type: ignore
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source_file": pdf_path.name,
|
||||
"page_number": page_num + 1,
|
||||
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||
"company": pdf_path.name.split("_")[0],
|
||||
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||
}
|
||||
|
||||
# Use recursive character splitting like LangChain
|
||||
if len(text.split()) > 500:
|
||||
# Split by double newlines (paragraphs)
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
|
||||
current_chunk = ""
|
||||
for para in paragraphs:
|
||||
# If adding this paragraph would make chunk too long, save current chunk
|
||||
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
current_chunk = para
|
||||
else:
|
||||
current_chunk = (current_chunk + " " + para).strip()
|
||||
|
||||
# Add the last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Page is short enough, use as single chunk
|
||||
chunks.append(
|
||||
{
|
||||
"text": text.strip(),
|
||||
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||
}
|
||||
)
|
||||
|
||||
doc.close()
|
||||
return chunks
|
||||
|
||||
def extract_year_from_filename(self, filename: str) -> str:
|
||||
"""Extract year from PDF filename"""
|
||||
# Try to find 4-digit year in filename
|
||||
|
||||
match = re.search(r"(\d{4})", filename)
|
||||
return match.group(1) if match else "unknown"
|
||||
|
||||
def verify_setup(self, index_path: str):
|
||||
"""Verify the setup by testing a simple query"""
|
||||
print("🧪 Verifying setup with test query...")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test query
|
||||
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||
results = searcher.search(test_query, top_k=3)
|
||||
|
||||
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
company = result.metadata.get("company", "Unknown")
|
||||
year = result.metadata.get("doc_period", "Unknown")
|
||||
page = result.metadata.get("page_number", "Unknown")
|
||||
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||
print(f" {result.text[:100]}...")
|
||||
|
||||
searcher.cleanup()
|
||||
print("✅ Setup verification completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Setup verification failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument(
|
||||
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model",
|
||||
)
|
||||
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
parser.add_argument(
|
||||
"--build-baseline-only",
|
||||
action="store_true",
|
||||
help="Only build FAISS baseline from existing index",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🏦 FinanceBench Complete Setup")
|
||||
print("=" * 50)
|
||||
|
||||
setup = FinanceBenchSetup(args.data_dir)
|
||||
|
||||
try:
|
||||
if args.build_baseline_only:
|
||||
# Only build baseline from existing index
|
||||
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||
index_file = f"{index_path}.index"
|
||||
meta_file = f"{index_path}.leann.meta.json"
|
||||
|
||||
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||
print("❌ Index files not found:")
|
||||
print(f" Index: {index_file}")
|
||||
print(f" Meta: {meta_file}")
|
||||
print("💡 Run without --build-baseline-only to build the index first")
|
||||
exit(1)
|
||||
|
||||
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
return
|
||||
|
||||
# Step 1: Download dataset
|
||||
setup.download_dataset()
|
||||
|
||||
# Step 2: Download PDFs
|
||||
if not args.skip_download:
|
||||
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||
else:
|
||||
print("⏭️ Skipping PDF download")
|
||||
|
||||
# Step 3: Build LEANN index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
backend=args.backend, embedding_model=args.embedding_model
|
||||
)
|
||||
|
||||
# Step 4: Build FAISS flat baseline
|
||||
print("\n🔨 Building FAISS flat baseline...")
|
||||
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
|
||||
# Step 5: Verify setup
|
||||
setup.verify_setup(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
|
||||
print("\n🎉 FinanceBench setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("\nNext steps:")
|
||||
print(
|
||||
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||
)
|
||||
print(
|
||||
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.9"
|
||||
# dependencies = [
|
||||
# "faiss-cpu",
|
||||
# "numpy",
|
||||
# "sentence-transformers",
|
||||
# "torch",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Independent recall verification script using standard FAISS.
|
||||
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||
"""
|
||||
Direct embedding computation using sentence-transformers.
|
||||
Copied logic to avoid dependency issues.
|
||||
"""
|
||||
print(f"Loading model: {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||
embeddings = model.encode(
|
||||
chunks,
|
||||
show_progress_bar=True,
|
||||
batch_size=32,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
|
||||
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||
"""Load FinanceBench queries from dataset"""
|
||||
queries = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
if len(queries) >= max_queries:
|
||||
break
|
||||
return queries
|
||||
|
||||
|
||||
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||
"""Load passages from LEANN index structure"""
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
index_dir = Path(index_path).parent
|
||||
passage_file = index_dir / Path(passage_file).name
|
||||
|
||||
print(f"Loading passages from {passage_file}")
|
||||
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Loading passages"):
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"Loaded {len(passages)} passages")
|
||||
return passages, passage_ids
|
||||
|
||||
|
||||
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||
dimension = embeddings.shape[1]
|
||||
|
||||
# Build Flat index (ground truth)
|
||||
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||
flat_index = faiss.IndexFlatIP(dimension)
|
||||
flat_index.add(embeddings)
|
||||
|
||||
# Build HNSW index
|
||||
print("Building FAISS IndexHNSWFlat...")
|
||||
M = 32 # Same as LEANN default
|
||||
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||
hnsw_index.add(embeddings)
|
||||
|
||||
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||
return flat_index, hnsw_index
|
||||
|
||||
|
||||
def evaluate_recall_at_k(
|
||||
query_embeddings: np.ndarray,
|
||||
flat_index: faiss.Index,
|
||||
hnsw_index: faiss.Index,
|
||||
passage_ids: list[str],
|
||||
k: int = 3,
|
||||
ef_search: int = 64,
|
||||
) -> float:
|
||||
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||
|
||||
# Set search parameters for HNSW
|
||||
hnsw_index.hnsw.efSearch = ef_search
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = query_embeddings.shape[0]
|
||||
|
||||
for i in range(num_queries):
|
||||
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||
|
||||
# Get ground truth from Flat index (standard FAISS API)
|
||||
flat_distances, flat_indices = flat_index.search(query, k)
|
||||
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||
|
||||
# Get results from HNSW index (standard FAISS API)
|
||||
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||
|
||||
# Calculate recall
|
||||
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||
recall = len(intersection) / k
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||
print(f" Flat: {list(ground_truth_ids)}")
|
||||
print(f" HNSW: {list(hnsw_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
return avg_recall
|
||||
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
dataset_path = "data/financebench_merged.jsonl"
|
||||
index_path = "data/index/financebench_full_hnsw.leann"
|
||||
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
print("🔍 FAISS Recall Verification")
|
||||
print("=" * 50)
|
||||
|
||||
# Check if files exist
|
||||
if not Path(dataset_path).exists():
|
||||
print(f"❌ Dataset not found: {dataset_path}")
|
||||
return
|
||||
if not Path(f"{index_path}.meta.json").exists():
|
||||
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||
return
|
||||
|
||||
# Load data
|
||||
print("📖 Loading FinanceBench queries...")
|
||||
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||
print(f"Loaded {len(queries)} queries")
|
||||
|
||||
print("📄 Loading passages from LEANN index...")
|
||||
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||
|
||||
# Compute embeddings
|
||||
print("🧮 Computing passage embeddings...")
|
||||
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||
|
||||
print("🧮 Computing query embeddings...")
|
||||
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||
|
||||
# Build FAISS indexes
|
||||
print("🏗️ Building FAISS indexes...")
|
||||
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||
|
||||
# Test different efSearch values (equivalent to LEANN complexity)
|
||||
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||
ef_search_values = [16, 32, 64, 128, 256]
|
||||
|
||||
for ef_search in ef_search_values:
|
||||
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||
start_time = time.time()
|
||||
|
||||
recall = evaluate_recall_at_k(
|
||||
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
print("\n✅ Verification completed!")
|
||||
print("\n📋 Summary:")
|
||||
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||
print(" - Compared recall@3 at different efSearch values")
|
||||
print(" - Used same embedding model as LEANN")
|
||||
print(" - This validates LEANN's recall measurements")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
data/
|
||||
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# LAION Multimodal Benchmark
|
||||
|
||||
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||
|
||||
## Overview
|
||||
|
||||
This benchmark evaluates:
|
||||
- **Image retrieval timing** using caption-based queries
|
||||
- **Recall@K performance** for image search
|
||||
- **Complexity analysis** across different search parameters
|
||||
- **Index size and storage efficiency**
|
||||
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||
|
||||
## Dataset Configuration
|
||||
|
||||
- **Dataset**: LAION-400M subset (10,000 images)
|
||||
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||
- **Queries**: 200 random captions from the dataset
|
||||
- **Ground Truth**: Self-recall (query caption → original image)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Setup the benchmark
|
||||
|
||||
```bash
|
||||
cd benchmarks/laion
|
||||
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||
```
|
||||
|
||||
This will:
|
||||
- Create dummy LAION data (10K samples)
|
||||
- Generate CLIP embeddings (512-dim)
|
||||
- Build LEANN index with HNSW backend
|
||||
- Create 200 evaluation queries
|
||||
|
||||
### 2. Run evaluation
|
||||
|
||||
```bash
|
||||
# Run all evaluation stages
|
||||
python evaluate_laion.py --index data/laion_index.leann
|
||||
|
||||
# Run specific stages
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||
|
||||
# Multimodal generation with Qwen2.5-VL
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
### 3. Save results
|
||||
|
||||
```bash
|
||||
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Setup Options
|
||||
```bash
|
||||
python setup_laion.py \
|
||||
--num-samples 10000 \
|
||||
--num-queries 200 \
|
||||
--index-path data/laion_index.leann \
|
||||
--backend hnsw
|
||||
```
|
||||
|
||||
### Evaluation Options
|
||||
```bash
|
||||
python evaluate_laion.py \
|
||||
--index data/laion_index.leann \
|
||||
--queries data/evaluation_queries.jsonl \
|
||||
--complexity 64 \
|
||||
--top-k 3 \
|
||||
--num-samples 100 \
|
||||
--stage all
|
||||
```
|
||||
|
||||
## Evaluation Stages
|
||||
|
||||
### Stage 2: Recall Evaluation
|
||||
- Evaluates Recall@3 for multimodal retrieval
|
||||
- Compares LEANN vs FAISS baseline performance
|
||||
- Self-recall: query caption should retrieve original image
|
||||
|
||||
### Stage 3: Complexity Analysis
|
||||
- Binary search for optimal complexity (90% recall target)
|
||||
- Tests performance across different complexity levels
|
||||
- Analyzes speed vs. accuracy tradeoffs
|
||||
|
||||
### Stage 4: Index Comparison
|
||||
- Compares compact vs non-compact index sizes
|
||||
- Measures search performance differences
|
||||
- Reports storage efficiency and speed ratios
|
||||
|
||||
### Stage 5: Multimodal Generation
|
||||
- Uses Qwen2.5-VL for image understanding and description
|
||||
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||
- Measures both search and generation timing
|
||||
|
||||
## Output Metrics
|
||||
|
||||
### Timing Metrics
|
||||
- Average/median/min/max search time
|
||||
- Standard deviation
|
||||
- Searches per second
|
||||
- Latency in milliseconds
|
||||
|
||||
### Recall Metrics
|
||||
- Recall@3 percentage for image retrieval
|
||||
- Number of queries with ground truth
|
||||
|
||||
### Index Metrics
|
||||
- Total index size (MB)
|
||||
- Component breakdown (index, passages, metadata)
|
||||
- Storage savings (compact vs non-compact)
|
||||
- Backend and embedding model info
|
||||
|
||||
### Generation Metrics (Stage 5)
|
||||
- Average search time per query
|
||||
- Average generation time per query
|
||||
- Time distribution (search vs generation)
|
||||
- Sample multimodal responses
|
||||
- Model: Qwen2.5-VL performance
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||
|
||||
**Stage 3: Optimal Complexity Analysis**
|
||||
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||
- **Binary Search Range**: 1-128
|
||||
- **Target Recall**: 90%
|
||||
- **Index Type**: Non-compact (for fast binary search)
|
||||
|
||||
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||
- **Total Queries**: 20
|
||||
- **Average Search Time**: 1.200s per query
|
||||
- **Average Generation Time**: 6.558s per query
|
||||
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
- **Optimal Complexity**: 85
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||
- **Backend**: HNSW with cosine distance
|
||||
|
||||
### Example Results
|
||||
|
||||
```
|
||||
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||
============================================================
|
||||
|
||||
📊 Multimodal Generation Results:
|
||||
Total Queries: 20
|
||||
Avg Search Time: 1.200s
|
||||
Avg Generation Time: 6.558s
|
||||
Time Distribution: Search 15.5%, Generation 84.5%
|
||||
LLM Backend: HuggingFace transformers
|
||||
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
⚙️ Optimal Complexity Analysis:
|
||||
Target Recall: 90%
|
||||
Optimal Complexity: 85
|
||||
Binary Search Range: 1-128
|
||||
Non-compact Index (fast search, no recompute)
|
||||
|
||||
🚀 Performance Summary:
|
||||
Multimodal RAG: 7.758s total per query
|
||||
Search: 15.5% of total time
|
||||
Generation: 84.5% of total time
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
benchmarks/laion/
|
||||
├── setup_laion.py # Setup script
|
||||
├── evaluate_laion.py # Evaluation script
|
||||
├── README.md # This file
|
||||
└── data/ # Generated data
|
||||
├── laion_images/ # Image files (placeholder)
|
||||
├── laion_metadata.jsonl # Image metadata
|
||||
├── laion_passages.jsonl # LEANN passages
|
||||
├── laion_embeddings.npy # CLIP embeddings
|
||||
├── evaluation_queries.jsonl # Evaluation queries
|
||||
└── laion_index.leann/ # LEANN index files
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Current implementation uses dummy data for demonstration
|
||||
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||
- Adjust `num_samples` and `num_queries` based on available resources
|
||||
- Consider using `--num-samples` during evaluation for faster testing
|
||||
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline (image embeddings)
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.image_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||
|
||||
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(captions)
|
||||
|
||||
for i, caption in enumerate(captions):
|
||||
# Get ground truth: search with FAISS flat using caption text embedding
|
||||
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||
query_embedding = self.st_clip.encode(
|
||||
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results (image IDs from FAISS)
|
||||
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity (using caption as text query)
|
||||
test_results = self.searcher.search(
|
||||
caption,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class LAIONEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
"""Load caption queries from evaluation file"""
|
||||
captions = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
query_data = json.loads(line)
|
||||
captions.append(query_data["query"])
|
||||
|
||||
print(f"📊 Loaded {len(captions)} caption queries")
|
||||
return captions
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
# Core index size (.index only)
|
||||
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
sizes["index_only_mb"] = index_mb
|
||||
|
||||
# Other files for reference (not counted in index_only_mb)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||
if sizes["metadata_mb"]:
|
||||
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||
print(
|
||||
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||
)
|
||||
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Load CLIP embeddings
|
||||
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
distance_metric="cosine",
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare ids and add passages
|
||||
ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
# Ensure metadata contains the id used by the vector index
|
||||
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||
builder.add_text(text=data["text"], metadata=metadata)
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
# Test queries
|
||||
test_queries = test_captions[:5]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Index comparison analysis (prefer .index-only view if present)
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||
print("\n📏 Index Comparison Analysis (.index only):")
|
||||
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
# Show excluded components for reference if available
|
||||
if any(
|
||||
k in non_compact
|
||||
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||
):
|
||||
print(" (passages excluded in totals, shown for reference):")
|
||||
print(
|
||||
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||
)
|
||||
else:
|
||||
# Fallback to legacy totals if running with older metrics
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
print(
|
||||
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(
|
||||
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend",
|
||||
choices=["hf"],
|
||||
default="hf",
|
||||
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_laion.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
laion_evaluator = LAIONEvaluator(args.index)
|
||||
captions = laion_evaluator.load_queries(args.queries)
|
||||
|
||||
# Test with queries for robust measurement
|
||||
test_captions = captions[:100] # Use subset for speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Use subset for robust measurement
|
||||
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Binary search for 90% recall complexity
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 128
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Index comparison (without LLM generation)
|
||||
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||
|
||||
# Use LAION evaluator for index comparison
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
|
||||
# Load caption queries
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes (.index only)
|
||||
print("\n📊 Index size comparison (.index only):")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||
)
|
||||
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||
|
||||
storage_saving = 0.0
|
||||
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||
storage_saving = (
|
||||
(
|
||||
non_compact_size_metrics.get("index_only_mb", 0)
|
||||
- compact_size_metrics.get("index_only_mb", 0)
|
||||
)
|
||||
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for index comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print comprehensive results
|
||||
evaluator._print_results(combined_metrics)
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||
|
||||
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||
|
||||
# Load Qwen2.5-VL model
|
||||
try:
|
||||
print("Loading Qwen2.5-VL model...")
|
||||
processor, model = load_qwen_vl_model(args.model_name)
|
||||
|
||||
# Run multimodal generation evaluation
|
||||
complexity = args.complexity or 64
|
||||
gen_results = evaluate_multimodal_rag(
|
||||
evaluator.searcher,
|
||||
test_captions,
|
||||
processor=processor,
|
||||
model=model,
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
print("\n📊 Multimodal Generation Results:")
|
||||
print(f" Total Queries: {len(test_captions)}")
|
||||
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||
print(" LLM Backend: HuggingFace transformers")
|
||||
print(f" Model: {args.model_name}")
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Multimodal Generations:")
|
||||
for i, response in enumerate(gen_results["results"][:3]):
|
||||
# Handle both string and dict formats for captions
|
||||
if isinstance(test_captions[i], dict):
|
||||
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||
else:
|
||||
caption_text = str(test_captions[i])
|
||||
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||
print(f" Response {i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Setup Script
|
||||
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from leann import LeannBuilder
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LAIONSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.images_dir = self.data_dir / "laion_images"
|
||||
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||
|
||||
# Create directories
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||
"""Download a single image asynchronously"""
|
||||
async with semaphore: # Limit concurrent downloads
|
||||
try:
|
||||
image_url = sample_data["url"]
|
||||
image_path = sample_data["image_path"]
|
||||
|
||||
# Skip if already exists
|
||||
if os.path.exists(image_path):
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
|
||||
async with session.get(image_url, timeout=10) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# Verify it's a valid image
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
img = img.convert("RGB")
|
||||
img.save(image_path, "JPEG")
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None # Skip invalid images
|
||||
else:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
def download_laion_subset(self, num_samples: int = 1000):
|
||||
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||
|
||||
# Load LAION-400M subset from HuggingFace
|
||||
print("🤗 Loading from HuggingFace datasets...")
|
||||
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||
|
||||
# Collect sample metadata first (fast)
|
||||
print("📋 Collecting sample metadata...")
|
||||
candidates = []
|
||||
for sample in dataset:
|
||||
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||
break
|
||||
|
||||
image_url = sample.get("url", "")
|
||||
caption = sample.get("caption", "")
|
||||
|
||||
if not image_url or not caption:
|
||||
continue
|
||||
|
||||
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||
image_path = self.images_dir / image_filename
|
||||
|
||||
candidate = {
|
||||
"id": f"laion_{len(candidates):06d}",
|
||||
"url": image_url,
|
||||
"caption": caption,
|
||||
"image_path": str(image_path),
|
||||
"width": sample.get("original_width", 512),
|
||||
"height": sample.get("original_height", 512),
|
||||
"similarity": sample.get("similarity", 0.0),
|
||||
}
|
||||
candidates.append(candidate)
|
||||
|
||||
print(
|
||||
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||
)
|
||||
|
||||
# Download images in parallel
|
||||
async def download_batch():
|
||||
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
tasks = []
|
||||
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all downloads
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
progress_bar.close()
|
||||
|
||||
# Filter successful downloads
|
||||
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||
return successful[:num_samples]
|
||||
|
||||
# Run async download
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
samples = loop.run_until_complete(download_batch())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Save metadata
|
||||
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||
for sample in samples:
|
||||
f.write(json.dumps(sample) + "\n")
|
||||
|
||||
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||
return samples
|
||||
|
||||
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||
"""Generate CLIP image embeddings for downloaded images"""
|
||||
print("🔍 Generating CLIP image embeddings...")
|
||||
|
||||
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||
# This single model can encode both images and text.
|
||||
model = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
embeddings = []
|
||||
valid_samples = []
|
||||
|
||||
for sample in tqdm(samples, desc="Processing images"):
|
||||
try:
|
||||
# Load image
|
||||
image_path = sample["image_path"]
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||
vec = model.encode(
|
||||
[image],
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=1,
|
||||
show_progress_bar=False,
|
||||
)[0]
|
||||
embeddings.append(vec.astype(np.float32))
|
||||
valid_samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||
# Skip invalid images
|
||||
|
||||
embeddings = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
# Save embeddings
|
||||
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(embeddings_file, embeddings)
|
||||
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||
|
||||
return embeddings, valid_samples
|
||||
|
||||
def build_faiss_baseline(
|
||||
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||
):
|
||||
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Extract image IDs (must be present)
|
||||
if not samples or "id" not in samples[0]:
|
||||
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
print(f"📄 Processing {len(image_ids)} images")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(image_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def create_leann_passages(self, samples: list[dict]):
|
||||
"""Create LEANN-compatible passages from LAION data"""
|
||||
print("📝 Creating LEANN passages...")
|
||||
|
||||
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
for i, sample in enumerate(samples):
|
||||
passage = {
|
||||
"id": sample["id"],
|
||||
"text": sample["caption"], # Use caption as searchable text
|
||||
"metadata": {
|
||||
"image_url": sample["url"],
|
||||
"image_path": sample.get("image_path", ""),
|
||||
"width": sample["width"],
|
||||
"height": sample["height"],
|
||||
"similarity": sample["similarity"],
|
||||
"image_index": i, # Index for embedding lookup
|
||||
},
|
||||
}
|
||||
f.write(json.dumps(passage) + "\n")
|
||||
|
||||
print(f"✅ Created {len(samples)} passages")
|
||||
return passages_file
|
||||
|
||||
def build_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
|
||||
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - compact with recompute
|
||||
# Note: For multimodal case, we need to handle embeddings differently
|
||||
# Let's try using sentence-transformers mode but with custom embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
# HNSW params (or forwarded to chosen backend)
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
# Compact/pruned with recompute at query time
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages (text + metadata)
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def build_non_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ensure embeddings are saved (npy + pickle)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
if not npy_path.exists():
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
# Prepare ids in same order as passages_file
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
if not pkl_path.exists():
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - non-compact without recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=False, # Store embeddings (no recompute needed)
|
||||
is_compact=False, # Store full index (not pruned)
|
||||
distance_metric="cosine",
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages - embeddings will be loaded from file
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Adding passages"):
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
|
||||
# Add image metadata - LEANN will handle embeddings separately
|
||||
# Note: We store image metadata and caption text for searchability
|
||||
# Important: ensure passage ID in metadata matches vector ID
|
||||
builder.add_text(
|
||||
text=passage["text"], # Image caption for searchability
|
||||
metadata={**passage["metadata"], "id": passage["id"]},
|
||||
)
|
||||
|
||||
def _analyze_index_size(self, index_path: str):
|
||||
"""Analyze index file sizes"""
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
index_path = Path(index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.name # e.g., laion_index.leann
|
||||
index_prefix = index_path.stem # e.g., laion_index
|
||||
|
||||
files = [
|
||||
(f"{index_prefix}.index", ".index", "core"),
|
||||
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||
]
|
||||
|
||||
def _fmt_size(bytes_val: int) -> str:
|
||||
if bytes_val < 1024:
|
||||
return f"{bytes_val} B"
|
||||
kb = bytes_val / 1024
|
||||
if kb < 1024:
|
||||
return f"{kb:.1f} KB"
|
||||
mb = kb / 1024
|
||||
if mb < 1024:
|
||||
return f"{mb:.2f} MB"
|
||||
gb = mb / 1024
|
||||
return f"{gb:.2f} GB"
|
||||
|
||||
total_index_only_mb = 0.0
|
||||
total_all_mb = 0.0
|
||||
for filename, label, group in files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_bytes = file_path.stat().st_size
|
||||
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
total_all_mb += size_mb
|
||||
if group == "core":
|
||||
total_index_only_mb += size_mb
|
||||
else:
|
||||
print(f" {label}: (missing)")
|
||||
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||
|
||||
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||
"""Create evaluation queries from captions"""
|
||||
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||
|
||||
# Sample random captions as queries
|
||||
import random
|
||||
|
||||
random.seed(42) # For reproducibility
|
||||
|
||||
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||
|
||||
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
with open(queries_file, "w", encoding="utf-8") as f:
|
||||
for sample in query_samples:
|
||||
query = {
|
||||
"id": sample["id"],
|
||||
"query": sample["caption"],
|
||||
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||
}
|
||||
f.write(json.dumps(query) + "\n")
|
||||
|
||||
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||
return queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||
parser.add_argument(
|
||||
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||
)
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Initialize setup
|
||||
setup = LAIONSetup(args.data_dir)
|
||||
|
||||
# Step 1: Download LAION subset
|
||||
if not args.skip_download:
|
||||
print("\n📦 Step 1: Download LAION subset")
|
||||
samples = setup.download_laion_subset(args.num_samples)
|
||||
|
||||
# Step 2: Generate CLIP image embeddings
|
||||
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||
|
||||
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||
print("\n📝 Step 3: Create LEANN passages")
|
||||
passages_file = setup.create_leann_passages(valid_samples)
|
||||
else:
|
||||
print("⏭️ Skipping LAION dataset download")
|
||||
# Load existing data
|
||||
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||
|
||||
if not passages_file.exists() or not embeddings_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passages or embeddings file not found. Run without --skip-download first."
|
||||
)
|
||||
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||
|
||||
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||
if not args.skip_build:
|
||||
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||
|
||||
# Build compact index (production mode - small, recompute required)
|
||||
compact_index_path = args.index_path
|
||||
print(f"Building compact index: {compact_index_path}")
|
||||
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||
|
||||
# Build non-compact index (comparison mode - large, fast search)
|
||||
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||
print(f"Building non-compact index: {non_compact_index_path}")
|
||||
setup.build_non_compact_index(
|
||||
passages_file, embeddings, non_compact_index_path, args.backend
|
||||
)
|
||||
|
||||
# Step 5: Build FAISS flat baseline
|
||||
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||
if not args.skip_download:
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
else:
|
||||
# Load valid_samples from passages file for FAISS baseline
|
||||
valid_samples = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
|
||||
# Step 6: Create evaluation queries
|
||||
print("\n📝 Step 6: Create evaluation queries")
|
||||
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
baseline_path = "data/baseline/faiss_index.bin"
|
||||
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||
|
||||
print("\n🎉 Setup completed successfully!")
|
||||
print("📊 Summary:")
|
||||
if not args.skip_download:
|
||||
print(f" Downloaded samples: {len(samples)}")
|
||||
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||
else:
|
||||
print(f" Loaded {len(embeddings)} embeddings")
|
||||
|
||||
if not args.skip_build:
|
||||
print(f" Compact index: {compact_index_path}")
|
||||
print(f" Non-compact index: {non_compact_index_path}")
|
||||
print(f" FAISS baseline: {baseline_path}")
|
||||
print(f" Queries: {queries_file}")
|
||||
|
||||
print("\n🔧 Next steps:")
|
||||
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||
else:
|
||||
print(" Skipped building indexes")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
342
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_qwen3_model(model_name):
|
||||
"""Check if model is Qwen3"""
|
||||
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||
|
||||
|
||||
def is_qwen_vl_model(model_name):
|
||||
"""Check if model is Qwen2.5-VL"""
|
||||
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||
|
||||
|
||||
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||
"""Apply Qwen3 chat template with thinking enabled"""
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
def extract_thinking_answer(response):
|
||||
"""Extract final answer from Qwen3 thinking model response"""
|
||||
if "<think>" in response and "</think>" in response:
|
||||
try:
|
||||
think_end = response.index("</think>") + len("</think>")
|
||||
final_answer = response[think_end:].strip()
|
||||
return final_answer
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load HuggingFace model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading HF: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load vLLM model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading vLLM: {model_name}")
|
||||
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
||||
|
||||
# Qwen3 specific config
|
||||
if is_qwen3_model(model_name):
|
||||
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||
max_tokens = 2048
|
||||
else:
|
||||
stop_tokens = None
|
||||
max_tokens = 1024
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||
return llm, sampling_params
|
||||
|
||||
|
||||
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||
"""Generate with HF - supports Qwen3 thinking models"""
|
||||
model_name = getattr(model, "name_or_path", "unknown")
|
||||
is_qwen3 = is_qwen3_model(model_name)
|
||||
|
||||
# Apply chat template for Qwen3
|
||||
if is_qwen3:
|
||||
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||
max_tokens = max_tokens or 2048
|
||||
else:
|
||||
max_tokens = max_tokens or 1024
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
response = response[len(prompt) :].strip()
|
||||
|
||||
# Extract final answer for thinking models
|
||||
if is_qwen3:
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def generate_vllm(llm, sampling_params, prompt):
|
||||
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# Extract final answer for Qwen3 thinking models
|
||||
model_name = str(llm.llm_engine.model_config.model)
|
||||
if is_qwen3_model(model_name):
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def create_prompt(context, query, domain="default"):
|
||||
"""Create RAG prompt"""
|
||||
if domain == "emails":
|
||||
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "finance":
|
||||
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "multimodal":
|
||||
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
else:
|
||||
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||
"""Simple RAG evaluation with timing"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Search
|
||||
start = time.time()
|
||||
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||
search_time = time.time() - start
|
||||
|
||||
# Generate
|
||||
context = "\n\n".join([doc.text for doc in docs])
|
||||
prompt = create_prompt(context, query, domain)
|
||||
|
||||
start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - start
|
||||
|
||||
search_times.append(search_time)
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
||||
"""Load Qwen2.5-VL multimodal model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||
|
||||
# Fallback to specific class
|
||||
try:
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e2:
|
||||
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||
|
||||
|
||||
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||
"""Generate with Qwen2.5-VL multimodal model"""
|
||||
from PIL import Image
|
||||
|
||||
# Prepare inputs
|
||||
if image_path:
|
||||
image = Image.open(image_path)
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||
else:
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||
)
|
||||
|
||||
# Decode response
|
||||
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||
"""Create prompt for multimodal RAG"""
|
||||
if task_type == "images":
|
||||
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||
|
||||
Retrieved Image Descriptions:
|
||||
{context}
|
||||
|
||||
Question: {query}
|
||||
|
||||
Provide a detailed answer based on the visual content described above."""
|
||||
|
||||
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query_item in enumerate(queries):
|
||||
# Handle both string and dict formats for queries
|
||||
if isinstance(query_item, dict):
|
||||
query = query_item.get("query", "")
|
||||
image_path = query_item.get("image_path") # Optional reference image
|
||||
else:
|
||||
query = str(query_item)
|
||||
image_path = None
|
||||
|
||||
# Search
|
||||
start_time = time.time()
|
||||
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
|
||||
# Prepare context from search results
|
||||
context_parts = []
|
||||
for result in search_results:
|
||||
context_parts.append(f"- {result.text}")
|
||||
context = "\n".join(context_parts)
|
||||
|
||||
# Generate with multimodal model
|
||||
start_time = time.time()
|
||||
if processor and model:
|
||||
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||
else:
|
||||
response = f"Context: {context}"
|
||||
gen_time = time.time() - start_time
|
||||
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv pip install -e '.[dev]'")
|
||||
print("uv sync --only-group dev")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
@@ -197,6 +197,25 @@ def main():
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Batch size for HNSW batched search (0 disables batching)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-type",
|
||||
type=str,
|
||||
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||
default="ollama",
|
||||
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default="qwen3:1.7b",
|
||||
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
@@ -318,9 +337,24 @@ def main():
|
||||
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||
new_results = searcher.search(
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||
answer = chat.ask(
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
print(f"Answer: {answer}")
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ except ImportError:
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str = "facebook/contriever"
|
||||
model_path: str = "facebook/contriever-msmarco"
|
||||
batch_sizes: list[int] = None
|
||||
seq_length: int = 256
|
||||
num_runs: int = 5
|
||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_sizes is None:
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
|
||||
class MLXBenchmark:
|
||||
@@ -179,10 +179,16 @@ class Benchmark:
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# print shape of input_ids and attention_mask
|
||||
print(f"input_ids shape: {input_ids.shape}")
|
||||
print(f"attention_mask shape: {attention_mask.shape}")
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Update Benchmarks
|
||||
|
||||
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||
search” pipeline under different assumptions:
|
||||
|
||||
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||
settings influence incremental `add()` latency when embeddings are fetched
|
||||
over the ZMQ embedding server.
|
||||
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||
against an offline approach that keeps the graph static and fuses results.
|
||||
|
||||
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### 1. HNSW RNG Recompute Benchmark
|
||||
|
||||
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||
is enabled:
|
||||
|
||||
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||
| `baseline` | Enabled | Enabled | Enabled |
|
||||
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||
|
||||
For each scenario the script:
|
||||
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||
3. Appends the requested updates using the scenario’s RNG flags.
|
||||
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||
timings before appending a row to the CSV output.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||
LEANN_LOG_LEVEL=INFO \
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--runs 1 \
|
||||
--index-path .leann/bench/test.leann \
|
||||
--initial-files data/PrideandPrejudice.txt \
|
||||
--update-files data/huawei_pangu.md \
|
||||
--max-initial 300 \
|
||||
--max-updates 1 \
|
||||
--add-timeout 120
|
||||
```
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||
(including ms/passage) for each run.
|
||||
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||
`LEANN_HNSW_LOG_PATH`).
|
||||
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||
|
||||
### 2. Sequential vs. Offline Update Benchmark
|
||||
|
||||
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||
same dataset:
|
||||
|
||||
- **Scenario A – Sequential Update**
|
||||
- Start an embedding server.
|
||||
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||
mutates the HNSW graph.
|
||||
- After all inserts, run a search on the updated graph.
|
||||
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||
latency.
|
||||
|
||||
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||
- Stop Scenario A’s server and start a fresh embedding server.
|
||||
- Spawn two threads: one generates embeddings for the new passages offline
|
||||
(graph unchanged); the other computes the query embedding and searches the
|
||||
existing graph.
|
||||
- Merge offline similarities with the graph search results to emulate late
|
||||
fusion, then report the merged top‑k preview.
|
||||
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||
|
||||
**Run (both scenarios):**
|
||||
```bash
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 \
|
||||
--num-updates 1
|
||||
```
|
||||
|
||||
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||
print timing summaries to stdout and append the results to CSV.
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||
Scenario A and B.
|
||||
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||
checks.
|
||||
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||
|
||||
### 3. Visualisation
|
||||
|
||||
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||
benchmark into a single two-panel plot.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
uv run -m benchmarks.update.plot_bench_results \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||
- `--csv` – RNG benchmark results CSV (left panel).
|
||||
- `--csv-right` – Update strategy results CSV (right panel).
|
||||
- `--out` – Output image path (PNG/PDF supported).
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||
suites.
|
||||
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||
slides/papers.
|
||||
|
||||
## Parameters & Environment
|
||||
|
||||
### Common CLI Flags
|
||||
- `--max-initial` – Number of initial passages used to seed the index.
|
||||
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||
|
||||
### Environment Variables
|
||||
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||
execution of the embedding model.
|
||||
|
||||
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||
fusion better match your latency/accuracy trade-offs.
|
||||
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Benchmarks for LEANN update workflows."""
|
||||
|
||||
# Expose helper to locate repository root for other modules that need it.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_repo_root() -> Path:
|
||||
"""Return the project root containing pyproject.toml."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
return current.parents[1]
|
||||
|
||||
|
||||
__all__ = ["find_repo_root"]
|
||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
||||
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||
embedding recomputation.
|
||||
|
||||
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||
|
||||
Example usage (run from the repo root; downloads the model on first run)::
|
||||
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--index-path .leann/bench/leann-demo.leann \
|
||||
--runs 1
|
||||
|
||||
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||
if you want a larger or different workload, and change the embedding model via
|
||||
``--model-name``.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import zmq
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||
|
||||
|
||||
def benchmark_update_with_mode(
|
||||
index_path: Path,
|
||||
new_chunks: list[dict[str, Any]],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
disable_forward_rng: bool,
|
||||
disable_reverse_rng: bool,
|
||||
server_port: int,
|
||||
add_timeout: int,
|
||||
ef_construction: int,
|
||||
) -> tuple[float, float]:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in new_chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
index.is_recompute = True
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
try:
|
||||
storage_index.ntotal = index.ntotal
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||
if ef_construction is not None:
|
||||
index.hnsw.efConstruction = ef_construction
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||
logger.info(
|
||||
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||
disable_forward_rng,
|
||||
disable_reverse_rng,
|
||||
applied_forward,
|
||||
applied_reverse,
|
||||
)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||
offset_map_backup = offset_map.copy()
|
||||
|
||||
try:
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
server_started, actual_port = server_manager.start_server(
|
||||
port=server_port,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=distance_metric,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError("Failed to start embedding server.")
|
||||
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(actual_port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(actual_port)
|
||||
|
||||
_warmup_embedding_server(actual_port)
|
||||
|
||||
total_start = time.time()
|
||||
add_elapsed = 0.0
|
||||
|
||||
try:
|
||||
import signal
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise TimeoutError("incremental add timed out")
|
||||
|
||||
if add_timeout > 0:
|
||||
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||
signal.alarm(add_timeout)
|
||||
|
||||
add_start = time.time()
|
||||
for i in range(embeddings.shape[0]):
|
||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||
add_elapsed = time.time() - add_start
|
||||
if add_timeout > 0:
|
||||
signal.alarm(0)
|
||||
faiss.write_index(index, str(index_file))
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
|
||||
except TimeoutError:
|
||||
raise
|
||||
except Exception:
|
||||
if passages_file.exists():
|
||||
with open(passages_file, "rb+") as f:
|
||||
f.truncate(rollback_size)
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map_backup, f)
|
||||
raise
|
||||
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(False)
|
||||
index.hnsw.set_disable_reverse_prune(False)
|
||||
except AttributeError:
|
||||
pass
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
|
||||
return total_elapsed, add_elapsed
|
||||
|
||||
|
||||
def _total_zmq_nodes(log_path: Path) -> int:
|
||||
if not log_path.exists():
|
||||
return 0
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
text = log_file.read()
|
||||
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||
|
||||
|
||||
def _warmup_embedding_server(port: int) -> None:
|
||||
"""Send a dummy REQ so the embedding server loads its model."""
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
sock = ctx.socket(zmq.REQ)
|
||||
sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||
sock.send(payload)
|
||||
try:
|
||||
sock.recv()
|
||||
except zmq.error.Again:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/leann-demo.leann"),
|
||||
help="Output index base path (without extension).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Files used to build the initial index.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Files appended during the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model used for build/update.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
default="sentence-transformers",
|
||||
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
help="Distance metric for HNSW backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef-construction",
|
||||
type=int,
|
||||
default=200,
|
||||
help="efConstruction setting for initial build.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=5557,
|
||||
help="Port for the real embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-initial",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Optional cap on initial passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-updates",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Optional cap on update passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-timeout",
|
||||
type=int,
|
||||
default=900,
|
||||
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-path",
|
||||
type=Path,
|
||||
default=Path("bench_latency.png"),
|
||||
help="Where to save the latency bar plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Where to append per-scenario results as CSV.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||
|
||||
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||
ensure_index_dir(args.index_path)
|
||||
|
||||
scenarios = [
|
||||
("baseline", False, False, True),
|
||||
("no_cache_baseline", False, False, False),
|
||||
("disable_forward_rng", True, False, True),
|
||||
("disable_forward_and_reverse_rng", True, True, True),
|
||||
]
|
||||
|
||||
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||
|
||||
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
|
||||
# CSV setup
|
||||
import csv
|
||||
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"cache_enabled",
|
||||
"ef_construction",
|
||||
"max_initial",
|
||||
"max_updates",
|
||||
"total_time_s",
|
||||
"add_only_s",
|
||||
"latency_ms_per_passage",
|
||||
"zmq_nodes",
|
||||
"stageA_time_s",
|
||||
"stageBC_time_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
# Create CSV with header if missing
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
for run in range(args.runs):
|
||||
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||
print(f"\nScenario: {name}")
|
||||
cleanup_index_files(args.index_path)
|
||||
if log_path.exists():
|
||||
try:
|
||||
log_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
|
||||
try:
|
||||
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||
args.index_path,
|
||||
update_chunks,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
disable_forward,
|
||||
disable_reverse,
|
||||
args.server_port,
|
||||
args.add_timeout,
|
||||
args.ef_construction,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
print(f"Scenario {name} timed out: {exc}")
|
||||
continue
|
||||
|
||||
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
if curr_size < prev_size:
|
||||
prev_size = 0
|
||||
zmq_count = 0
|
||||
if log_path.exists():
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
log_file.seek(prev_size)
|
||||
new_entries = log_file.read()
|
||||
zmq_count = sum(
|
||||
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||
)
|
||||
stageA = sum(
|
||||
float(x)
|
||||
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
stageBC = sum(
|
||||
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
else:
|
||||
stageA = 0.0
|
||||
stageBC = 0.0
|
||||
|
||||
per_chunk = add_elapsed / len(update_chunks)
|
||||
print(
|
||||
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||
)
|
||||
print(f"ZMQ node fetch total: {zmq_count}")
|
||||
results_total[name].append(total_elapsed)
|
||||
results_add[name].append(add_elapsed)
|
||||
results_zmq[name].append(zmq_count)
|
||||
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||
results_stageA[name].append(stageA)
|
||||
results_stageBC[name].append(stageBC)
|
||||
|
||||
# Append row to CSV
|
||||
if args.csv_path:
|
||||
row = {
|
||||
"run_id": run_id,
|
||||
"scenario": name,
|
||||
"cache_enabled": 1 if cache_enabled else 0,
|
||||
"ef_construction": args.ef_construction,
|
||||
"max_initial": args.max_initial,
|
||||
"max_updates": args.max_updates,
|
||||
"total_time_s": round(total_elapsed, 6),
|
||||
"add_only_s": round(add_elapsed, 6),
|
||||
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||
"zmq_nodes": int(zmq_count),
|
||||
"stageA_time_s": round(stageA, 6),
|
||||
"stageBC_time_s": round(stageBC, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row)
|
||||
|
||||
print("\n=== Summary ===")
|
||||
for name in results_add:
|
||||
add_values = results_add[name]
|
||||
total_values = results_total[name]
|
||||
zmq_values = results_zmq[name]
|
||||
latency_values = results_ms_per_passage[name]
|
||||
if not add_values:
|
||||
print(f"{name}: no successful runs")
|
||||
continue
|
||||
avg_add = sum(add_values) / len(add_values)
|
||||
avg_total = sum(total_values) / len(total_values)
|
||||
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||
runs = len(add_values)
|
||||
print(
|
||||
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||
)
|
||||
|
||||
if args.plot_path:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
labels = [name for name, *_ in scenarios]
|
||||
values = [
|
||||
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||
if results_ms_per_passage[name]
|
||||
else 0.0
|
||||
for name in labels
|
||||
]
|
||||
|
||||
def _auto_cap(vals: list[float]) -> float | None:
|
||||
s = sorted(vals, reverse=True)
|
||||
if len(s) < 2:
|
||||
return None
|
||||
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||
return s[1] * 1.1
|
||||
return None
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
if args.broken_y:
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.4, 5.0),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||
)
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i,
|
||||
v + lower_cap * 0.02,
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False)
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
d = 0.015
|
||||
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.set_xticks(range(len(labels)))
|
||||
ax_bottom.set_xticklabels(labels)
|
||||
ax = ax_bottom
|
||||
else:
|
||||
cap = args.cap_y or _auto_cap(values)
|
||||
plt.figure(figsize=(7.2, 4.2))
|
||||
ax = plt.gca()
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(b[0])
|
||||
if v > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
ax.plot(
|
||||
[0.02 - 0.02, 0.02 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
ax.plot(
|
||||
[0.98 - 0.02, 0.98 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
if any(v > cap for v in values):
|
||||
ax.legend(
|
||||
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||
)
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels)
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||
|
||||
plt.ylabel("Average add latency (ms per passage)")
|
||||
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.plot_path)
|
||||
print(f"Saved latency bar plot to {args.plot_path}")
|
||||
# ZMQ time split (Stage A vs B/C)
|
||||
try:
|
||||
plt.figure(figsize=(6, 4))
|
||||
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||
bc_vals = [
|
||||
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||
]
|
||||
ind = range(len(labels))
|
||||
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||
plt.bar(
|
||||
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||
)
|
||||
plt.xticks(list(ind), labels, rotation=10)
|
||||
plt.ylabel("Server ZMQ time (s)")
|
||||
plt.title(
|
||||
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||
)
|
||||
plt.legend()
|
||||
out2 = args.plot_path.with_name(
|
||||
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out2)
|
||||
print(f"Saved ZMQ time split plot to {out2}")
|
||||
except Exception as e:
|
||||
print("Failed to plot ZMQ split:", e)
|
||||
except ImportError:
|
||||
print("matplotlib not available; skipping plot generation")
|
||||
|
||||
# leave the last build on disk for inspection
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""
|
||||
Compare two latency models for small incremental updates vs. search:
|
||||
|
||||
Scenario A (sequential update then search):
|
||||
- Build initial HNSW (is_recompute=True)
|
||||
- Start embedding server (ZMQ) for recompute
|
||||
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||
- Then run a search query on the updated index
|
||||
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||
|
||||
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||
- Do NOT insert the N passages into the graph
|
||||
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||
embedding and run a search on the existing index
|
||||
- After both finish, compute similarity between the query embedding and the N
|
||||
new passage embeddings, merge with the index search results by score, and
|
||||
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||
|
||||
This script reuses the model/data loading conventions of
|
||||
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||
comparison for the two execution strategies above.
|
||||
|
||||
Example (from the repository root):
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 --num-updates 5 --k 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import psutil # type: ignore
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||
if metric == "cosine":
|
||||
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vecs = vecs / norms
|
||||
return vecs
|
||||
|
||||
|
||||
def _read_index_for_search(index_path: Path) -> Any:
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
# Force-disable experimental disk cache when loading the index so that
|
||||
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||
cfg = faiss.HNSWIndexConfig()
|
||||
cfg.is_recompute = True
|
||||
if hasattr(cfg, "disk_cache_ratio"):
|
||||
cfg.disk_cache_ratio = 0.0
|
||||
if hasattr(cfg, "external_storage_path"):
|
||||
cfg.external_storage_path = None
|
||||
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||
# ensure recompute mode persists after reload
|
||||
try:
|
||||
index.is_recompute = True
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
actual_ntotal = index.hnsw.levels.size()
|
||||
except AttributeError:
|
||||
actual_ntotal = index.ntotal
|
||||
if actual_ntotal != index.ntotal:
|
||||
print(
|
||||
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||
flush=True,
|
||||
)
|
||||
index.ntotal = actual_ntotal
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
return index
|
||||
|
||||
|
||||
def _append_passages_for_updates(
|
||||
meta_path: Path,
|
||||
start_id: int,
|
||||
texts: list[str],
|
||||
) -> list[str]:
|
||||
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
index_dir = meta_path.parent
|
||||
meta_name = meta_path.name
|
||||
if not meta_name.endswith(".meta.json"):
|
||||
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||
index_base = meta_name[: -len(".meta.json")]
|
||||
|
||||
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||
|
||||
if not passages_file.exists() or not offsets_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passage store missing; cannot register update passages for recompute mode."
|
||||
)
|
||||
|
||||
with open(offsets_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
|
||||
assigned_ids: list[str] = []
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for i, text in enumerate(texts):
|
||||
passage_id = str(start_id + i)
|
||||
offset = f.tell()
|
||||
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
offset_map[passage_id] = offset
|
||||
assigned_ids.append(passage_id)
|
||||
|
||||
with open(offsets_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
try:
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
meta = {}
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
return assigned_ids
|
||||
|
||||
|
||||
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||
distances = np.zeros((1, k), dtype=np.float32)
|
||||
indices = np.zeros((1, k), dtype=np.int64)
|
||||
index.search(
|
||||
1,
|
||||
faiss.swig_ptr(q),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(indices),
|
||||
)
|
||||
return distances[0], indices[0]
|
||||
|
||||
|
||||
def _score_for_metric(dist: float, metric: str) -> float:
|
||||
# Convert FAISS distance to a "higher is better" score
|
||||
if metric in ("mips", "cosine"):
|
||||
return float(dist)
|
||||
# l2 distance (smaller better) -> negative distance as score
|
||||
return -float(dist)
|
||||
|
||||
|
||||
def _merge_results(
|
||||
index_results: tuple[np.ndarray, np.ndarray],
|
||||
offline_scores: list[tuple[int, float]],
|
||||
k: int,
|
||||
metric: str,
|
||||
) -> list[tuple[str, float]]:
|
||||
distances, indices = index_results
|
||||
merged: list[tuple[str, float]] = []
|
||||
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||
for j, s in offline_scores:
|
||||
merged.append((f"offline:{j}", s))
|
||||
merged.sort(key=lambda x: x[1], reverse=True)
|
||||
return merged[:k]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
name: str
|
||||
update_total_s: float
|
||||
search_s: float
|
||||
overall_s: float
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
)
|
||||
parser.add_argument("--max-initial", type=int, default=300)
|
||||
parser.add_argument("--num-updates", type=int, default=5)
|
||||
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="neural network",
|
||||
help="Query text used for the search benchmark.",
|
||||
)
|
||||
parser.add_argument("--server-port", type=int, default=5557)
|
||||
parser.add_argument("--add-timeout", type=int, default=600)
|
||||
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
)
|
||||
parser.add_argument("--ef-construction", type=int, default=200)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
choices=["A", "B", "both"],
|
||||
default="both",
|
||||
help="Run only Scenario A, Scenario B, or both",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Where to append results (CSV).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
# Load data
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages loaded from --update-files")
|
||||
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||
if len(update_paragraphs) < args.num_updates:
|
||||
raise ValueError(
|
||||
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||
)
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
cleanup_index_files(args.index_path)
|
||||
|
||||
# Build initial index
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
# Prepare index object and meta
|
||||
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||
index = _read_index_for_search(args.index_path)
|
||||
|
||||
# CSV setup
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"max_initial",
|
||||
"num_updates",
|
||||
"k",
|
||||
"total_time_s",
|
||||
"add_total_s",
|
||||
"search_time_s",
|
||||
"emb_time_s",
|
||||
"makespan_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
# Debug: list existing HNSW server PIDs before starting
|
||||
try:
|
||||
existing = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if existing:
|
||||
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||
for p in existing:
|
||||
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||
except Exception as _e:
|
||||
pass
|
||||
|
||||
add_total = 0.0
|
||||
search_after_add = 0.0
|
||||
total_seq = 0.0
|
||||
port_a = None
|
||||
if args.only in ("A", "both"):
|
||||
# Scenario A: sequential update then search
|
||||
start_id = index.ntotal
|
||||
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||
if assigned_ids:
|
||||
logger.debug(
|
||||
"Registered %d update passages starting at id %s",
|
||||
len(assigned_ids),
|
||||
assigned_ids[0],
|
||||
)
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
ok, port = server_manager.start_server(
|
||||
port=args.server_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError("Failed to start embedding server")
|
||||
try:
|
||||
# Set ZMQ port for recompute mode
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(port)
|
||||
|
||||
# Start A overall timer BEFORE computing update embeddings
|
||||
t0 = time.time()
|
||||
|
||||
# Compute embeddings for updates (counted into A's overall)
|
||||
t_emb0 = time.time()
|
||||
upd_embs = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time_updates = time.time() - t_emb0
|
||||
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||
|
||||
# Perform sequential adds
|
||||
for i in range(upd_embs.shape[0]):
|
||||
t_add0 = time.time()
|
||||
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||
add_total += time.time() - t_add0
|
||||
# Don't persist index after adds to avoid contaminating Scenario B
|
||||
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||
# faiss.write_index(index, str(index_file))
|
||||
|
||||
# Search after updates
|
||||
q_emb = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||
|
||||
# Warm up search with a dummy query first
|
||||
print("[DEBUG] Warming up search...")
|
||||
_ = _search(index, q_emb, 1)
|
||||
|
||||
t_s0 = time.time()
|
||||
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||
search_after_add = time.time() - t_s0
|
||||
total_seq = time.time() - t0
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
port_a = port
|
||||
|
||||
print("\n=== Scenario A: update->search (sequential) ===")
|
||||
# emb_time_updates is defined only when A runs
|
||||
try:
|
||||
_emb_a = emb_time_updates
|
||||
except NameError:
|
||||
_emb_a = 0.0
|
||||
print(
|
||||
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||
)
|
||||
# CSV row for A
|
||||
if args.csv_path:
|
||||
row_a = {
|
||||
"run_id": run_id,
|
||||
"scenario": "A",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": round(total_seq, 6),
|
||||
"add_total_s": round(add_total, 6),
|
||||
"search_time_s": round(search_after_add, 6),
|
||||
"emb_time_s": round(_emb_a, 6),
|
||||
"makespan_s": 0.0,
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_a)
|
||||
|
||||
# Verify server cleanup
|
||||
try:
|
||||
# short sleep to allow signal handling to finish
|
||||
time.sleep(0.5)
|
||||
leftovers = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if leftovers:
|
||||
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||
for p in leftovers:
|
||||
print(
|
||||
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||
)
|
||||
else:
|
||||
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||
if args.only in ("B", "both"):
|
||||
# ensure a server is available for recompute search
|
||||
server_manager_b = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
requested_port = args.server_port if port_a is None else port_a
|
||||
ok_b, port_b = server_manager_b.start_server(
|
||||
port=requested_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok_b:
|
||||
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||
|
||||
# Wait for server to fully initialize
|
||||
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||
time.sleep(2)
|
||||
|
||||
try:
|
||||
# Read the index first
|
||||
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||
|
||||
# Then configure ZMQ port on the correct index object
|
||||
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||
index_no_update.hnsw.set_zmq_port(port_b)
|
||||
elif hasattr(index_no_update, "set_zmq_port"):
|
||||
index_no_update.set_zmq_port(port_b)
|
||||
|
||||
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||
logger.info("Warming up embedding model for Scenario B...")
|
||||
_ = compute_embeddings(
|
||||
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
|
||||
# Prepare worker A: compute embeddings for the same N passages
|
||||
emb_time = 0.0
|
||||
updates_embs_offline: np.ndarray | None = None
|
||||
|
||||
def _worker_emb():
|
||||
nonlocal emb_time, updates_embs_offline
|
||||
t = time.time()
|
||||
updates_embs_offline = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time = time.time() - t
|
||||
|
||||
# Pre-compute query embedding and warm up search outside of timed section.
|
||||
q_vec = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||
print("[DEBUG B] Warming up search...")
|
||||
_ = _search(index_no_update, q_vec, 1)
|
||||
|
||||
# Worker B: timed search on the warmed index
|
||||
search_time = 0.0
|
||||
offline_elapsed = 0.0
|
||||
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||
|
||||
def _worker_search():
|
||||
nonlocal search_time, index_results
|
||||
t = time.time()
|
||||
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||
search_time = time.time() - t
|
||||
index_results = (distances, indices)
|
||||
|
||||
# Run two workers concurrently
|
||||
t0 = time.time()
|
||||
th1 = threading.Thread(target=_worker_emb)
|
||||
th2 = threading.Thread(target=_worker_search)
|
||||
th1.start()
|
||||
th2.start()
|
||||
th1.join()
|
||||
th2.join()
|
||||
offline_elapsed = time.time() - t0
|
||||
|
||||
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||
offline_scores: list[tuple[int, float]] = []
|
||||
if updates_embs_offline is not None:
|
||||
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||
for j in range(upd2.shape[0]):
|
||||
if args.distance_metric in ("mips", "cosine"):
|
||||
s = float(np.dot(q_vec[0], upd2[j]))
|
||||
else:
|
||||
diff = q_vec[0] - upd2[j]
|
||||
s = -float(np.dot(diff, diff))
|
||||
offline_scores.append((j, s))
|
||||
|
||||
merged_topk = (
|
||||
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||
if index_results
|
||||
else []
|
||||
)
|
||||
|
||||
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||
print(
|
||||
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||
)
|
||||
if merged_topk:
|
||||
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||
print(f"Merged top-5 preview: {preview}")
|
||||
# CSV row for B
|
||||
if args.csv_path:
|
||||
row_b = {
|
||||
"run_id": run_id,
|
||||
"scenario": "B",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": 0.0,
|
||||
"add_total_s": 0.0,
|
||||
"search_time_s": round(search_time, 6),
|
||||
"emb_time_s": round(emb_time, 6),
|
||||
"makespan_s": round(offline_elapsed, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_b)
|
||||
|
||||
finally:
|
||||
server_manager_b.stop_server()
|
||||
|
||||
# Summary
|
||||
print("\n=== Summary ===")
|
||||
msg_a = (
|
||||
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||
if args.only in ("A", "both")
|
||||
else "A: skipped"
|
||||
)
|
||||
msg_b = (
|
||||
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||
if args.only in ("B", "both")
|
||||
else "B: skipped"
|
||||
)
|
||||
print(msg_a + "\n" + msg_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Plot latency bars from the benchmark CSV produced by
|
||||
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||
|
||||
If you also provide an offline_vs_update.csv via --csv-right
|
||||
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||
output a side-by-side figure:
|
||||
- Left: ms/passage bars (four RNG scenarios).
|
||||
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||
|
||||
Usage:
|
||||
uv run python benchmarks/update/plot_bench_results.py \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
|
||||
The script selects the latest run_id in the CSV and plots four bars for
|
||||
the default scenarios:
|
||||
- baseline
|
||||
- no_cache_baseline
|
||||
- disable_forward_rng
|
||||
- disable_forward_and_reverse_rng
|
||||
|
||||
If multiple rows exist per scenario for that run_id, the script averages
|
||||
their latency_ms_per_passage values.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_SCENARIOS = [
|
||||
"no_cache_baseline",
|
||||
"baseline",
|
||||
"disable_forward_rng",
|
||||
"disable_forward_and_reverse_rng",
|
||||
]
|
||||
|
||||
SCENARIO_LABELS = {
|
||||
"baseline": "+ Cache",
|
||||
"no_cache_baseline": "Naive \n Recompute",
|
||||
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||
}
|
||||
|
||||
# Paper-style colors and hatches for scenarios
|
||||
SCENARIO_STYLES = {
|
||||
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||
}
|
||||
|
||||
|
||||
def load_latest_run(csv_path: Path):
|
||||
rows = []
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
rows.append(row)
|
||||
if not rows:
|
||||
raise SystemExit("CSV is empty: no rows to plot")
|
||||
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||
run_ids = [r.get("run_id", "") for r in rows]
|
||||
latest = max(run_ids)
|
||||
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||
if not latest_rows:
|
||||
# Fallback: take last 4 rows
|
||||
latest_rows = rows[-4:]
|
||||
latest = latest_rows[-1].get("run_id", "unknown")
|
||||
return latest, latest_rows
|
||||
|
||||
|
||||
def aggregate_latency(rows):
|
||||
acc = defaultdict(list)
|
||||
for r in rows:
|
||||
sc = r.get("scenario", "")
|
||||
try:
|
||||
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||
except ValueError:
|
||||
continue
|
||||
acc[sc].append(val)
|
||||
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||
return avg
|
||||
|
||||
|
||||
def _auto_cap(values: list[float]) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
sorted_vals = sorted(values, reverse=True)
|
||||
if len(sorted_vals) < 2:
|
||||
return None
|
||||
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||
if second <= 0:
|
||||
return None
|
||||
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||
if max_v >= 2.5 * second:
|
||||
return second * 1.1
|
||||
return None
|
||||
|
||||
|
||||
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||
# Draw small diagonal ticks near left/right to signal cap
|
||||
x0, x1 = rel_x0, rel_x1
|
||||
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
if v >= 1000:
|
||||
return f"{v / 1000:.1f}k"
|
||||
return f"{v:.1f}"
|
||||
|
||||
|
||||
def main():
|
||||
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument(
|
||||
"--csv",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Path to results CSV (defaults to bench_results.csv)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--out",
|
||||
type=Path,
|
||||
default=Path("add_ablation.pdf"),
|
||||
help="Output image path",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--csv-right",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-auto-cap",
|
||||
action="store_true",
|
||||
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
latest_run, latest_rows = load_latest_run(args.csv)
|
||||
avg = aggregate_latency(latest_rows)
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception as e:
|
||||
raise SystemExit(f"matplotlib not available: {e}")
|
||||
|
||||
scenarios = DEFAULT_SCENARIOS
|
||||
values = [avg.get(name, 0.0) for name in scenarios]
|
||||
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
# If right CSV is provided, build side-by-side figure
|
||||
if args.csv_right is not None:
|
||||
try:
|
||||
right_rows_all = []
|
||||
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||
rreader = csv.DictReader(f)
|
||||
right_rows_all = list(rreader)
|
||||
if right_rows_all:
|
||||
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||
else:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
except Exception:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
|
||||
a_total = 0.0
|
||||
b_makespan = 0.0
|
||||
for r in right_rows:
|
||||
sc = (r.get("scenario", "") or "").strip().upper()
|
||||
if sc == "A":
|
||||
try:
|
||||
a_total = float(r.get("total_time_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
elif sc == "B":
|
||||
try:
|
||||
b_makespan = float(r.get("makespan_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import gridspec
|
||||
|
||||
# Left subplot (reuse current style, with optional cap)
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
x = list(range(len(labels)))
|
||||
|
||||
if args.broken_y:
|
||||
# Use broken axis for left subplot
|
||||
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||
gs = gridspec.GridSpec(
|
||||
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||
)
|
||||
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||
ax_right = fig.add_subplot(gs[:, 1])
|
||||
|
||||
# Determine break points
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = (
|
||||
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||
) # Increased to show more range
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.5, lower_cap * 1.02)
|
||||
)
|
||||
ymax = (
|
||||
max(values) * 1.90 if values else 1.0
|
||||
) # Increase headroom to 1.90 for text label and tick range
|
||||
|
||||
# Draw bars on both axes
|
||||
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Set limits
|
||||
ax_left_bottom.set_ylim(0, lower_cap)
|
||||
ax_left_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values (convert ms to s)
|
||||
values_s = [v / 1000.0 for v in values]
|
||||
lower_cap_s = lower_cap / 1000.0
|
||||
upper_start_s = upper_start / 1000.0
|
||||
ymax_s = ymax / 1000.0
|
||||
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||
ax_left_bottom.clear()
|
||||
ax_left_top.clear()
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||
# Draw in bottom axis for all bars
|
||||
ax_left_bottom.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||
if v > upper_start_s:
|
||||
ax_left_top.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
for i, v in enumerate(values_s):
|
||||
if v <= lower_cap_s:
|
||||
ax_left_bottom.text(
|
||||
i,
|
||||
v + lower_cap_s * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
else:
|
||||
ax_left_top.text(
|
||||
i,
|
||||
v + (ymax_s - upper_start_s) * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Hide spines between axes
|
||||
ax_left_top.spines["bottom"].set_visible(False)
|
||||
ax_left_bottom.spines["top"].set_visible(False)
|
||||
ax_left_top.tick_params(
|
||||
labeltop=False, labelbottom=False, bottom=False
|
||||
) # Hide tick marks
|
||||
ax_left_bottom.xaxis.tick_bottom()
|
||||
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||
|
||||
# Draw break marks (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_left_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
|
||||
ax_left_bottom.set_xticks(x)
|
||||
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match bar width with right subplot
|
||||
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||
ax_left_top.set_xlim(-0.6, 3.6)
|
||||
|
||||
ax_left = ax_left_bottom # for compatibility
|
||||
else:
|
||||
# Regular side-by-side layout
|
||||
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||
if val > cap:
|
||||
bars[i].set_hatch("//")
|
||||
ax_left.text(
|
||||
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_left.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(val),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax_left.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax_left, y=0.98)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
else:
|
||||
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
for i, v in enumerate(values):
|
||||
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax_left.set_ylabel("Latency (ms per passage)")
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
ax_left.set_title(
|
||||
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||
)
|
||||
|
||||
# Right subplot (A vs B, seconds) - paper style
|
||||
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||
r_styles = [
|
||||
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||
]
|
||||
# 2 bars, centered with proper spacing
|
||||
xr = [0, 1]
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||
ax_right.bar(
|
||||
xr[i],
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
for i, v in enumerate(r_values):
|
||||
max_v = max(r_values) if r_values else 1.0
|
||||
offset = max(0.0002, 0.02 * max_v)
|
||||
ax_right.text(
|
||||
xr[i],
|
||||
v + offset,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_right.set_xticks(xr)
|
||||
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_right.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match left subplot's bar width visually
|
||||
# Accounting for width_ratios=[1.5, 1]:
|
||||
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||
# Right: 2 bars, need same visual width
|
||||
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||
# range_right = 4.2 / 1.5 = 2.8
|
||||
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||
ax_right.set_xlim(-0.9, 1.9)
|
||||
|
||||
# Set y-axis limit with headroom for text labels
|
||||
if r_values:
|
||||
max_v = max(r_values)
|
||||
ax_right.set_ylim(0, max_v * 1.15)
|
||||
|
||||
# Format y-axis to avoid scientific notation
|
||||
ax_right.ticklabel_format(style="plain", axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Add aligned ylabels using fig.text (after tight_layout)
|
||||
# Get the vertical center of the entire figure
|
||||
fig_center_y = 0.5
|
||||
# Left ylabel - closer to left plot
|
||||
left_x = 0.05
|
||||
fig.text(
|
||||
left_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
# Right ylabel - closer to right plot
|
||||
right_bbox = ax_right.get_position()
|
||||
right_x = right_bbox.x0 - 0.07
|
||||
fig.text(
|
||||
right_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
return
|
||||
|
||||
# Broken-Y mode
|
||||
if args.broken_y:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.5, 6.75),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||
)
|
||||
|
||||
# Determine default breaks from second-highest
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Limits
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
|
||||
# Hide spines between axes and draw diagonal break marks
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
|
||||
# Diagonal lines at the break (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||
|
||||
ax_bottom.set_xticks(x)
|
||||
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax = ax_bottom # for labeling below
|
||||
else:
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
|
||||
plt.figure(figsize=(5.4, 3.15))
|
||||
ax = plt.gca()
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(bar[0])
|
||||
# Hatch and annotate when capped
|
||||
if val > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax, y=0.98)
|
||||
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||
v > cap for v in values
|
||||
) else None
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(
|
||||
idx,
|
||||
val + 1.0,
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
# Try to extract some context for title
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
|
||||
if args.broken_y:
|
||||
fig.text(
|
||||
0.02,
|
||||
0.5,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
fig.suptitle(
|
||||
"Add Operation Latency",
|
||||
fontsize=11,
|
||||
y=0.98,
|
||||
fontweight="bold",
|
||||
)
|
||||
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||
else:
|
||||
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
200
docs/COLQWEN_GUIDE.md
Normal file
@@ -0,0 +1,200 @@
|
||||
# ColQwen Integration Guide
|
||||
|
||||
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||
|
||||
## Quick Start
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
### 2. Basic Usage
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 \
|
||||
--pages-dir ./page_images/ # Optional: save page images
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||
- `--index`: Name for the index (required)
|
||||
- `--model`: `colqwen2` (default) or `colpali`
|
||||
- `--pages-dir`: Directory to save page images (optional)
|
||||
|
||||
### Search Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--top-k`: Number of results to return (default: 5)
|
||||
- `--model`: Model used for search (should match build model)
|
||||
|
||||
### Interactive Q&A
|
||||
```bash
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
```
|
||||
|
||||
**Commands in interactive mode:**
|
||||
- Type your questions naturally
|
||||
- `help`: Show available commands
|
||||
- `quit`/`exit`/`q`: Exit interactive mode
|
||||
|
||||
## 🧪 Test & Reproduce Results
|
||||
|
||||
Run the reproduction test for issue #119:
|
||||
```bash
|
||||
python test_colqwen_reproduction.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. ✅ Check dependencies
|
||||
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||
3. 🏗️ Build test index
|
||||
4. 🔍 Run sample queries
|
||||
5. 📊 Show how to generate similarity maps
|
||||
|
||||
## 🎨 Advanced: Similarity Maps
|
||||
|
||||
For visual similarity analysis, use the existing advanced script:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Edit the script to customize:
|
||||
- `QUERY`: Your question
|
||||
- `MODEL`: "colqwen2" or "colpali"
|
||||
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||
- `SIMILARITY_MAP`: Generate heatmaps
|
||||
- `ANSWER`: Enable Qwen-VL answer generation
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
### ColQwen2 vs ColPali
|
||||
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||
|
||||
### Architecture
|
||||
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||
4. **Query Processing**: Encode text queries with same model
|
||||
5. **Similarity Search**: Find most relevant pages/regions
|
||||
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||
|
||||
### Device Support
|
||||
- **CUDA**: Best performance with GPU acceleration
|
||||
- **MPS**: Apple Silicon Mac support
|
||||
- **CPU**: Fallback for any system (slower)
|
||||
|
||||
Auto-detection: CUDA > MPS > CPU
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
### For Best Performance:
|
||||
```bash
|
||||
# Use ColQwen2 for latest features
|
||||
--model colqwen2
|
||||
|
||||
# Save page images for reuse
|
||||
--pages-dir ./cached_pages/
|
||||
|
||||
# Adjust batch size based on GPU memory
|
||||
# (automatically handled)
|
||||
```
|
||||
|
||||
### For Large Document Sets:
|
||||
- Process PDFs in batches
|
||||
- Use SSD storage for index files
|
||||
- Consider using CUDA if available
|
||||
|
||||
## 🔗 Related Resources
|
||||
|
||||
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||
- **Pylate**: https://github.com/lightonai/pylate
|
||||
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### PDF Conversion Issues (macOS)
|
||||
```bash
|
||||
# Install poppler
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
- Reduce batch size (automatically handled)
|
||||
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||
- Process fewer PDFs at once
|
||||
|
||||
### Model Download Issues
|
||||
- Ensure internet connection for first run
|
||||
- Models are cached after first download
|
||||
- Use HuggingFace mirrors if needed
|
||||
|
||||
### Import Errors
|
||||
```bash
|
||||
# Ensure all dependencies installed
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
|
||||
# Check PyTorch installation
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
```
|
||||
|
||||
## 💡 Examples
|
||||
|
||||
### Research Paper Analysis
|
||||
```bash
|
||||
# Index your research papers
|
||||
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||
|
||||
# Ask research questions
|
||||
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||
```
|
||||
|
||||
### Document Q&A
|
||||
```bash
|
||||
# Index business documents
|
||||
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||
|
||||
# Interactive analysis
|
||||
python -m apps.colqwen_rag ask reports --interactive
|
||||
```
|
||||
|
||||
### Visual Analysis
|
||||
```bash
|
||||
# Generate similarity maps for specific queries
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
# Edit multi-vector-leann-similarity-map.py with your query
|
||||
python multi-vector-leann-similarity-map.py
|
||||
# Check ./figures/ for generated heatmaps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||
@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
1. **Install pre-commit tools**:
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
uv sync lint
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
@@ -85,6 +85,9 @@ Our pre-commit configuration includes:
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Install test tools only (no project runtime)
|
||||
uv sync --group test
|
||||
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
|
||||
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# AST-Aware Code chunking guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Enable AST chunking for mixed content (code + docs)
|
||||
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||
|
||||
# Specialized code repository indexing
|
||||
python -m apps.code_rag --repo-dir ./my_codebase
|
||||
|
||||
# Global CLI with AST support
|
||||
leann build my-code-index --docs ./src --use-ast-chunking
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install LEANN with AST chunking support
|
||||
uv pip install -e "."
|
||||
```
|
||||
|
||||
#### For normal users (PyPI install)
|
||||
- Use `pip install leann` or `uv pip install leann`.
|
||||
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||
|
||||
#### For developers (from source, editable)
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
uv sync
|
||||
```
|
||||
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||
|
||||
## Best Practices
|
||||
|
||||
### When to Use AST Chunking
|
||||
|
||||
✅ **Recommended for:**
|
||||
- Code repositories with multiple languages
|
||||
- Mixed documentation and code content
|
||||
- Complex codebases with deep function/class hierarchies
|
||||
- When working with Claude Code for code assistance
|
||||
|
||||
❌ **Not recommended for:**
|
||||
- Pure text documents
|
||||
- Very large files (>1MB)
|
||||
- Languages not supported by tree-sitter
|
||||
|
||||
### Optimal Configuration
|
||||
|
||||
```bash
|
||||
# Recommended settings for most codebases
|
||||
python -m apps.code_rag \
|
||||
--repo-dir ./src \
|
||||
--ast-chunk-size 768 \
|
||||
--ast-chunk-overlap 96 \
|
||||
--exclude-dirs .git __pycache__ node_modules build dist
|
||||
```
|
||||
|
||||
### Supported Languages
|
||||
|
||||
| Extension | Language | Status |
|
||||
|-----------|----------|--------|
|
||||
| `.py` | Python | ✅ Full support |
|
||||
| `.java` | Java | ✅ Full support |
|
||||
| `.cs` | C# | ✅ Full support |
|
||||
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### Document RAG with Code Support
|
||||
|
||||
```python
|
||||
# Enable code chunking in document RAG
|
||||
python -m apps.document_rag \
|
||||
--enable-code-chunking \
|
||||
--data-dir ./project \
|
||||
--query "How does authentication work in the codebase?"
|
||||
```
|
||||
|
||||
### Claude Code Integration
|
||||
|
||||
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||
- Code completion and suggestions
|
||||
- Bug analysis and debugging
|
||||
- Architecture understanding
|
||||
- Refactoring assistance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Fallback to Traditional Chunking**
|
||||
- Normal behavior for unsupported languages
|
||||
- Check logs for specific language support
|
||||
|
||||
2. **Performance with Large Files**
|
||||
- Adjust `--max-file-size` parameter
|
||||
- Use `--exclude-dirs` to skip unnecessary directories
|
||||
|
||||
3. **Quality Issues**
|
||||
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||
- Adjust overlap for better context preservation
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```bash
|
||||
export LEANN_LOG_LEVEL=DEBUG
|
||||
python -m apps.code_rag --repo-dir ./my_code
|
||||
```
|
||||
|
||||
## Migration from Traditional Chunking
|
||||
|
||||
Existing workflows continue to work without changes. To enable AST chunking:
|
||||
|
||||
```bash
|
||||
# Before
|
||||
python -m apps.document_rag --chunk-size 256
|
||||
|
||||
# After (maintains traditional chunking for non-code files)
|
||||
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||
|
||||
---
|
||||
|
||||
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||
@@ -83,6 +83,170 @@ ollama pull nomic-embed-text
|
||||
|
||||
</details>
|
||||
|
||||
## Local & Remote Inference Endpoints
|
||||
|
||||
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||
|
||||
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||
|
||||
### One-Time Environment Setup
|
||||
|
||||
```bash
|
||||
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||
|
||||
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||
```
|
||||
|
||||
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||
|
||||
### Passing Hosts Per Command
|
||||
|
||||
```bash
|
||||
# Build an index with a remote embedding server
|
||||
leann build my-notes \
|
||||
--docs ./notes \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||
--embedding-api-key local-dev-key
|
||||
|
||||
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||
leann ask my-notes \
|
||||
--llm openai \
|
||||
--llm-model qwen3-8b \
|
||||
--api-base http://localhost:1234/v1 \
|
||||
--api-key local-dev-key
|
||||
|
||||
# Query an Ollama instance running on another box
|
||||
leann ask my-notes \
|
||||
--llm ollama \
|
||||
--llm-model qwen3:14b \
|
||||
--host http://192.168.1.101:11434
|
||||
```
|
||||
|
||||
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||
|
||||
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||
- Configure router or cloud provider port forwarding.
|
||||
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||
|
||||
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||
|
||||
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||
|
||||
### Python API Usage
|
||||
|
||||
You can pass the same configuration from Python:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||
embedding_options={
|
||||
"base_url": "http://192.168.1.50:1234/v1",
|
||||
"api_key": "local-dev-key",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-notes", chunks)
|
||||
```
|
||||
|
||||
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||
|
||||
## Optional Embedding Features
|
||||
|
||||
### Task-Specific Prompt Templates
|
||||
|
||||
Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case:
|
||||
|
||||
- **Indexing documents**: `"title: none | text: "`
|
||||
- **Search queries**: `"task: search result | query: "`
|
||||
|
||||
LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag:
|
||||
|
||||
```bash
|
||||
# Build index with EmbeddingGemma (via LM Studio or Ollama)
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-embeddinggemma-300m-qat \
|
||||
--embedding-api-base http://localhost:1234/v1 \
|
||||
--embedding-prompt-template "title: none | text: " \
|
||||
--force
|
||||
|
||||
# Search with query-specific prompt
|
||||
leann search my-docs \
|
||||
--query "What is quantum computing?" \
|
||||
--embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- **Only use with compatible models**: EmbeddingGemma and similar task-specific models
|
||||
- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings
|
||||
- **Template is saved**: Build-time templates are saved to `.meta.json` for reference
|
||||
- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`)
|
||||
|
||||
**Python API:**
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder(
|
||||
embedding_mode="openai",
|
||||
embedding_model="text-embedding-embeddinggemma-300m-qat",
|
||||
embedding_options={
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_key": "lm-studio",
|
||||
"prompt_template": "title: none | text: ",
|
||||
},
|
||||
)
|
||||
builder.build_index("./indexes/my-docs", chunks)
|
||||
```
|
||||
|
||||
**References:**
|
||||
- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details
|
||||
|
||||
### LM Studio Auto-Detection (Optional)
|
||||
|
||||
When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits.
|
||||
|
||||
**Prerequisites:**
|
||||
```bash
|
||||
# Install Node.js (if not already installed)
|
||||
# Then install the LM Studio SDK globally
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL)
|
||||
2. Queries model metadata via Node.js subprocess
|
||||
3. Automatically unloads model after query (respects your JIT auto-evict settings)
|
||||
4. Falls back to static registry if SDK unavailable
|
||||
|
||||
**No configuration needed** - it works automatically when SDK is installed:
|
||||
|
||||
```bash
|
||||
leann build my-docs \
|
||||
--docs ./documents \
|
||||
--embedding-mode openai \
|
||||
--embedding-model text-embedding-nomic-embed-text-v1.5 \
|
||||
--embedding-api-base http://localhost:1234/v1
|
||||
# Context length auto-detected if SDK available
|
||||
# Falls back to registry (2048) if not
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- ✅ Automatic token limit detection
|
||||
- ✅ Respects LM Studio JIT auto-evict settings
|
||||
- ✅ No manual registry maintenance
|
||||
- ✅ Graceful fallback if SDK unavailable
|
||||
|
||||
**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry.
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
@@ -290,7 +454,7 @@ leann search my-index "your query" \
|
||||
|
||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://docs.skypilot.co/en/latest/docs/index.html). A template is provided at `sky/leann-build.yaml`.
|
||||
|
||||
```bash
|
||||
# One-time: install and configure SkyPilot
|
||||
@@ -380,5 +544,5 @@ Conclusion:
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [DiskANN Original Paper](https://suhasjs.github.io/files/diskann_neurips19.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
|
||||
48
docs/faq.md
@@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
## 2. When should I use prompt templates?
|
||||
|
||||
**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries.
|
||||
|
||||
**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings.
|
||||
|
||||
**Example usage with EmbeddingGemma:**
|
||||
```bash
|
||||
# Build with document prompt
|
||||
leann build my-docs --embedding-prompt-template "title: none | text: "
|
||||
|
||||
# Search with query prompt
|
||||
leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: "
|
||||
```
|
||||
|
||||
See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage.
|
||||
|
||||
## 3. Why is LM Studio loading multiple copies of my model?
|
||||
|
||||
This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings.
|
||||
|
||||
**If you still see duplicates:**
|
||||
- Update to the latest LEANN version
|
||||
- Restart LM Studio to clear loaded models
|
||||
- Check that you have JIT auto-evict enabled in LM Studio settings
|
||||
|
||||
**How it works now:**
|
||||
1. LEANN loads model temporarily to get context length
|
||||
2. Immediately unloads after query
|
||||
3. LM Studio JIT loads model on-demand for actual embeddings
|
||||
4. Auto-evicts per your settings
|
||||
|
||||
## 4. Do I need Node.js and @lmstudio/sdk?
|
||||
|
||||
**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry.
|
||||
|
||||
**Benefits if you install it:**
|
||||
- Automatic context length detection for LM Studio models
|
||||
- No manual registry maintenance
|
||||
- Always gets accurate token limits from the model itself
|
||||
|
||||
**To install (optional):**
|
||||
```bash
|
||||
npm install -g @lmstudio/sdk
|
||||
```
|
||||
|
||||
See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
## 🔥 Core Features
|
||||
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||
|
||||
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
||||
# LEANN Grep Search Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Simple Grep Search
|
||||
|
||||
```python
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("your_index_path")
|
||||
|
||||
# Exact text search
|
||||
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||
|
||||
for result in results:
|
||||
print(f"Score: {result.score}")
|
||||
print(f"Text: {result.text[:100]}...")
|
||||
print("-" * 40)
|
||||
```
|
||||
|
||||
### Comparison: Semantic vs Grep Search
|
||||
|
||||
```python
|
||||
# Semantic search - finds conceptually similar content
|
||||
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||
|
||||
# Grep search - finds exact text matches
|
||||
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||
```
|
||||
|
||||
## When to Use Grep Search
|
||||
|
||||
### Use Cases
|
||||
|
||||
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||
- **Error Debugging**: Locating exact error messages or stack traces
|
||||
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||
|
||||
### Examples
|
||||
|
||||
```python
|
||||
# Find function definitions
|
||||
functions = searcher.search("def __init__", use_grep=True)
|
||||
|
||||
# Find import statements
|
||||
imports = searcher.search("from sklearn import", use_grep=True)
|
||||
|
||||
# Find specific error types
|
||||
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||
|
||||
# Find TODO comments
|
||||
todos = searcher.search("TODO:", use_grep=True)
|
||||
|
||||
# Find configuration entries
|
||||
configs = searcher.search("server_port=", use_grep=True)
|
||||
```
|
||||
|
||||
## Technical Details
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||
|
||||
### Search Process
|
||||
|
||||
```
|
||||
Query: "def train_model"
|
||||
↓
|
||||
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||
↓
|
||||
Parse matching JSON lines
|
||||
↓
|
||||
Calculate scores based on term frequency
|
||||
↓
|
||||
Return top_k results
|
||||
```
|
||||
|
||||
### Scoring Algorithm
|
||||
|
||||
```python
|
||||
# Term frequency in document
|
||||
score = text.lower().count(query.lower())
|
||||
```
|
||||
|
||||
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### Grep Command Not Found
|
||||
```
|
||||
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||
```
|
||||
|
||||
**Solution**: Install grep on your system:
|
||||
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||
- **macOS**: grep is pre-installed
|
||||
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||
|
||||
#### No Results Found
|
||||
```python
|
||||
# Check if your query exists in the raw data
|
||||
results = searcher.search("your_query", use_grep=True)
|
||||
if not results:
|
||||
print("No exact matches found. Try:")
|
||||
print("1. Check spelling and case")
|
||||
print("2. Use partial terms")
|
||||
print("3. Switch to semantic search")
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Grep Search Example
|
||||
Demonstrates grep search for exact text matching.
|
||||
"""
|
||||
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
def demonstrate_grep_search():
|
||||
# Initialize searcher
|
||||
searcher = LeannSearcher("my_index")
|
||||
|
||||
print("=== Function Search ===")
|
||||
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||
for i, result in enumerate(functions, 1):
|
||||
print(f"{i}. Score: {result.score}")
|
||||
print(f" Preview: {result.text[:60]}...")
|
||||
print()
|
||||
|
||||
print("=== Error Search ===")
|
||||
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||
for result in errors:
|
||||
print(f"Content: {result.text.strip()}")
|
||||
print("-" * 40)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demonstrate_grep_search()
|
||||
```
|
||||
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
||||
# LEANN Metadata Filtering Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Adding Metadata to Your Documents
|
||||
|
||||
When building your index, add metadata to each text chunk:
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
builder = LeannBuilder("hnsw")
|
||||
|
||||
# Add text with metadata
|
||||
builder.add_text(
|
||||
text="Chapter 1: Alice falls down the rabbit hole",
|
||||
metadata={
|
||||
"chapter": 1,
|
||||
"character": "Alice",
|
||||
"themes": ["adventure", "curiosity"],
|
||||
"word_count": 150
|
||||
}
|
||||
)
|
||||
|
||||
builder.build_index("alice_in_wonderland_index")
|
||||
```
|
||||
|
||||
### Searching with Metadata Filters
|
||||
|
||||
Use the `metadata_filters` parameter in search calls:
|
||||
|
||||
```python
|
||||
from leann.api import LeannSearcher
|
||||
|
||||
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||
|
||||
# Search with filters
|
||||
results = searcher.search(
|
||||
query="What happens to Alice?",
|
||||
top_k=10,
|
||||
metadata_filters={
|
||||
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Filter Syntax
|
||||
|
||||
### Basic Structure
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"field_name": {"operator": value},
|
||||
"another_field": {"operator": value}
|
||||
}
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
#### Comparison Operators
|
||||
- `"=="`: Equal to
|
||||
- `"!="`: Not equal to
|
||||
- `"<"`: Less than
|
||||
- `"<="`: Less than or equal
|
||||
- `">"`: Greater than
|
||||
- `">="`: Greater than or equal
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||
{"page": {">": 100}} # Pages after 100
|
||||
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||
{"word_count": {"<": 500}} # Short passages
|
||||
```
|
||||
|
||||
#### Membership Operators
|
||||
- `"in"`: Value is in list
|
||||
- `"not_in"`: Value is not in list
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||
```
|
||||
|
||||
#### String Operators
|
||||
- `"contains"`: String contains substring
|
||||
- `"starts_with"`: String starts with prefix
|
||||
- `"ends_with"`: String ends with suffix
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||
{"filename": {"ends_with": ".py"}} # Python files
|
||||
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||
```
|
||||
|
||||
#### Boolean Operators
|
||||
- `"is_true"`: Field is truthy
|
||||
- `"is_false"`: Field is falsy
|
||||
|
||||
```python
|
||||
# Examples
|
||||
{"is_published": {"is_true": True}} # Published content
|
||||
{"is_draft": {"is_false": False}} # Not drafts
|
||||
```
|
||||
|
||||
### Multiple Operators on Same Field
|
||||
|
||||
You can apply multiple operators to the same field (AND logic):
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"word_count": {
|
||||
">=": 100, # At least 100 words
|
||||
"<=": 500 # At most 500 words
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Compound Filters
|
||||
|
||||
Multiple fields are combined with AND logic:
|
||||
|
||||
```python
|
||||
metadata_filters = {
|
||||
"chapter": {"<=": 10}, # Up to chapter 10
|
||||
"character": {"==": "Alice"}, # About Alice
|
||||
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||
}
|
||||
```
|
||||
|
||||
## Use Case Examples
|
||||
|
||||
### 1. Spoiler-Free Book Search
|
||||
|
||||
```python
|
||||
# Reader has only read up to chapter 5
|
||||
def search_spoiler_free(query, max_chapter):
|
||||
return searcher.search(
|
||||
query=query,
|
||||
metadata_filters={
|
||||
"chapter": {"<=": max_chapter},
|
||||
"spoiler_level": {"in": ["none", "low"]}
|
||||
}
|
||||
)
|
||||
|
||||
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||
```
|
||||
|
||||
### 2. Document Management by Date
|
||||
|
||||
```python
|
||||
# Find recent documents
|
||||
recent_docs = searcher.search(
|
||||
query="project updates",
|
||||
metadata_filters={
|
||||
"date": {">=": "2024-01-01"},
|
||||
"document_type": {"==": "report"}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Code Search by File Type
|
||||
|
||||
```python
|
||||
# Search only Python files
|
||||
python_code = searcher.search(
|
||||
query="authentication function",
|
||||
metadata_filters={
|
||||
"file_extension": {"==": ".py"},
|
||||
"lines_of_code": {"<": 100}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Content Filtering by Audience
|
||||
|
||||
```python
|
||||
# Age-appropriate content
|
||||
family_content = searcher.search(
|
||||
query="adventure stories",
|
||||
metadata_filters={
|
||||
"age_rating": {"in": ["G", "PG"]},
|
||||
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Multi-Book Series Management
|
||||
|
||||
```python
|
||||
# Search across first 3 books only
|
||||
early_series = searcher.search(
|
||||
query="character development",
|
||||
metadata_filters={
|
||||
"series": {"==": "Harry Potter"},
|
||||
"book_number": {"<=": 3}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Running the Example
|
||||
|
||||
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||
|
||||
```bash
|
||||
# Don't forget to set up the environment
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Run the spoiler-free book RAG example
|
||||
uv run examples/spoiler_free_book_rag.py
|
||||
```
|
||||
|
||||
This example demonstrates:
|
||||
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||
- Different scenarios for readers at various points in the book
|
||||
|
||||
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Custom Chunking with metadata
|
||||
|
||||
```python
|
||||
def chunk_book_with_metadata(book_text, book_info):
|
||||
chunks = []
|
||||
|
||||
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||
# Extract entities, themes, etc.
|
||||
characters = extract_characters(chapter_text)
|
||||
themes = classify_themes(chapter_text)
|
||||
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||
|
||||
# Create chunks with rich metadata
|
||||
for paragraph in split_paragraphs(chapter_text):
|
||||
chunks.append({
|
||||
"text": paragraph,
|
||||
"metadata": {
|
||||
"book_title": book_info["title"],
|
||||
"chapter": chapter_num,
|
||||
"characters": characters,
|
||||
"themes": themes,
|
||||
"spoiler_level": spoiler_level,
|
||||
"word_count": len(paragraph.split()),
|
||||
"reading_level": calculate_reading_level(paragraph)
|
||||
}
|
||||
})
|
||||
|
||||
return chunks
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Efficient Filtering Strategies
|
||||
|
||||
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||
|
||||
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||
|
||||
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||
|
||||
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||
|
||||
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||
|
||||
### Adding Metadata to Existing Indices
|
||||
|
||||
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||
|
||||
```python
|
||||
# Read existing passages and add metadata
|
||||
def add_metadata_to_existing_chunks(chunks):
|
||||
for chunk in chunks:
|
||||
# Extract or assign metadata based on content
|
||||
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||
return chunks
|
||||
|
||||
# Rebuild index with metadata
|
||||
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||
builder = LeannBuilder("hnsw")
|
||||
for chunk in enhanced_chunks:
|
||||
builder.add_text(chunk["text"], chunk["metadata"])
|
||||
builder.build_index("enhanced_index")
|
||||
```
|
||||
395
docs/slack-setup-guide.md
Normal file
@@ -0,0 +1,395 @@
|
||||
# Slack Integration Setup Guide
|
||||
|
||||
This guide provides step-by-step instructions for setting up Slack integration with LEANN.
|
||||
|
||||
## Overview
|
||||
|
||||
LEANN's Slack integration uses MCP (Model Context Protocol) servers to fetch and index your Slack messages for RAG (Retrieval-Augmented Generation). This allows you to search through your Slack conversations using natural language queries.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Slack Workspace Access**: You need admin or owner permissions in your Slack workspace to create apps and configure OAuth tokens.
|
||||
|
||||
2. **Slack MCP Server**: Install a Slack MCP server (e.g., `slack-mcp-server` via npm)
|
||||
|
||||
3. **LEANN**: Ensure you have LEANN installed and working
|
||||
|
||||
## Step 1: Create a Slack App
|
||||
|
||||
### 1.1 Go to Slack API Dashboard
|
||||
|
||||
1. Visit [https://api.slack.com/apps](https://api.slack.com/apps)
|
||||
2. Click **"Create New App"**
|
||||
3. Choose **"From scratch"**
|
||||
4. Enter your app name (e.g., "LEANN Slack Integration")
|
||||
5. Select your workspace
|
||||
6. Click **"Create App"**
|
||||
|
||||
### 1.2 Configure App Permissions
|
||||
|
||||
#### Token Scopes
|
||||
|
||||
1. In your app dashboard, go to **"OAuth & Permissions"** in the left sidebar
|
||||
2. Scroll down to **"Scopes"** section
|
||||
3. Under **"Bot Token Scopes & OAuth Scope"**, click **"Add an OAuth Scope"**
|
||||
4. Add the following scopes:
|
||||
- `channels:read` - Read public channel information
|
||||
- `channels:history` - Read messages in public channels
|
||||
- `groups:read` - Read private channel information
|
||||
- `groups:history` - Read messages in private channels
|
||||
- `im:read` - Read direct message information
|
||||
- `im:history` - Read direct messages
|
||||
- `mpim:read` - Read group direct message information
|
||||
- `mpim:history` - Read group direct messages
|
||||
- `users:read` - Read user information
|
||||
- `team:read` - Read workspace information
|
||||
|
||||
#### App-Level Tokens (Optional)
|
||||
|
||||
Some MCP servers may require app-level tokens:
|
||||
|
||||
1. Go to **"Basic Information"** in the left sidebar
|
||||
2. Scroll down to **"App-Level Tokens"**
|
||||
3. Click **"Generate Token and Scopes"**
|
||||
4. Enter a name (e.g., "LEANN Integration")
|
||||
5. Add the `connections:write` scope
|
||||
6. Click **"Generate"**
|
||||
7. Copy the token (starts with `xapp-`)
|
||||
|
||||
### 1.3 Install App to Workspace
|
||||
|
||||
1. Go to **"OAuth & Permissions"** in the left sidebar
|
||||
2. Click **"Install to Workspace"**
|
||||
3. Review the permissions and click **"Allow"**
|
||||
4. Copy the **"Bot User OAuth Token"** (starts with `xoxb-`)
|
||||
5. Copy the **"User OAuth Token"** (starts with `xoxp-`)
|
||||
|
||||
## Step 2: Install Slack MCP Server
|
||||
|
||||
### Option A: Using npm (Recommended)
|
||||
|
||||
```bash
|
||||
# Install globally
|
||||
npm install -g slack-mcp-server
|
||||
|
||||
# Or install locally
|
||||
npm install slack-mcp-server
|
||||
```
|
||||
|
||||
### Option B: Using npx (No installation required)
|
||||
|
||||
```bash
|
||||
# Use directly without installation
|
||||
npx slack-mcp-server
|
||||
```
|
||||
|
||||
## Step 3: Install and Configure Ollama (for Real LLM Responses)
|
||||
|
||||
### 3.1 Install Ollama
|
||||
|
||||
```bash
|
||||
# Install Ollama using Homebrew (macOS)
|
||||
brew install ollama
|
||||
|
||||
# Or download from https://ollama.ai/
|
||||
```
|
||||
|
||||
### 3.2 Start Ollama Service
|
||||
|
||||
```bash
|
||||
# Start Ollama as a service
|
||||
brew services start ollama
|
||||
|
||||
# Or start manually
|
||||
ollama serve
|
||||
```
|
||||
|
||||
### 3.3 Pull a Model
|
||||
|
||||
```bash
|
||||
# Pull a lightweight model for testing
|
||||
ollama pull llama3.2:1b
|
||||
|
||||
# Verify the model is available
|
||||
ollama list
|
||||
```
|
||||
|
||||
## Step 4: Configure Environment Variables
|
||||
|
||||
Create a `.env` file or set environment variables:
|
||||
|
||||
```bash
|
||||
# Required: User OAuth Token
|
||||
SLACK_OAUTH_TOKEN=xoxp-your-user-oauth-token-here
|
||||
|
||||
# Optional: App-Level Token (if your MCP server requires it)
|
||||
SLACK_APP_TOKEN=xapp-your-app-token-here
|
||||
|
||||
# Optional: Workspace-specific settings
|
||||
SLACK_WORKSPACE_ID=T1234567890 # Your workspace ID (optional)
|
||||
```
|
||||
|
||||
## Step 5: Test the Setup
|
||||
|
||||
### 5.1 Test MCP Server Connection
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--test-connection \
|
||||
--workspace-name "Your Workspace Name"
|
||||
```
|
||||
|
||||
This will test the connection and list available tools without indexing any data.
|
||||
|
||||
### 5.2 Index a Specific Channel
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Your Workspace Name" \
|
||||
--channels general \
|
||||
--query "What did we discuss about the project?"
|
||||
```
|
||||
|
||||
### 5.3 Real RAG Query Examples
|
||||
|
||||
This section demonstrates successful Slack RAG integration queries against the Sky Lab Computing workspace's "random" channel. The system successfully retrieves actual conversation messages and performs semantic search with high relevance scores, including finding specific research paper announcements and technical discussions.
|
||||
|
||||
### Example 1: Advisor Models Query
|
||||
|
||||
**Query:** "train black-box models to adopt to your personal data"
|
||||
|
||||
This query demonstrates the system's ability to find specific research announcements about training black-box models for personal data adaptation.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Example 2: Barbarians at the Gate Query
|
||||
|
||||
**Query:** "AI-driven research systems ADRS"
|
||||
|
||||
This query demonstrates the system's ability to find specific research announcements about AI-driven research systems and algorithm discovery.
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Bot is installed in the Sky Lab Computing workspace and invited to the target channel (run `/invite @YourBotName` in the channel if needed)
|
||||
- Bot token available and exported in the same terminal session
|
||||
|
||||
### Commands
|
||||
|
||||
1) Set the workspace token for this shell
|
||||
|
||||
```bash
|
||||
export SLACK_MCP_XOXP_TOKEN="xoxp-***-redacted-***"
|
||||
```
|
||||
|
||||
2) Run queries against the "random" channel by channel ID (C0GN5BX0F)
|
||||
|
||||
**Advisor Models Query:**
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--channels C0GN5BX0F \
|
||||
--max-messages-per-channel 100000 \
|
||||
--query "train black-box models to adopt to your personal data" \
|
||||
--llm ollama \
|
||||
--llm-model "llama3.2:1b" \
|
||||
--llm-host "http://localhost:11434" \
|
||||
--no-concatenate-conversations
|
||||
```
|
||||
|
||||
**Barbarians at the Gate Query:**
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--channels C0GN5BX0F \
|
||||
--max-messages-per-channel 100000 \
|
||||
--query "AI-driven research systems ADRS" \
|
||||
--llm ollama \
|
||||
--llm-model "llama3.2:1b" \
|
||||
--llm-host "http://localhost:11434" \
|
||||
--no-concatenate-conversations
|
||||
```
|
||||
|
||||
These examples demonstrate the system's ability to find and retrieve specific research announcements and technical discussions from the conversation history, showcasing the power of semantic search in Slack data.
|
||||
|
||||
3) Optional: Ask a broader question
|
||||
|
||||
```bash
|
||||
python test_channel_by_id_or_name.py \
|
||||
--channel-id C0GN5BX0F \
|
||||
--workspace-name "Sky Lab Computing" \
|
||||
--query "What is LEANN about?"
|
||||
```
|
||||
|
||||
Notes:
|
||||
- If you see `not_in_channel`, invite the bot to the channel and re-run.
|
||||
- If you see `channel_not_found`, confirm the channel ID and workspace.
|
||||
- Deep search via server-side “search” tools may require additional Slack scopes; the example above performs client-side filtering over retrieved history.
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue 1: "users cache is not ready yet" Error
|
||||
|
||||
**Problem**: You see this warning:
|
||||
```
|
||||
WARNING - Failed to fetch messages from channel random: Failed to fetch messages: {'code': -32603, 'message': 'users cache is not ready yet, sync process is still running... please wait'}
|
||||
```
|
||||
|
||||
**Solution**: This is a common timing issue. The LEANN integration now includes automatic retry logic:
|
||||
|
||||
1. **Wait and Retry**: The system will automatically retry with exponential backoff (2s, 4s, 8s, etc.)
|
||||
2. **Increase Retry Parameters**: If needed, you can customize retry behavior:
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--max-retries 10 \
|
||||
--retry-delay 3.0 \
|
||||
--channels general \
|
||||
--query "Your query here"
|
||||
```
|
||||
3. **Keep MCP Server Running**: Start the MCP server separately and keep it running:
|
||||
```bash
|
||||
# Terminal 1: Start MCP server
|
||||
slack-mcp-server
|
||||
|
||||
# Terminal 2: Run LEANN (it will connect to the running server)
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --channels general --query "test"
|
||||
```
|
||||
|
||||
### Issue 2: "No message fetching tool found"
|
||||
|
||||
**Problem**: The MCP server doesn't have the expected tools.
|
||||
|
||||
**Solution**:
|
||||
1. Check if your MCP server is properly installed and configured
|
||||
2. Verify your Slack tokens are correct
|
||||
3. Try a different MCP server implementation
|
||||
4. Check the MCP server documentation for required configuration
|
||||
|
||||
### Issue 3: Permission Denied Errors
|
||||
|
||||
**Problem**: You get permission errors when trying to access channels.
|
||||
|
||||
**Solutions**:
|
||||
1. **Check Bot Permissions**: Ensure your bot has been added to the channels you want to access
|
||||
2. **Verify Token Scopes**: Make sure you have all required scopes configured
|
||||
3. **Channel Access**: For private channels, the bot needs to be explicitly invited
|
||||
4. **Workspace Permissions**: Ensure your Slack app has the necessary workspace permissions
|
||||
|
||||
### Issue 4: Empty Results
|
||||
|
||||
**Problem**: No messages are returned even though the channel has messages.
|
||||
|
||||
**Solutions**:
|
||||
1. **Check Channel Names**: Ensure channel names are correct (without the # symbol)
|
||||
2. **Verify Bot Access**: Make sure the bot can access the channels
|
||||
3. **Check Date Ranges**: Some MCP servers have limitations on message history
|
||||
4. **Increase Message Limits**: Try increasing the message limit:
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--channels general \
|
||||
--max-messages-per-channel 1000 \
|
||||
--query "test"
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom MCP Server Commands
|
||||
|
||||
If you need to pass additional parameters to your MCP server:
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server --token-file /path/to/tokens.json" \
|
||||
--workspace-name "Your Workspace" \
|
||||
--channels general \
|
||||
--query "Your query"
|
||||
```
|
||||
|
||||
### Multiple Workspaces
|
||||
|
||||
To work with multiple Slack workspaces, you can:
|
||||
|
||||
1. Create separate apps for each workspace
|
||||
2. Use different environment variables
|
||||
3. Run separate instances with different configurations
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
For better performance with large workspaces:
|
||||
|
||||
```bash
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "Your Workspace" \
|
||||
--max-messages-per-channel 500 \
|
||||
--no-concatenate-conversations \
|
||||
--query "Your query"
|
||||
```
|
||||
---
|
||||
|
||||
## Troubleshooting Checklist
|
||||
|
||||
- [ ] Slack app created with proper permissions
|
||||
- [ ] Bot token (xoxb-) copied correctly
|
||||
- [ ] App-level token (xapp-) created if needed
|
||||
- [ ] MCP server installed and accessible
|
||||
- [ ] Ollama installed and running (`brew services start ollama`)
|
||||
- [ ] Ollama model pulled (`ollama pull llama3.2:1b`)
|
||||
- [ ] Environment variables set correctly
|
||||
- [ ] Bot invited to relevant channels
|
||||
- [ ] Channel names specified without # symbol
|
||||
- [ ] Sufficient retry attempts configured
|
||||
- [ ] Network connectivity to Slack APIs
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you continue to have issues:
|
||||
|
||||
1. **Check Logs**: Look for detailed error messages in the console output
|
||||
2. **Test MCP Server**: Use `--test-connection` to verify the MCP server is working
|
||||
3. **Verify Tokens**: Double-check that your Slack tokens are valid and have the right scopes
|
||||
4. **Check Ollama**: Ensure Ollama is running (`ollama serve`) and the model is available (`ollama list`)
|
||||
5. **Community Support**: Reach out to the LEANN community for help
|
||||
|
||||
## Example Commands
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# Test connection
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --test-connection
|
||||
|
||||
# Index specific channels
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "My Company" \
|
||||
--channels general random \
|
||||
--query "What did we decide about the project timeline?"
|
||||
```
|
||||
|
||||
### Advanced Usage
|
||||
```bash
|
||||
# With custom retry settings
|
||||
python -m apps.slack_rag \
|
||||
--mcp-server "slack-mcp-server" \
|
||||
--workspace-name "My Company" \
|
||||
--channels general \
|
||||
--max-retries 10 \
|
||||
--retry-delay 5.0 \
|
||||
--max-messages-per-channel 2000 \
|
||||
--query "Show me all decisions made in the last month"
|
||||
```
|
||||
BIN
docs/videos/slack_integration_1.1.png
Normal file
|
After Width: | Height: | Size: 445 KiB |
BIN
docs/videos/slack_integration_1.2.png
Normal file
|
After Width: | Height: | Size: 508 KiB |
BIN
docs/videos/slack_integration_1.3.png
Normal file
|
After Width: | Height: | Size: 437 KiB |
BIN
docs/videos/slack_integration_2.1.png
Normal file
|
After Width: | Height: | Size: 474 KiB |
BIN
docs/videos/slack_integration_2.2.png
Normal file
|
After Width: | Height: | Size: 501 KiB |
BIN
docs/videos/slack_integration_2.3.png
Normal file
|
After Width: | Height: | Size: 454 KiB |
0
examples/__init__.py
Normal file
429
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Dynamic HNSW update demo without compact storage.
|
||||
|
||||
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||
recompute:
|
||||
|
||||
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||
2. Print the top results with `recompute_embeddings=True`.
|
||||
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||
4. Run the same query again to show the newly inserted passages.
|
||||
|
||||
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||
ZMQ activity)::
|
||||
|
||||
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||
uv run -m examples.dynamic_update_no_recompute \
|
||||
--index-path .leann/examples/leann-demo.leann
|
||||
|
||||
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||
It issues the query "What's LEANN?" before and after the update to show how the
|
||||
new passages become immediately searchable. The script uses the
|
||||
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||
freshly added passages are embedded locally just like the initial build.
|
||||
|
||||
To make storage comparisons easy, the script can also build a matching
|
||||
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||
delta after the update. Disable the baseline run with
|
||||
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
from apps.chunking import create_text_chunks
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
DEFAULT_QUERY = "What's LEANN?"
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
REPO_ROOT / "data" / "PrideandPrejudice.txt",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
|
||||
|
||||
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||
searcher = LeannSearcher(str(index_path))
|
||||
try:
|
||||
return searcher.search(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
batch_size=16,
|
||||
)
|
||||
finally:
|
||||
searcher.cleanup()
|
||||
|
||||
|
||||
def print_results(title: str, results: Iterable) -> None:
|
||||
print(f"\n=== {title} ===")
|
||||
res_list = list(results)
|
||||
print(f"results count: {len(res_list)}")
|
||||
print("passages:")
|
||||
if not res_list:
|
||||
print(" (no passages returned)")
|
||||
for res in res_list:
|
||||
snippet = res.text.replace("\n", " ")[:120]
|
||||
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def update_index(
|
||||
index_path: Path,
|
||||
start_id: int,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
) -> None:
|
||||
updater = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||
updater.add_text(passage, metadata={"id": str(offset)})
|
||||
updater.update_index(str(index_path))
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
"""Remove leftover index artifacts for a clean rebuild."""
|
||||
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def index_file_size(index_path: Path) -> int:
|
||||
"""Return the size of the primary .index file for the given index path."""
|
||||
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
return index_file.stat().st_size if index_file.exists() else 0
|
||||
|
||||
|
||||
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(meta_path.read_text())
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def run_workflow(
|
||||
*,
|
||||
label: str,
|
||||
index_path: Path,
|
||||
initial_paragraphs: list[str],
|
||||
update_paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
is_recompute: bool,
|
||||
query: str,
|
||||
top_k: int,
|
||||
skip_search: bool,
|
||||
) -> dict[str, Any]:
|
||||
prefix = f"[{label}] " if label else ""
|
||||
|
||||
ensure_index_dir(index_path)
|
||||
cleanup_index_files(index_path)
|
||||
|
||||
print(f"{prefix}Building initial index...")
|
||||
build_initial_index(
|
||||
index_path,
|
||||
initial_paragraphs,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
initial_size = index_file_size(index_path)
|
||||
if not skip_search:
|
||||
before_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
else:
|
||||
before_results = None
|
||||
|
||||
print(f"\n{prefix}Updating index with additional passages...")
|
||||
update_index(
|
||||
index_path,
|
||||
start_id=len(initial_paragraphs),
|
||||
paragraphs=update_paragraphs,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_recompute=is_recompute,
|
||||
)
|
||||
|
||||
if not skip_search:
|
||||
after_results = run_search(
|
||||
index_path,
|
||||
query,
|
||||
top_k,
|
||||
recompute_embeddings=is_recompute,
|
||||
)
|
||||
else:
|
||||
after_results = None
|
||||
updated_size = index_file_size(index_path)
|
||||
|
||||
return {
|
||||
"initial_size": initial_size,
|
||||
"updated_size": updated_size,
|
||||
"delta": updated_size - initial_size,
|
||||
"before_results": before_results if not skip_search else None,
|
||||
"after_results": after_results if not skip_search else None,
|
||||
"metadata": load_metadata_snapshot(index_path),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
type=Path,
|
||||
nargs="+",
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Initial document files (PDF/TXT) used to build the base index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/examples/leann-demo.leann"),
|
||||
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-count",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of chunks to use from the initial documents (default: 8)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
type=Path,
|
||||
nargs="*",
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Additional documents to add during update (PDF/TXT)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-count",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of chunks to append from update documents (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-text",
|
||||
type=str,
|
||||
default=(
|
||||
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||
"on demand to keep disk usage minimal."
|
||||
),
|
||||
help="Fallback text to append if --update-files is omitted",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of results to show for each search (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=DEFAULT_QUERY,
|
||||
help="Query to run before/after the update",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare-no-recompute",
|
||||
dest="compare_no_recompute",
|
||||
action="store_true",
|
||||
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-compare-no-recompute",
|
||||
dest="compare_no_recompute",
|
||||
action="store_false",
|
||||
help="Skip building the no-recompute baseline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-search",
|
||||
dest="skip_search",
|
||||
action="store_true",
|
||||
help="Skip the search step.",
|
||||
)
|
||||
parser.set_defaults(compare_no_recompute=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||
if not initial_chunks:
|
||||
raise ValueError("No text chunks extracted from the initial files.")
|
||||
|
||||
initial = initial_chunks[: args.initial_count]
|
||||
if not initial:
|
||||
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||
|
||||
if args.update_files:
|
||||
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||
if not update_chunks:
|
||||
raise ValueError("No text chunks extracted from the update files.")
|
||||
to_add = update_chunks[: args.update_count]
|
||||
else:
|
||||
if not args.update_text:
|
||||
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||
to_add = [args.update_text]
|
||||
if not to_add:
|
||||
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||
|
||||
recompute_stats = run_workflow(
|
||||
label="recompute",
|
||||
index_path=args.index_path,
|
||||
initial_paragraphs=initial,
|
||||
update_paragraphs=to_add,
|
||||
model_name=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
is_recompute=True,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
skip_search=args.skip_search,
|
||||
)
|
||||
|
||||
if not args.skip_search:
|
||||
print_results("initial search", recompute_stats["before_results"])
|
||||
if not args.skip_search:
|
||||
print_results("after update", recompute_stats["after_results"])
|
||||
print(
|
||||
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||
f" (Δ {recompute_stats['delta']})"
|
||||
)
|
||||
|
||||
if recompute_stats["metadata"]:
|
||||
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||
print("[recompute] metadata snapshot:")
|
||||
print(json.dumps(meta_view, indent=2))
|
||||
|
||||
if args.compare_no_recompute:
|
||||
baseline_path = (
|
||||
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||
)
|
||||
baseline_stats = run_workflow(
|
||||
label="no-recompute",
|
||||
index_path=baseline_path,
|
||||
initial_paragraphs=initial,
|
||||
update_paragraphs=to_add,
|
||||
model_name=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
is_recompute=False,
|
||||
query=args.query,
|
||||
top_k=args.top_k,
|
||||
skip_search=args.skip_search,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||
f" (Δ {baseline_stats['delta']})"
|
||||
)
|
||||
|
||||
after_texts = (
|
||||
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
|
||||
)
|
||||
baseline_after_texts = (
|
||||
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
|
||||
)
|
||||
if after_texts == baseline_after_texts:
|
||||
print(
|
||||
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||
)
|
||||
else:
|
||||
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||
|
||||
if baseline_stats["metadata"]:
|
||||
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||
print("[no-recompute] metadata snapshot:")
|
||||
print(json.dumps(meta_view, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Grep Search Example
|
||||
|
||||
Shows how to use grep-based text search instead of semantic search.
|
||||
Useful when you need exact text matches rather than meaning-based results.
|
||||
"""
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
# Load your index
|
||||
searcher = LeannSearcher("my-documents.leann")
|
||||
|
||||
# Regular semantic search
|
||||
print("=== Semantic Search ===")
|
||||
results = searcher.search("machine learning algorithms", top_k=3)
|
||||
for result in results:
|
||||
print(f"Score: {result.score:.3f}")
|
||||
print(f"Text: {result.text[:80]}...")
|
||||
print()
|
||||
|
||||
# Grep-based search for exact text matches
|
||||
print("=== Grep Search ===")
|
||||
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||
for result in results:
|
||||
print(f"Score: {result.score}")
|
||||
print(f"Text: {result.text[:80]}...")
|
||||
print()
|
||||
|
||||
# Find specific error messages
|
||||
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||
|
||||
# Search for function definitions
|
||||
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||
print(f"Found {len(func_results)} class definitions")
|
||||
178
examples/mcp_integration_demo.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MCP Integration Examples for LEANN
|
||||
|
||||
This script demonstrates how to use LEANN with different MCP servers for
|
||||
RAG on various platforms like Slack and Twitter.
|
||||
|
||||
Examples:
|
||||
1. Slack message RAG via MCP
|
||||
2. Twitter bookmark RAG via MCP
|
||||
3. Testing MCP server connections
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the parent directory to the path so we can import from apps
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def demo_slack_mcp():
|
||||
"""Demonstrate Slack MCP integration."""
|
||||
print("=" * 60)
|
||||
print("🔥 Slack MCP RAG Demo")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing Slack MCP server connection...")
|
||||
|
||||
# This would typically use a real MCP server command
|
||||
# For demo purposes, we show what the command would look like
|
||||
# slack_app = SlackMCPRAG() # Would be used for actual testing
|
||||
|
||||
# Simulate command line arguments for testing
|
||||
class MockArgs:
|
||||
mcp_server = "slack-mcp-server" # This would be the actual MCP server command
|
||||
workspace_name = "my-workspace"
|
||||
channels = ["general", "random", "dev-team"]
|
||||
no_concatenate_conversations = False
|
||||
max_messages_per_channel = 50
|
||||
test_connection = True
|
||||
|
||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
||||
print(f"Workspace: {MockArgs.workspace_name}")
|
||||
print(f"Channels: {', '.join(MockArgs.channels)}")
|
||||
|
||||
# In a real scenario, you would run:
|
||||
# success = await slack_app.test_mcp_connection(MockArgs)
|
||||
|
||||
print("\n📝 Example usage:")
|
||||
print("python -m apps.slack_rag \\")
|
||||
print(" --mcp-server 'slack-mcp-server' \\")
|
||||
print(" --workspace-name 'my-team' \\")
|
||||
print(" --channels general dev-team \\")
|
||||
print(" --test-connection")
|
||||
|
||||
print("\n🔍 After indexing, you could query:")
|
||||
print("- 'What did the team discuss about the project deadline?'")
|
||||
print("- 'Find messages about the new feature launch'")
|
||||
print("- 'Show me conversations about budget planning'")
|
||||
|
||||
|
||||
async def demo_twitter_mcp():
|
||||
"""Demonstrate Twitter MCP integration."""
|
||||
print("\n" + "=" * 60)
|
||||
print("🐦 Twitter MCP RAG Demo")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n1. Testing Twitter MCP server connection...")
|
||||
|
||||
# twitter_app = TwitterMCPRAG() # Would be used for actual testing
|
||||
|
||||
class MockArgs:
|
||||
mcp_server = "twitter-mcp-server"
|
||||
username = None # Fetch all bookmarks
|
||||
max_bookmarks = 500
|
||||
no_tweet_content = False
|
||||
no_metadata = False
|
||||
test_connection = True
|
||||
|
||||
print(f"MCP Server Command: {MockArgs.mcp_server}")
|
||||
print(f"Max Bookmarks: {MockArgs.max_bookmarks}")
|
||||
print(f"Include Content: {not MockArgs.no_tweet_content}")
|
||||
print(f"Include Metadata: {not MockArgs.no_metadata}")
|
||||
|
||||
print("\n📝 Example usage:")
|
||||
print("python -m apps.twitter_rag \\")
|
||||
print(" --mcp-server 'twitter-mcp-server' \\")
|
||||
print(" --max-bookmarks 1000 \\")
|
||||
print(" --test-connection")
|
||||
|
||||
print("\n🔍 After indexing, you could query:")
|
||||
print("- 'What AI articles did I bookmark last month?'")
|
||||
print("- 'Find tweets about machine learning techniques'")
|
||||
print("- 'Show me bookmarked threads about startup advice'")
|
||||
|
||||
|
||||
async def show_mcp_server_setup():
|
||||
"""Show how to set up MCP servers."""
|
||||
print("\n" + "=" * 60)
|
||||
print("⚙️ MCP Server Setup Guide")
|
||||
print("=" * 60)
|
||||
|
||||
print("\n🔧 Setting up Slack MCP Server:")
|
||||
print("1. Install a Slack MCP server (example commands):")
|
||||
print(" npm install -g slack-mcp-server")
|
||||
print(" # OR")
|
||||
print(" pip install slack-mcp-server")
|
||||
|
||||
print("\n2. Configure Slack credentials:")
|
||||
print(" export SLACK_BOT_TOKEN='xoxb-your-bot-token'")
|
||||
print(" export SLACK_APP_TOKEN='xapp-your-app-token'")
|
||||
|
||||
print("\n3. Test the server:")
|
||||
print(" slack-mcp-server --help")
|
||||
|
||||
print("\n🔧 Setting up Twitter MCP Server:")
|
||||
print("1. Install a Twitter MCP server:")
|
||||
print(" npm install -g twitter-mcp-server")
|
||||
print(" # OR")
|
||||
print(" pip install twitter-mcp-server")
|
||||
|
||||
print("\n2. Configure Twitter API credentials:")
|
||||
print(" export TWITTER_API_KEY='your-api-key'")
|
||||
print(" export TWITTER_API_SECRET='your-api-secret'")
|
||||
print(" export TWITTER_ACCESS_TOKEN='your-access-token'")
|
||||
print(" export TWITTER_ACCESS_TOKEN_SECRET='your-access-token-secret'")
|
||||
|
||||
print("\n3. Test the server:")
|
||||
print(" twitter-mcp-server --help")
|
||||
|
||||
|
||||
async def show_integration_benefits():
|
||||
"""Show the benefits of MCP integration."""
|
||||
print("\n" + "=" * 60)
|
||||
print("🌟 Benefits of MCP Integration")
|
||||
print("=" * 60)
|
||||
|
||||
benefits = [
|
||||
("🔄 Live Data Access", "Fetch real-time data from platforms without manual exports"),
|
||||
("🔌 Standardized Protocol", "Use any MCP-compatible server with minimal code changes"),
|
||||
("🚀 Easy Extension", "Add new platforms by implementing MCP readers"),
|
||||
("🔒 Secure Access", "MCP servers handle authentication and API management"),
|
||||
("📊 Rich Metadata", "Access full platform metadata (timestamps, engagement, etc.)"),
|
||||
("⚡ Efficient Processing", "Stream data directly into LEANN without intermediate files"),
|
||||
]
|
||||
|
||||
for title, description in benefits:
|
||||
print(f"\n{title}")
|
||||
print(f" {description}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main demo function."""
|
||||
print("🎯 LEANN MCP Integration Examples")
|
||||
print("This demo shows how to integrate LEANN with MCP servers for various platforms.")
|
||||
|
||||
await demo_slack_mcp()
|
||||
await demo_twitter_mcp()
|
||||
await show_mcp_server_setup()
|
||||
await show_integration_benefits()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✨ Next Steps")
|
||||
print("=" * 60)
|
||||
print("1. Install and configure MCP servers for your platforms")
|
||||
print("2. Test connections using --test-connection flag")
|
||||
print("3. Run indexing to build your RAG knowledge base")
|
||||
print("4. Start querying your personal data!")
|
||||
|
||||
print("\n📚 For more information:")
|
||||
print("- Check the README for detailed setup instructions")
|
||||
print("- Look at the apps/slack_rag.py and apps/twitter_rag.py for implementation details")
|
||||
print("- Explore other MCP servers for additional platforms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||
|
||||
This example demonstrates how to use LEANN's metadata filtering to create
|
||||
a spoiler-free book RAG system where users can search for information
|
||||
up to a specific chapter they've read.
|
||||
|
||||
Usage:
|
||||
python spoiler_free_book_rag.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
# Add LEANN to path (adjust path as needed)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Create sample book chunks with metadata for demonstration.
|
||||
|
||||
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||
and extract chapter boundaries, character mentions, etc.
|
||||
|
||||
Args:
|
||||
book_title: Title of the book
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries with text and metadata
|
||||
"""
|
||||
# Sample book chunks with metadata
|
||||
# In practice, you'd use proper text processing libraries
|
||||
|
||||
sample_chunks = [
|
||||
{
|
||||
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 1,
|
||||
"page": 1,
|
||||
"characters": ["Alice", "Sister"],
|
||||
"themes": ["boredom", "curiosity"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 1,
|
||||
"page": 2,
|
||||
"characters": ["Alice", "White Rabbit"],
|
||||
"themes": ["decision", "surprise", "magic"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 2,
|
||||
"page": 15,
|
||||
"characters": ["Alice"],
|
||||
"themes": ["falling", "wonder", "transformation"],
|
||||
"location": "rabbit hole",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 6,
|
||||
"page": 85,
|
||||
"characters": ["Alice", "Cheshire Cat"],
|
||||
"themes": ["madness", "philosophy", "identity"],
|
||||
"location": "Duchess's house",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 8,
|
||||
"page": 120,
|
||||
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||
"themes": ["justice", "absurdity", "authority"],
|
||||
"location": "Queen's court",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||
"metadata": {
|
||||
"book": book_title,
|
||||
"chapter": 12,
|
||||
"page": 180,
|
||||
"characters": ["Alice", "Sister", "Rabbit"],
|
||||
"themes": ["revelation", "reality", "growth"],
|
||||
"location": "riverbank",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
return sample_chunks
|
||||
|
||||
|
||||
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||
"""
|
||||
Build a LEANN index with book chunks that include spoiler metadata.
|
||||
|
||||
Args:
|
||||
book_chunks: List of book chunks with metadata
|
||||
index_name: Name for the index
|
||||
|
||||
Returns:
|
||||
Path to the built index
|
||||
"""
|
||||
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||
|
||||
# Initialize LEANN builder
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||
)
|
||||
|
||||
# Add each chunk with its metadata
|
||||
for chunk in book_chunks:
|
||||
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||
|
||||
# Build the index
|
||||
index_path = f"{index_name}_book_index"
|
||||
builder.build_index(index_path)
|
||||
|
||||
print(f"✅ Index built successfully: {index_path}")
|
||||
return index_path
|
||||
|
||||
|
||||
def spoiler_free_search(
|
||||
index_path: str,
|
||||
query: str,
|
||||
max_chapter: int,
|
||||
character_filter: Optional[list[str]] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform a spoiler-free search on the book index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: Search query
|
||||
max_chapter: Maximum chapter number to include
|
||||
character_filter: Optional list of characters to focus on
|
||||
|
||||
Returns:
|
||||
List of search results safe for the reader
|
||||
"""
|
||||
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||
|
||||
if character_filter:
|
||||
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||
|
||||
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def demo_spoiler_free_rag():
|
||||
"""
|
||||
Demonstrate the spoiler-free book RAG system.
|
||||
"""
|
||||
print("🎭 Spoiler-Free Book RAG Demo")
|
||||
print("=" * 40)
|
||||
|
||||
# Step 1: Prepare book data
|
||||
book_title = "Alice's Adventures in Wonderland"
|
||||
book_chunks = chunk_book_with_metadata(book_title)
|
||||
|
||||
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||
|
||||
# Step 2: Build the index (in practice, this would be done once)
|
||||
try:
|
||||
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||
print(
|
||||
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||
)
|
||||
return
|
||||
|
||||
# Step 3: Demonstrate various spoiler-free searches
|
||||
search_scenarios = [
|
||||
{
|
||||
"description": "Reader who has only read Chapter 1",
|
||||
"query": "What can you tell me about the rabbit?",
|
||||
"max_chapter": 1,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read up to Chapter 5",
|
||||
"query": "Tell me about Alice's adventures",
|
||||
"max_chapter": 5,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read most of the book",
|
||||
"query": "What does the Cheshire Cat represent?",
|
||||
"max_chapter": 10,
|
||||
},
|
||||
{
|
||||
"description": "Reader who has read the whole book",
|
||||
"query": "What can you tell me about the rabbit?",
|
||||
"max_chapter": 12,
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in search_scenarios:
|
||||
print(f"\n📚 Scenario: {scenario['description']}")
|
||||
print(f" Query: {scenario['query']}")
|
||||
|
||||
try:
|
||||
results = spoiler_free_search(
|
||||
index_path=index_path,
|
||||
query=scenario["query"],
|
||||
max_chapter=scenario["max_chapter"],
|
||||
)
|
||||
|
||||
print(f" 📄 Found {len(results)} results:")
|
||||
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||
chapter = result.metadata.get("chapter", "?")
|
||||
location = result.metadata.get("location", "?")
|
||||
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Search failed: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||
print("=====================================")
|
||||
|
||||
try:
|
||||
demo_spoiler_free_rag()
|
||||
except ImportError as e:
|
||||
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error running demo: {e}")
|
||||