Compare commits
115 Commits
feature/cl
...
gen-time
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc6c53edf0 | ||
|
|
abc12d5069 | ||
|
|
9ba0ecac15 | ||
|
|
4e9e2f3da0 | ||
|
|
ed167f43b0 | ||
|
|
f9746d3fe2 | ||
|
|
a090a3444a | ||
|
|
aaaba27a4f | ||
|
|
f40f539456 | ||
|
|
576a2dcb49 | ||
|
|
ad8ab84675 | ||
|
|
58b96b64d8 | ||
|
|
a76c3cdac4 | ||
|
|
520619deab | ||
|
|
dea08c95b4 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3357d5765e | ||
|
|
9dbd0c64cc | ||
|
|
9c400acd7e | ||
|
|
ac560964f5 | ||
|
|
07e4f176e1 | ||
|
|
b1daf021e0 | ||
|
|
3578680cb6 | ||
|
|
a0d6857faa | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
d7011bbea0 | ||
|
|
ef4c69d128 | ||
|
|
75c8aeee5f | ||
|
|
3d79741f9c | ||
|
|
df34c84bd3 | ||
|
|
8dfd2f015c | ||
|
|
ed72232bab | ||
|
|
26d961bfc5 | ||
|
|
722bda4ebb | ||
|
|
a7c7e8801d | ||
|
|
069bce558b | ||
|
|
4e5b73ce7b | ||
|
|
772894012e | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
5c163737c4 | ||
|
|
6d1d67ead7 | ||
|
|
a6f07a54f1 | ||
|
|
ed27ea6990 | ||
|
|
baf2d76e0e | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
|||||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
1
.github/workflows/build-and-publish.yml
vendored
1
.github/workflows/build-and-publish.yml
vendored
@@ -5,6 +5,7 @@ on:
|
|||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|||||||
354
.github/workflows/build-reusable.yml
vendored
354
.github/workflows/build-reusable.yml
vendored
@@ -17,26 +17,17 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.ref }}
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Install uv and Python
|
||||||
uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Run pre-commit with only lint group (no project deps)
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install ruff
|
|
||||||
run: |
|
run: |
|
||||||
uv tool install ruff
|
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||||
|
|
||||||
- name: Run ruff check
|
|
||||||
run: |
|
|
||||||
ruff check .
|
|
||||||
|
|
||||||
- name: Run ruff format check
|
|
||||||
run: |
|
|
||||||
ruff format --check .
|
|
||||||
|
|
||||||
build:
|
build:
|
||||||
needs: lint
|
needs: lint
|
||||||
@@ -54,45 +45,108 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-latest
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: macos-latest
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: macos-latest
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.11'
|
python: '3.11'
|
||||||
- os: macos-latest
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: macos-latest
|
- os: ubuntu-24.04-arm
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.12'
|
||||||
|
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||||
|
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.ref }}
|
ref: ${{ inputs.ref }}
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Install uv and Python
|
||||||
uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
- name: Install system dependencies (Ubuntu)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
patchelf
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
# Debug: Show system information
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
echo "🔍 System Information:"
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
echo "Architecture: $(uname -m)"
|
||||||
source /opt/intel/oneapi/setvars.sh
|
echo "OS: $(uname -a)"
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
|
||||||
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
- name: Install system dependencies (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
@@ -102,55 +156,93 @@ jobs:
|
|||||||
|
|
||||||
- name: Install build dependencies
|
- name: Install build dependencies
|
||||||
run: |
|
run: |
|
||||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
uv python install ${{ matrix.python }}
|
||||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
uv venv --python ${{ matrix.python }} .uv-build
|
||||||
uv pip install --system auditwheel
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||||
else
|
else
|
||||||
uv pip install --system delocate
|
BUILD_PY=".uv-build/bin/python"
|
||||||
fi
|
fi
|
||||||
|
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --python "$BUILD_PY" auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --python "$BUILD_PY" delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||||
|
else
|
||||||
|
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set macOS environment variables
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Use brew --prefix to automatically detect Homebrew installation path
|
||||||
|
HOMEBREW_PREFIX=$(brew --prefix)
|
||||||
|
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||||
|
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set compiler flags for OpenMP (required for both backends)
|
||||||
|
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||||
|
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Build packages
|
- name: Build packages
|
||||||
run: |
|
run: |
|
||||||
# Build core (platform independent)
|
# Build core (platform independent)
|
||||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
cd packages/leann-core
|
||||||
cd packages/leann-core
|
uv build
|
||||||
uv build
|
cd ../..
|
||||||
cd ../..
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Build HNSW backend
|
# Build HNSW backend
|
||||||
cd packages/leann-backend-hnsw
|
cd packages/leann-backend-hnsw
|
||||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
# Use system clang instead of homebrew LLVM for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
export MACOSX_DEPLOYMENT_TARGET=11.0
|
# Homebrew libraries on each macOS version require matching minimum version
|
||||||
uv build --wheel --python python
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
uv build --wheel --python python
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
fi
|
fi
|
||||||
cd ../..
|
cd ../..
|
||||||
|
|
||||||
# Build DiskANN backend
|
# Build DiskANN backend
|
||||||
cd packages/leann-backend-diskann
|
cd packages/leann-backend-diskann
|
||||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
# Use system clang instead of homebrew LLVM for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
# But Homebrew libraries on each macOS version require matching minimum version
|
||||||
uv build --wheel --python python
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
uv build --wheel --python python
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
fi
|
fi
|
||||||
cd ../..
|
cd ../..
|
||||||
|
|
||||||
# Build meta package (platform independent)
|
# Build meta package (platform independent)
|
||||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
cd packages/leann
|
||||||
cd packages/leann
|
uv build
|
||||||
uv build
|
cd ../..
|
||||||
cd ../..
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Repair wheels (Linux)
|
- name: Repair wheels (Linux)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
@@ -176,10 +268,24 @@ jobs:
|
|||||||
- name: Repair wheels (macOS)
|
- name: Repair wheels (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
run: |
|
run: |
|
||||||
|
# Determine deployment target based on runner OS
|
||||||
|
# Must match the Homebrew libraries for each macOS version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
HNSW_TARGET="13.0"
|
||||||
|
DISKANN_TARGET="13.3"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
fi
|
||||||
|
|
||||||
# Repair HNSW wheel
|
# Repair HNSW wheel
|
||||||
cd packages/leann-backend-hnsw
|
cd packages/leann-backend-hnsw
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -188,7 +294,8 @@ jobs:
|
|||||||
# Repair DiskANN wheel
|
# Repair DiskANN wheel
|
||||||
cd packages/leann-backend-diskann
|
cd packages/leann-backend-diskann
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -199,39 +306,82 @@ jobs:
|
|||||||
echo "📦 Built packages:"
|
echo "📦 Built packages:"
|
||||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
|
||||||
- name: Install built packages for testing
|
- name: Install built packages for testing
|
||||||
run: |
|
run: |
|
||||||
# Create a virtual environment
|
# Create uv-managed virtual environment with the requested interpreter
|
||||||
uv venv
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }}
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
# Install the built wheels
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
# Use --find-links to let uv choose the correct wheel for the platform
|
UV_PY=".venv\\Scripts\\python.exe"
|
||||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
else
|
||||||
uv pip install leann-core --find-links packages/leann-core/dist
|
UV_PY=".venv/bin/python"
|
||||||
uv pip install leann --find-links packages/leann/dist
|
|
||||||
fi
|
fi
|
||||||
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
|
||||||
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
|
||||||
|
|
||||||
# Install test dependencies using extras
|
# Install test dependency group only (avoids reinstalling project package)
|
||||||
uv pip install -e ".[test]"
|
uv pip install --python "$UV_PY" --group test
|
||||||
|
|
||||||
|
# Install core wheel built in this job
|
||||||
|
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$CORE_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$HNSW_WHL" ]]; then
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$HNSW_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||||
|
fi
|
||||||
|
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$DISKANN_WHL" ]]; then
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$DISKANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||||
|
fi
|
||||||
|
|
||||||
|
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$LEANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Run tests with pytest
|
- name: Run tests with pytest
|
||||||
env:
|
env:
|
||||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
CI: true
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
HF_HUB_DISABLE_SYMLINKS: 1
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
TOKENIZERS_PARALLELISM: false
|
TOKENIZERS_PARALLELISM: false
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||||
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
OMP_NUM_THREADS: 1
|
||||||
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
MKL_NUM_THREADS: 1
|
||||||
run: |
|
run: |
|
||||||
# Activate virtual environment
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
# Run all tests
|
|
||||||
pytest tests/
|
|
||||||
|
|
||||||
- name: Run sanity checks (optional)
|
- name: Run sanity checks (optional)
|
||||||
run: |
|
run: |
|
||||||
@@ -249,3 +399,53 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
path: packages/*/dist/
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
|
arch-smoke:
|
||||||
|
name: Arch Linux smoke test (install & import)
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: archlinux:latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare system
|
||||||
|
run: |
|
||||||
|
pacman -Syu --noconfirm
|
||||||
|
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||||
|
|
||||||
|
- name: Download ALL wheel artifacts from this run
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
# Don't specify name, download all artifacts
|
||||||
|
path: ./wheels
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Create virtual environment and install wheels
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
uv pip install --find-links wheels leann-core
|
||||||
|
uv pip install --find-links wheels leann-backend-hnsw
|
||||||
|
uv pip install --find-links wheels leann-backend-diskann
|
||||||
|
uv pip install --find-links wheels leann
|
||||||
|
|
||||||
|
- name: Import & tiny runtime check
|
||||||
|
env:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
python - <<'PY'
|
||||||
|
import leann
|
||||||
|
import leann_backend_hnsw as h
|
||||||
|
import leann_backend_diskann as d
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
b = LeannBuilder(backend_name="hnsw")
|
||||||
|
b.add_text("hello arch")
|
||||||
|
b.build_index("arch_demo.leann")
|
||||||
|
s = LeannSearcher("arch_demo.leann")
|
||||||
|
print("search:", s.search("hello", top_k=1))
|
||||||
|
PY
|
||||||
|
|||||||
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: lycheeverse/lychee-action@v2
|
- uses: lycheeverse/lychee-action@v2
|
||||||
with:
|
with:
|
||||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -18,9 +18,12 @@ demo/experiment_results/**/*.json
|
|||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
|
*.png
|
||||||
|
!.vscode/*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -92,3 +95,7 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
tests/__pytest_cache__/
|
||||||
tests/__pycache__/
|
tests/__pycache__/
|
||||||
|
benchmarks/data/
|
||||||
|
|
||||||
|
## multi vector
|
||||||
|
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.5.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -10,7 +10,8 @@ repos:
|
|||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.2.1
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"recommendations": [
|
||||||
|
"charliermarsh.ruff",
|
||||||
|
]
|
||||||
|
}
|
||||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
|
"editor.formatOnSave": true,
|
||||||
|
"editor.codeActionsOnSave": {
|
||||||
|
"source.organizeImports": "explicit",
|
||||||
|
"source.fixAll": "explicit"
|
||||||
|
},
|
||||||
|
"editor.insertSpaces": true,
|
||||||
|
"editor.tabSize": 4
|
||||||
|
},
|
||||||
|
"ruff.enable": true,
|
||||||
|
"files.watcherExclude": {
|
||||||
|
"**/.venv/**": true,
|
||||||
|
"**/__pycache__/**": true,
|
||||||
|
"**/*.egg-info/**": true,
|
||||||
|
"**/build/**": true,
|
||||||
|
"**/dist/**": true
|
||||||
|
}
|
||||||
|
}
|
||||||
344
README.md
344
README.md
@@ -3,9 +3,13 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
||||||
|
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
||||||
|
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||||
|
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||||
|
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -16,7 +20,10 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
|||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +33,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
@@ -65,6 +72,8 @@ uv venv
|
|||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install leann
|
uv pip install leann
|
||||||
```
|
```
|
||||||
|
<!--
|
||||||
|
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>
|
<summary>
|
||||||
@@ -80,15 +89,60 @@ git submodule update --init --recursive
|
|||||||
```
|
```
|
||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
|
|
||||||
|
Note: DiskANN requires MacOS 13.3 or later.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
brew install libomp boost protobuf zeromq pkgconf
|
||||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux:**
|
**Linux (Ubuntu/Debian):**
|
||||||
|
|
||||||
|
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
||||||
|
|
||||||
|
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
sudo apt-get update && sudo apt-get install -y \
|
||||||
uv sync
|
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
libmkl-full-dev
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Arch Linux):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
||||||
|
boost boost-libs protobuf abseil-cpp libaio zeromq
|
||||||
|
|
||||||
|
# For MKL in DiskANN
|
||||||
|
sudo pacman -S --needed base-devel git
|
||||||
|
git clone https://aur.archlinux.org/paru-bin.git
|
||||||
|
cd paru-bin && makepkg -si
|
||||||
|
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
||||||
|
|
||||||
|
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo dnf groupinstall -y "Development Tools"
|
||||||
|
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
||||||
|
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
||||||
|
|
||||||
|
# For MKL in DiskANN
|
||||||
|
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
||||||
|
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -124,9 +178,14 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Generation Model Setup
|
### Generation Model Setup
|
||||||
|
|
||||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
#### LLM Backend
|
||||||
|
|
||||||
|
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||||
@@ -137,6 +196,68 @@ Set your OpenAI API key as an environment variable:
|
|||||||
export OPENAI_API_KEY="your-api-key-here"
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Make sure to use `--llm openai` flag when using the CLI.
|
||||||
|
You can also specify the model name with `--llm-model <model-name>` flag.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
|
||||||
|
|
||||||
|
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
export OPENAI_API_KEY="xxx"
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
|
||||||
|
```
|
||||||
|
|
||||||
|
To use OpenAI compatible endpoint with the CLI interface:
|
||||||
|
|
||||||
|
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
|
||||||
|
|
||||||
|
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
|
||||||
|
Below is a list of base URLs for common providers to get you started.
|
||||||
|
|
||||||
|
|
||||||
|
### 🖥️ Local Inference Engines (Recommended for full privacy)
|
||||||
|
|
||||||
|
| Provider | Sample Base URL |
|
||||||
|
| ---------------- | --------------------------- |
|
||||||
|
| **Ollama** | `http://localhost:11434/v1` |
|
||||||
|
| **LM Studio** | `http://localhost:1234/v1` |
|
||||||
|
| **vLLM** | `http://localhost:8000/v1` |
|
||||||
|
| **llama.cpp** | `http://localhost:8080/v1` |
|
||||||
|
| **SGLang** | `http://localhost:30000/v1` |
|
||||||
|
| **LiteLLM** | `http://localhost:4000` |
|
||||||
|
|
||||||
|
-----
|
||||||
|
|
||||||
|
### ☁️ Cloud Providers
|
||||||
|
|
||||||
|
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
|
||||||
|
|
||||||
|
|
||||||
|
| Provider | Base URL |
|
||||||
|
| ---------------- | ---------------------------------------------------------- |
|
||||||
|
| **OpenAI** | `https://api.openai.com/v1` |
|
||||||
|
| **OpenRouter** | `https://openrouter.ai/api/v1` |
|
||||||
|
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
|
||||||
|
| **x.AI (Grok)** | `https://api.x.ai/v1` |
|
||||||
|
| **Groq AI** | `https://api.groq.com/openai/v1` |
|
||||||
|
| **DeepSeek** | `https://api.deepseek.com/v1` |
|
||||||
|
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
|
||||||
|
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
|
||||||
|
| **Mistral AI** | `https://api.mistral.ai/v1` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -166,7 +287,8 @@ ollama pull llama3.2:1b
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### ⭐ Flexible Configuration
|
|
||||||
|
## ⭐ Flexible Configuration
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
@@ -179,34 +301,34 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Core Parameters (General preprocessing for all examples)
|
# Core Parameters (General preprocessing for all examples)
|
||||||
--index-dir DIR # Directory to store the index (default: current directory)
|
--index-dir DIR # Directory to store the index (default: current directory)
|
||||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||||
--force-rebuild # Force rebuild index even if it exists
|
--force-rebuild # Force rebuild index even if it exists
|
||||||
|
|
||||||
# Embedding Parameters
|
# Embedding Parameters
|
||||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small or mlx-community/multilingual-e5-base-mlx
|
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||||
--embedding-mode MODE # sentence-transformers, openai, or mlx
|
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# LLM Parameters (Text generation models)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||||
|
|
||||||
# Search Parameters
|
# Search Parameters
|
||||||
--top-k N # Number of results to retrieve (default: 20)
|
--top-k N # Number of results to retrieve (default: 20)
|
||||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||||
|
|
||||||
# Chunking Parameters
|
# Chunking Parameters
|
||||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||||
|
|
||||||
# Index Building Parameters
|
# Index Building Parameters
|
||||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||||
--graph-degree N # Graph degree for index construction (default: 32)
|
--graph-degree N # Graph degree for index construction (default: 32)
|
||||||
--build-complexity N # Build complexity for index construction (default: 64)
|
--build-complexity N # Build complexity for index construction (default: 64)
|
||||||
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||||
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -219,7 +341,7 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
|||||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a README in Chinese) and this is the **easiest example** to run here:
|
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||||
@@ -242,6 +364,12 @@ python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
|||||||
|
|
||||||
# Filter only markdown and Python files with smaller chunks
|
# Filter only markdown and Python files with smaller chunks
|
||||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||||
|
|
||||||
|
# Enable AST-aware chunking for code files
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
||||||
|
|
||||||
|
# Or use the specialized code RAG for better code understanding
|
||||||
|
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -414,7 +542,36 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||||
|
|
||||||
|
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||||
|
|
||||||
|
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
||||||
|
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
||||||
|
- 📚 **Context-aware assistance** for debugging and development
|
||||||
|
- 🚀 **Zero-config setup** with automatic language detection
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN globally for MCP integration
|
||||||
|
uv tool install leann-core --with leann
|
||||||
|
claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
# Setup is automatic - just start using Claude Code!
|
||||||
|
```
|
||||||
|
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||||
|
|
||||||
## 🖥️ Command Line Interface
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
@@ -428,22 +585,25 @@ source .venv/bin/activate
|
|||||||
leann --help
|
leann --help
|
||||||
```
|
```
|
||||||
|
|
||||||
**To make it globally available (recommended for daily use):**
|
**To make it globally available:**
|
||||||
```bash
|
```bash
|
||||||
# Install the LEANN CLI globally using uv tool
|
# Install the LEANN CLI globally using uv tool
|
||||||
uv tool install leann
|
uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
|
||||||
# Now you can use leann from anywhere without activating venv
|
# Now you can use leann from anywhere without activating venv
|
||||||
leann --help
|
leann --help
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Usage Examples
|
### Usage Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Build an index from documents
|
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
|
||||||
leann build my-docs --docs ./documents
|
leann build my-docs --docs ./your_documents
|
||||||
|
|
||||||
# Search your documents
|
# Search your documents
|
||||||
leann search my-docs "machine learning concepts"
|
leann search my-docs "machine learning concepts"
|
||||||
@@ -451,32 +611,41 @@ leann search my-docs "machine learning concepts"
|
|||||||
# Interactive chat with your documents
|
# Interactive chat with your documents
|
||||||
leann ask my-docs --interactive
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# Ask a single question (non-interactive)
|
||||||
|
leann ask my-docs "Where are prompts configured?"
|
||||||
|
|
||||||
# List all your indexes
|
# List all your indexes
|
||||||
leann list
|
leann list
|
||||||
|
|
||||||
|
# Remove an index
|
||||||
|
leann remove my-docs
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key CLI features:**
|
**Key CLI features:**
|
||||||
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
||||||
- Smart text chunking with overlap
|
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
||||||
|
- Smart text chunking with overlap for all other content
|
||||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
- Organized index storage in `~/.leann/indexes/`
|
- Organized index storage in `.leann/indexes/` (project-local)
|
||||||
- Support for advanced search parameters
|
- Support for advanced search parameters
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
|
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
|
||||||
|
|
||||||
**Build Command:**
|
**Build Command:**
|
||||||
```bash
|
```bash
|
||||||
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
--graph-degree N Graph degree (default: 32)
|
--graph-degree N Graph degree (default: 32)
|
||||||
--complexity N Build complexity (default: 64)
|
--complexity N Build complexity (default: 64)
|
||||||
--force Force rebuild existing index
|
--force Force rebuild existing index
|
||||||
--compact Use compact storage (default: true)
|
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||||
--recompute Enable recomputation (default: true)
|
--recompute / --no-recompute Enable recomputation (default: true)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Search Command:**
|
**Search Command:**
|
||||||
@@ -484,9 +653,9 @@ Options:
|
|||||||
leann search INDEX_NAME QUERY [OPTIONS]
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--top-k N Number of results (default: 5)
|
--top-k N Number of results (default: 5)
|
||||||
--complexity N Search complexity (default: 64)
|
--complexity N Search complexity (default: 64)
|
||||||
--recompute-embeddings Use recomputation for highest accuracy
|
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||||
--pruning-strategy {global,local,proportional}
|
--pruning-strategy {global,local,proportional}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -501,8 +670,73 @@ Options:
|
|||||||
--top-k N Retrieval count (default: 20)
|
--top-k N Retrieval count (default: 20)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**List Command:**
|
||||||
|
```bash
|
||||||
|
leann list
|
||||||
|
|
||||||
|
# Lists all indexes across all projects with status indicators:
|
||||||
|
# ✅ - Index is complete and ready to use
|
||||||
|
# ❌ - Index is incomplete or corrupted
|
||||||
|
# 📁 - CLI-created index (in .leann/indexes/)
|
||||||
|
# 📄 - App-created index (*.leann.meta.json files)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Remove Command:**
|
||||||
|
```bash
|
||||||
|
leann remove INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--force, -f Force removal without confirmation
|
||||||
|
|
||||||
|
# Smart removal: automatically finds and safely removes indexes
|
||||||
|
# - Shows all matching indexes across projects
|
||||||
|
# - Requires confirmation for cross-project removal
|
||||||
|
# - Interactive selection when multiple matches found
|
||||||
|
# - Supports both CLI and app-created indexes
|
||||||
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 🚀 Advanced Features
|
||||||
|
|
||||||
|
### 🎯 Metadata Filtering
|
||||||
|
|
||||||
|
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Add metadata during indexing
|
||||||
|
builder.add_text(
|
||||||
|
"def authenticate_user(token): ...",
|
||||||
|
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||||
|
|
||||||
|
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||||
|
|
||||||
|
### 🔍 Grep Search
|
||||||
|
|
||||||
|
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||||
|
|
||||||
|
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -517,12 +751,16 @@ Options:
|
|||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
|
**Backends:**
|
||||||
|
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||||
|
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
|
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||||
|
|
||||||
|
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||||
|
|
||||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
|
||||||
### 📊 Storage Comparison
|
### 📊 Storage Comparison
|
||||||
|
|
||||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
@@ -536,8 +774,8 @@ Options:
|
|||||||
## Reproduce Our Results
|
## Reproduce Our Results
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
uv run benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
uv run benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||||
```
|
```
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
@@ -577,12 +815,16 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
|
|
||||||
|
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
|
|
||||||
---
|
## Star History
|
||||||
|
|
||||||
|
[](https://www.star-history.com/#yichuan-w/LEANN&Date)
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from leann.registry import register_project_directory
|
||||||
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -69,14 +70,32 @@ class BaseRAGExample(ABC):
|
|||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
type=str,
|
type=str,
|
||||||
default=embedding_model_default,
|
default=embedding_model_default,
|
||||||
help=f"Embedding model to use (default: {embedding_model_default})",
|
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||||
)
|
)
|
||||||
embedding_group.add_argument(
|
embedding_group.add_argument(
|
||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# LLM parameters
|
# LLM parameters
|
||||||
@@ -85,20 +104,20 @@ class BaseRAGExample(ABC):
|
|||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
default="openai",
|
default="openai",
|
||||||
choices=["openai", "ollama", "hf"],
|
choices=["openai", "ollama", "hf", "simulated"],
|
||||||
help="LLM backend to use (default: openai)",
|
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-model",
|
"--llm-model",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-host",
|
"--llm-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:11434",
|
default=None,
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--thinking-budget",
|
"--thinking-budget",
|
||||||
@@ -107,6 +126,50 @@ class BaseRAGExample(ABC):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# AST Chunking parameters
|
||||||
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum characters per AST chunk (default: 512)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Overlap between AST chunks (default: 64)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--code-file-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
# Search parameters
|
# Search parameters
|
||||||
search_group = parser.add_argument_group("Search Parameters")
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
@@ -173,11 +236,18 @@ class BaseRAGExample(ABC):
|
|||||||
|
|
||||||
if args.llm == "openai":
|
if args.llm == "openai":
|
||||||
config["model"] = args.llm_model or "gpt-4o"
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
elif args.llm == "ollama":
|
elif args.llm == "ollama":
|
||||||
config["model"] = args.llm_model or "llama3.2:1b"
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
config["host"] = args.llm_host
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
elif args.llm == "hf":
|
elif args.llm == "hf":
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
elif args.llm == "simulated":
|
||||||
|
# Simulated LLM doesn't need additional configuration
|
||||||
|
pass
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -188,10 +258,20 @@ class BaseRAGExample(ABC):
|
|||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
print(f"Total text chunks: {len(texts)}")
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.build_complexity,
|
complexity=args.build_complexity,
|
||||||
is_compact=not args.no_compact,
|
is_compact=not args.no_compact,
|
||||||
@@ -211,6 +291,11 @@ class BaseRAGExample(ABC):
|
|||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index saved to: {index_path}")
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
# Register project directory so leann list can discover this index
|
||||||
|
# The index is saved as args.index_dir/index_name.leann
|
||||||
|
# We want to register the current working directory where the app is run
|
||||||
|
register_project_directory(Path.cwd())
|
||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
async def run_interactive_chat(self, args, index_path: str):
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
@@ -259,7 +344,6 @@ class BaseRAGExample(ABC):
|
|||||||
chat = LeannChat(
|
chat = LeannChat(
|
||||||
index_path,
|
index_path,
|
||||||
llm_config=self.get_llm_config(args),
|
llm_config=self.get_llm_config(args),
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
complexity=args.search_complexity,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -301,21 +385,3 @@ class BaseRAGExample(ABC):
|
|||||||
await self.run_single_query(args, index_path, args.query)
|
await self.run_single_query(args, index_path, args.query)
|
||||||
else:
|
else:
|
||||||
await self.run_interactive_chat(args, index_path)
|
await self.run_interactive_chat(args, index_path)
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
|
||||||
"""Helper function to create text chunks from documents."""
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator=" ",
|
|
||||||
paragraph_separator="\n\n",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
if nodes:
|
|
||||||
all_texts.extend(node.get_content() for node in nodes)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .history_data.history import ChromeHistoryReader
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|||||||
44
apps/chunking/__init__.py
Normal file
44
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Unified chunking utilities facade.
|
||||||
|
|
||||||
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"create_ast_chunks",
|
||||||
|
"create_text_chunks",
|
||||||
|
"create_traditional_chunks",
|
||||||
|
"detect_code_files",
|
||||||
|
"get_language_from_extension",
|
||||||
|
]
|
||||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||||
|
Specialized for code repositories with automatic language detection and
|
||||||
|
optimized chunking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRAG(BaseRAGExample):
|
||||||
|
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Code",
|
||||||
|
description="Process and query code repositories with AST-aware chunking",
|
||||||
|
default_index_name="code_index",
|
||||||
|
)
|
||||||
|
# Override defaults for code-specific usage
|
||||||
|
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||||
|
self.max_items_default = -1 # Process all code files by default
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add code-specific arguments."""
|
||||||
|
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||||
|
|
||||||
|
code_group.add_argument(
|
||||||
|
"--repo-dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Code repository directory to index (default: current directory)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=list(CODE_EXTENSIONS.keys()),
|
||||||
|
help="File extensions to include (default: supported code extensions)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--exclude-dirs",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
],
|
||||||
|
help="Directories to exclude from indexing",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=int,
|
||||||
|
default=1000000, # 1MB
|
||||||
|
help="Maximum file size in bytes to process (default: 1MB)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-comments",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in chunking (useful for documentation)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--preserve-imports",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
|
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||||
|
|
||||||
|
# Check if repository directory exists
|
||||||
|
repo_path = Path(args.repo_dir)
|
||||||
|
if not repo_path.exists():
|
||||||
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
|
# Load code files with filtering
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"required_exts": args.include_extensions,
|
||||||
|
"exclude_hidden": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create exclusion filter
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
"""Filter out unwanted files and directories."""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
if path.stat().st_size > args.max_file_size:
|
||||||
|
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if in excluded directory
|
||||||
|
for exclude_dir in args.exclude_dirs:
|
||||||
|
if exclude_dir in path.parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load documents with file filtering
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.repo_dir,
|
||||||
|
file_extractor=None, # Use default extractors
|
||||||
|
**reader_kwargs,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
# Apply custom filtering
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_filter(file_path):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading code files: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(
|
||||||
|
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Loaded {len(documents)} code files")
|
||||||
|
|
||||||
|
# Show breakdown by language/extension
|
||||||
|
ext_counts = {}
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||||
|
|
||||||
|
print("📊 Files by extension:")
|
||||||
|
for ext, count in sorted(ext_counts.items()):
|
||||||
|
print(f" {ext}: {count} files")
|
||||||
|
|
||||||
|
# Use AST-aware chunking by default for code
|
||||||
|
print(
|
||||||
|
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=256, # Fallback for non-code files
|
||||||
|
chunk_overlap=64,
|
||||||
|
use_ast_chunking=True, # Always use AST for code RAG
|
||||||
|
ast_chunk_size=args.ast_chunk_size,
|
||||||
|
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||||
|
code_file_extensions=args.include_extensions,
|
||||||
|
ast_fallback_traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for code RAG
|
||||||
|
print("\n💻 Code RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'How does the embedding computation work?'")
|
||||||
|
print("- 'What are the main classes in this codebase?'")
|
||||||
|
print("- 'Show me the search implementation'")
|
||||||
|
print("- 'How is error handling implemented?'")
|
||||||
|
print("- 'What design patterns are used?'")
|
||||||
|
print("- 'Explain the chunking logic'")
|
||||||
|
print("\n🚀 Features:")
|
||||||
|
print("- ✅ AST-aware chunking preserves code structure")
|
||||||
|
print("- ✅ Automatic language detection")
|
||||||
|
print("- ✅ Smart filtering of large files and common excludes")
|
||||||
|
print("- ✅ Optimized for code understanding")
|
||||||
|
print("\nUsage examples:")
|
||||||
|
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||||
|
print(
|
||||||
|
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = CodeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
@@ -44,6 +45,11 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
doc_group.add_argument(
|
doc_group.add_argument(
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
)
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--enable-code-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[str]:
|
||||||
"""Load documents and convert to text chunks."""
|
"""Load documents and convert to text chunks."""
|
||||||
@@ -76,9 +82,22 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
|
|
||||||
print(f"Loaded {len(documents)} documents")
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
# Convert to text chunks
|
# Determine chunking strategy
|
||||||
|
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("Using AST-aware chunking for code files")
|
||||||
|
|
||||||
|
# Convert to text chunks with optional AST support
|
||||||
all_texts = create_text_chunks(
|
all_texts = create_text_chunks(
|
||||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
use_ast_chunking=use_ast,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||||
|
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
# Apply max_items limit if specified
|
||||||
@@ -102,6 +121,10 @@ if __name__ == "__main__":
|
|||||||
print(
|
print(
|
||||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
)
|
)
|
||||||
|
print("\n🚀 NEW: Code-aware chunking available!")
|
||||||
|
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||||
|
print("- Supports Python, Java, C#, TypeScript files")
|
||||||
|
print("- Better semantic understanding of code structure")
|
||||||
print("\nOr run without --query for interactive mode\n")
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
rag = DocumentRAG()
|
rag = DocumentRAG()
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .email_data.LEANN_email_reader import EmlxReader
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
|
|||||||
@@ -0,0 +1,182 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class LeannMultiVector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_path: str,
|
||||||
|
dim: int = 128,
|
||||||
|
distance_metric: str = "mips",
|
||||||
|
m: int = 16,
|
||||||
|
ef_construction: int = 500,
|
||||||
|
is_compact: bool = False,
|
||||||
|
is_recompute: bool = False,
|
||||||
|
embedding_model_name: str = "colvision",
|
||||||
|
) -> None:
|
||||||
|
self.index_path = index_path
|
||||||
|
self.dim = dim
|
||||||
|
self.embedding_model_name = embedding_model_name
|
||||||
|
self._pending_items: list[dict] = []
|
||||||
|
self._backend_kwargs = {
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"M": m,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
"is_compact": is_compact,
|
||||||
|
"is_recompute": is_recompute,
|
||||||
|
}
|
||||||
|
self._labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
def _meta_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": self.embedding_model_name,
|
||||||
|
"embedding_mode": "custom",
|
||||||
|
"dimensions": self.dim,
|
||||||
|
"backend_kwargs": self._backend_kwargs,
|
||||||
|
"is_compact": self._backend_kwargs.get("is_compact", True),
|
||||||
|
"is_pruned": self._backend_kwargs.get("is_compact", True)
|
||||||
|
and self._backend_kwargs.get("is_recompute", True),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_collection(self) -> None:
|
||||||
|
path = Path(self.index_path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def insert(self, data: dict) -> None:
|
||||||
|
self._pending_items.append(
|
||||||
|
{
|
||||||
|
"doc_id": int(data["doc_id"]),
|
||||||
|
"filepath": data.get("filepath", ""),
|
||||||
|
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _labels_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
|
||||||
|
|
||||||
|
def _meta_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
|
def create_index(self) -> None:
|
||||||
|
if not self._pending_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings: list[np.ndarray] = []
|
||||||
|
labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
for item in self._pending_items:
|
||||||
|
doc_id = int(item["doc_id"])
|
||||||
|
filepath = item.get("filepath", "")
|
||||||
|
colbert_vecs = item["colbert_vecs"]
|
||||||
|
for seq_id, vec in enumerate(colbert_vecs):
|
||||||
|
vec_np = np.asarray(vec, dtype=np.float32)
|
||||||
|
embeddings.append(vec_np)
|
||||||
|
labels_meta.append(
|
||||||
|
{
|
||||||
|
"id": f"{doc_id}:{seq_id}",
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"seq_id": int(seq_id),
|
||||||
|
"filepath": filepath,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||||
|
# print shape of embeddings_np
|
||||||
|
print(embeddings_np.shape)
|
||||||
|
|
||||||
|
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||||
|
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||||
|
builder.build(embeddings_np, ids, self.index_path)
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(self._meta_dict(), f, indent=2)
|
||||||
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
|
if self._labels_meta:
|
||||||
|
return
|
||||||
|
labels_path = self._labels_path()
|
||||||
|
if labels_path.exists():
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
|
||||||
|
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||||
|
raw = searcher.search(
|
||||||
|
data,
|
||||||
|
first_stage_k,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
batch_size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = raw.get("labels")
|
||||||
|
distances = raw.get("distances")
|
||||||
|
if labels is None or distances is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
doc_scores: dict[int, float] = {}
|
||||||
|
B = len(labels)
|
||||||
|
for b in range(B):
|
||||||
|
per_doc_best: dict[int, float] = {}
|
||||||
|
for k, sid in enumerate(labels[b]):
|
||||||
|
try:
|
||||||
|
idx = int(sid)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if 0 <= idx < len(self._labels_meta):
|
||||||
|
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
score = float(distances[b][k])
|
||||||
|
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
|
||||||
|
per_doc_best[doc_id] = score
|
||||||
|
for doc_id, best_score in per_doc_best.items():
|
||||||
|
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
|
||||||
|
|
||||||
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
@@ -0,0 +1,477 @@
|
|||||||
|
## Jupyter-style notebook script
|
||||||
|
# %%
|
||||||
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Config
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||||
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
|
USE_HF_DATASET: bool = True
|
||||||
|
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||||
|
DATASET_SPLIT: str = "train"
|
||||||
|
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||||
|
|
||||||
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
|
PAGES_DIR: str = "./pages"
|
||||||
|
|
||||||
|
# Index + retrieval settings
|
||||||
|
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||||
|
TOPK: int = 1
|
||||||
|
FIRST_STAGE_K: int = 500
|
||||||
|
REBUILD_INDEX: bool = False
|
||||||
|
|
||||||
|
# Artifacts
|
||||||
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
|
SIMILARITY_MAP: bool = True
|
||||||
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
|
ANSWER: bool = True
|
||||||
|
MAX_NEW_TOKENS: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Helpers
|
||||||
|
def _natural_sort_key(name: str) -> int:
|
||||||
|
m = re.search(r"\d+", name)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||||
|
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||||
|
filenames = sorted(filenames, key=_natural_sort_key)
|
||||||
|
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||||
|
images = [Image.open(p) for p in filepaths]
|
||||||
|
return filepaths, images
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
if not pdf_path:
|
||||||
|
return
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||||
|
) from e
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def _select_device_and_dtype():
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import get_torch_device
|
||||||
|
|
||||||
|
device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(device_str)
|
||||||
|
# Stable dtype selection to avoid NaNs:
|
||||||
|
# - CUDA: prefer bfloat16 if supported, else float16
|
||||||
|
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||||
|
# - CPU: float32
|
||||||
|
if device_str == "cuda":
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif device_str == "mps":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
return device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _load_colvision(model_choice: str):
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
if model_choice == "colqwen2":
|
||||||
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
|
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2"
|
||||||
|
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||||
|
else "eager"
|
||||||
|
)
|
||||||
|
model = ColQwen2.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Ensure deterministic eval and autocast for stability
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[Image.Image](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_vecs: list[Any] = []
|
||||||
|
for batch_doc in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
else:
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
return doc_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_vecs: list[Any] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
else:
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
return q_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||||
|
dim = int(doc_vecs[0].shape[-1])
|
||||||
|
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
retriever.create_collection()
|
||||||
|
for i, vec in enumerate(doc_vecs):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": vec.float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||||
|
index_base = Path(index_path)
|
||||||
|
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||||
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
|
if index_base.exists() and meta.exists() and labels.exists():
|
||||||
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_similarity_map(
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
image: Image.Image,
|
||||||
|
query: str,
|
||||||
|
token_idx: Optional[int] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.interpretability import (
|
||||||
|
get_similarity_maps_from_embeddings,
|
||||||
|
plot_similarity_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = processor.process_images([image]).to(model.device)
|
||||||
|
batch_queries = processor.process_queries([query]).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_embeddings = model.forward(**batch_images)
|
||||||
|
query_embeddings = model.forward(**batch_queries)
|
||||||
|
|
||||||
|
n_patches = processor.get_n_patches(
|
||||||
|
image_size=image.size,
|
||||||
|
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||||
|
)
|
||||||
|
image_mask = processor.get_image_mask(batch_images)
|
||||||
|
|
||||||
|
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_patches=n_patches,
|
||||||
|
image_mask=image_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_maps = batched_similarity_maps[0]
|
||||||
|
|
||||||
|
# Determine token index if not provided: choose the token with highest max score
|
||||||
|
if token_idx is None:
|
||||||
|
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||||
|
token_idx = int(per_token_max.argmax().item())
|
||||||
|
|
||||||
|
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plot_similarity_map(
|
||||||
|
image=image,
|
||||||
|
similarity_map=similarity_maps[token_idx],
|
||||||
|
figsize=(14, 14),
|
||||||
|
show_colorbar=False,
|
||||||
|
)
|
||||||
|
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
plt.savefig(output_path, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return token_idx, float(max_sim_score)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVL:
|
||||||
|
def __init__(self, device: str):
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_pixels = 256 * 28 * 28
|
||||||
|
max_pixels = 1280 * 28 * 28
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
|
||||||
|
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for img in images:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format="jpeg")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||||
|
content.append({"type": "text", "text": query})
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Step 1: Prepare data
|
||||||
|
if USE_HF_DATASET:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||||
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
|
filepaths: list[str] = []
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
for i in tqdm(range(N), desc="Loading dataset"):
|
||||||
|
p = dataset[i]
|
||||||
|
# Compose a descriptive identifier for printing later
|
||||||
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
print(identifier)
|
||||||
|
filepaths.append(identifier)
|
||||||
|
images.append(p["page_image"]) # PIL Image
|
||||||
|
else:
|
||||||
|
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||||
|
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 2: Load model and processor
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
|
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 3: Build or load index
|
||||||
|
retriever: Optional[LeannMultiVector] = None
|
||||||
|
if not REBUILD_INDEX:
|
||||||
|
try:
|
||||||
|
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||||
|
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||||
|
except Exception:
|
||||||
|
retriever = None
|
||||||
|
|
||||||
|
if retriever is None:
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 4: Embed query and search
|
||||||
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
|
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
|
if not results:
|
||||||
|
print("No results found.")
|
||||||
|
else:
|
||||||
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
|
top_images: list[Image.Image] = []
|
||||||
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
|
path = filepaths[doc_id]
|
||||||
|
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||||
|
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||||
|
top_images.append(images[doc_id])
|
||||||
|
|
||||||
|
if SAVE_TOP_IMAGE:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
base = _Path(SAVE_TOP_IMAGE)
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if base.suffix:
|
||||||
|
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||||
|
else:
|
||||||
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
|
img.save(str(out_path))
|
||||||
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
|
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 5: Similarity maps for top-K results
|
||||||
|
if results and SIMILARITY_MAP:
|
||||||
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if output_base:
|
||||||
|
if output_base.suffix:
|
||||||
|
out_dir = output_base.parent
|
||||||
|
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||||
|
out_path = str(out_dir / out_name)
|
||||||
|
else:
|
||||||
|
out_dir = output_base
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||||
|
else:
|
||||||
|
out_path = None
|
||||||
|
chosen_idx, max_sim = _generate_similarity_map(
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
image=img,
|
||||||
|
query=QUERY,
|
||||||
|
token_idx=token_idx,
|
||||||
|
output_path=out_path,
|
||||||
|
)
|
||||||
|
if out_path:
|
||||||
|
print(
|
||||||
|
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 6: Optional answer generation
|
||||||
|
if results and ANSWER:
|
||||||
|
qwen = QwenVL(device=device_str)
|
||||||
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
print("\nAnswer:")
|
||||||
|
print(response)
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
# pip install pdf2image
|
||||||
|
# pip install pymilvus
|
||||||
|
# pip install colpali_engine
|
||||||
|
# pip install tqdm
|
||||||
|
# pip install pillow
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
|
||||||
|
pdf_path = "pdfs/2004.12832v2.pdf"
|
||||||
|
images = convert_from_path(pdf_path)
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(f"pages/page_{i + 1}.png", "PNG")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make local leann packages importable without installing
|
||||||
|
_repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector
|
||||||
|
|
||||||
|
|
||||||
|
class LeannRetriever(LeannMultiVector):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||||
|
_device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(_device_str)
|
||||||
|
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||||
|
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=_dtype,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"How to end-to-end retrieval with ColBert",
|
||||||
|
"Where is ColBERT performance Table, including text representation results?",
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
qs: list[torch.Tensor] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
print(qs[0].shape)
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||||
|
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
ds: list[torch.Tensor] = []
|
||||||
|
for batch_doc in tqdm(dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
|
||||||
|
print(ds[0].shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Build HNSW index via LeannRetriever primitives and run search
|
||||||
|
index_path = "./indexes/colpali.leann"
|
||||||
|
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||||
|
retriever.create_collection()
|
||||||
|
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||||
|
for i in range(len(filepaths)):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": ds[i].float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
for query in qs:
|
||||||
|
query_np = query.float().numpy()
|
||||||
|
result = retriever.search(query_np, topk=1)
|
||||||
|
print(filepaths[result[0][1]])
|
||||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,9 +1,24 @@
|
|||||||
# 🧪 Leann Sanity Checks
|
# 🧪 LEANN Benchmarks & Testing
|
||||||
|
|
||||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
|
### `diskann_vs_hnsw_speed_comparison.py`
|
||||||
|
Performance comparison between DiskANN and HNSW backends:
|
||||||
|
- ✅ **Search latency** comparison with both backends using recompute
|
||||||
|
- ✅ **Index size** and **build time** measurements
|
||||||
|
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||||
|
- ✅ **Configurable dataset sizes** for different scales
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick comparison with 500 docs, 10 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||||
|
|
||||||
|
# Large-scale comparison with 2000 docs, 20 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||||
|
```
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
|
|||||||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_exists(index_path: str) -> bool:
|
||||||
|
p = Path(index_path)
|
||||||
|
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||||
|
# if _meta_exists(index_path):
|
||||||
|
# return
|
||||||
|
kwargs = {}
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
kwargs["is_compact"] = is_recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||||
|
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
num_threads=4,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for i in range(num_docs):
|
||||||
|
builder.add_text(
|
||||||
|
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||||
|
)
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_group(
|
||||||
|
index_path: str,
|
||||||
|
recompute: bool,
|
||||||
|
query: str,
|
||||||
|
repeats: int,
|
||||||
|
complexity: int = 32,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> float:
|
||||||
|
# Independent searcher per group; fixed port when recompute
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
|
||||||
|
# Warm-up once
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _once() -> float:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
return time.time() - t0
|
||||||
|
|
||||||
|
if repeats <= 1:
|
||||||
|
t = _once()
|
||||||
|
else:
|
||||||
|
vals = [_once() for _ in range(repeats)]
|
||||||
|
vals.sort()
|
||||||
|
t = vals[len(vals) // 2]
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-docs", type=int, default=5000)
|
||||||
|
parser.add_argument("--repeats", type=int, default=3)
|
||||||
|
parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# ---------- Build HNSW variants ----------
|
||||||
|
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||||
|
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||||
|
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||||
|
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Build DiskANN variants ----------
|
||||||
|
diskann_r = str(base / "diskann_r.leann")
|
||||||
|
diskann_nr = str(base / "diskann_nr.leann")
|
||||||
|
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||||
|
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Helpers ----------
|
||||||
|
def _size_for(prefix: str) -> int:
|
||||||
|
p = Path(prefix)
|
||||||
|
base_dir = p.parent
|
||||||
|
stem = p.stem
|
||||||
|
total = 0
|
||||||
|
for f in base_dir.iterdir():
|
||||||
|
if f.is_file() and f.name.startswith(stem):
|
||||||
|
total += f.stat().st_size
|
||||||
|
return total
|
||||||
|
|
||||||
|
# ---------- HNSW benchmark ----------
|
||||||
|
t_hnsw_r = _bench_group(
|
||||||
|
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_hnsw_nr = _bench_group(
|
||||||
|
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
size_hnsw_r = _size_for(hnsw_r)
|
||||||
|
size_hnsw_nr = _size_for(hnsw_nr)
|
||||||
|
|
||||||
|
print("Benchmark results (HNSW):")
|
||||||
|
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(
|
||||||
|
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||||
|
)
|
||||||
|
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||||
|
|
||||||
|
# ---------- DiskANN benchmark ----------
|
||||||
|
t_diskann_r = _bench_group(
|
||||||
|
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_diskann_nr = _bench_group(
|
||||||
|
diskann_nr,
|
||||||
|
False,
|
||||||
|
"DiskANN NR test doc 123",
|
||||||
|
repeats=args.repeats,
|
||||||
|
complexity=args.complexity,
|
||||||
|
)
|
||||||
|
size_diskann_r = _size_for(diskann_r)
|
||||||
|
size_diskann_nr = _size_for(diskann_nr)
|
||||||
|
|
||||||
|
print("\nBenchmark results (DiskANN):")
|
||||||
|
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||||
|
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
BM25 vs DiskANN Baselines
|
||||||
|
|
||||||
|
```bash
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||||
|
```
|
||||||
|
|
||||||
|
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||||
|
- Machine-specific; results measured locally with the current repo.
|
||||||
|
|
||||||
|
DiskANN (NQ queries, search-only)
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||||
|
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||||
|
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||||
|
|
||||||
|
BM25
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||||
|
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||||
|
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||||
|
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "pyserini"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
# sudo pacman -S jdk21-openjdk
|
||||||
|
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||||
|
# sudo archlinux-java status
|
||||||
|
# sudo archlinux-java set java-21-openjdk
|
||||||
|
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||||
|
# fish_add_path --global $JAVA_HOME/bin
|
||||||
|
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||||
|
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||||
|
_, ext = os.path.splitext(path)
|
||||||
|
if ext.lower() in {".jsonl", ".json"}:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not strict JSONL? treat the whole line as the query
|
||||||
|
queries.append(line)
|
||||||
|
continue
|
||||||
|
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||||
|
if q:
|
||||||
|
queries.append(str(q))
|
||||||
|
else:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
s = line.strip()
|
||||||
|
if s:
|
||||||
|
queries.append(s)
|
||||||
|
|
||||||
|
if limit is not None and limit > 0:
|
||||||
|
queries = queries[:limit]
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def percentile(values: list[float], p: float) -> float:
|
||||||
|
if not values:
|
||||||
|
return 0.0
|
||||||
|
s = sorted(values)
|
||||||
|
k = (len(s) - 1) * (p / 100.0)
|
||||||
|
f = int(k)
|
||||||
|
c = min(f + 1, len(s) - 1)
|
||||||
|
if f == c:
|
||||||
|
return s[f]
|
||||||
|
return s[f] + (s[c] - s[f]) * (k - f)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--bm25-index",
|
||||||
|
default="benchmarks/data/indices/bm25_index",
|
||||||
|
help="Path to Pyserini Lucene index directory",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--queries",
|
||||||
|
default="benchmarks/data/queries/nq_open.jsonl",
|
||||||
|
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||||
|
)
|
||||||
|
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||||
|
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||||
|
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||||
|
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||||
|
)
|
||||||
|
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
|
except Exception:
|
||||||
|
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not os.path.isdir(args.bm25_index):
|
||||||
|
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
queries = load_queries(args.queries, args.limit)
|
||||||
|
if not queries:
|
||||||
|
print("No queries loaded.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||||
|
print(f"Opening BM25 index: {args.bm25_index}")
|
||||||
|
searcher = LuceneSearcher(args.bm25_index)
|
||||||
|
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||||
|
try:
|
||||||
|
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
total_searches = 0
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for i in range(min(args.warmup, len(queries))):
|
||||||
|
_ = searcher.search(queries[i], k=args.k)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
for i, q in enumerate(queries):
|
||||||
|
t1 = time.time()
|
||||||
|
hits = searcher.search(q, k=args.k)
|
||||||
|
t2 = time.time()
|
||||||
|
latencies.append(t2 - t1)
|
||||||
|
total_searches += 1
|
||||||
|
|
||||||
|
if args.fetch_docs:
|
||||||
|
# Optional doc fetch to include I/O time
|
||||||
|
for h in hits:
|
||||||
|
try:
|
||||||
|
_ = searcher.doc(h.docid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
total_time = t1 - t0
|
||||||
|
|
||||||
|
if latencies:
|
||||||
|
avg = mean(latencies)
|
||||||
|
p50 = percentile(latencies, 50)
|
||||||
|
p90 = percentile(latencies, 90)
|
||||||
|
p95 = percentile(latencies, 95)
|
||||||
|
p99 = percentile(latencies, 99)
|
||||||
|
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||||
|
else:
|
||||||
|
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||||
|
|
||||||
|
print("BM25 Latency Report")
|
||||||
|
print(f" queries: {total_searches}")
|
||||||
|
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||||
|
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||||
|
|
||||||
|
if args.report:
|
||||||
|
payload = {
|
||||||
|
"queries": total_searches,
|
||||||
|
"k": args.k,
|
||||||
|
"k1": args.k1,
|
||||||
|
"b": args.b,
|
||||||
|
"avg_s": avg,
|
||||||
|
"p50_s": p50,
|
||||||
|
"p90_s": p90,
|
||||||
|
"p95_s": p95,
|
||||||
|
"p99_s": p99,
|
||||||
|
"total_time_s": total_time,
|
||||||
|
"qps": qps,
|
||||||
|
"index_dir": os.path.abspath(args.bm25_index),
|
||||||
|
"fetch_docs": bool(args.fetch_docs),
|
||||||
|
}
|
||||||
|
with open(args.report, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(payload, f, indent=2)
|
||||||
|
print(f"Saved report to {args.report}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "leann-backend-diskann"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
obj = json.loads(line)
|
||||||
|
out.append(obj["query"])
|
||||||
|
if limit and len(out) >= limit:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||||
|
help="Directory containing DiskANN files",
|
||||||
|
)
|
||||||
|
ap.add_argument("--index-prefix", default="ann")
|
||||||
|
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||||
|
ap.add_argument("--num-queries", type=int, default=200)
|
||||||
|
ap.add_argument("--top-k", type=int, default=10)
|
||||||
|
ap.add_argument("--complexity", type=int, default=62)
|
||||||
|
ap.add_argument("--threads", type=int, default=1)
|
||||||
|
ap.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||||
|
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
index_dir = Path(args.index_dir).resolve()
|
||||||
|
if not index_dir.is_dir():
|
||||||
|
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||||
|
|
||||||
|
qpath = Path(args.queries_file).resolve()
|
||||||
|
if not qpath.exists():
|
||||||
|
raise SystemExit(f"Queries file not found: {qpath}")
|
||||||
|
|
||||||
|
queries = load_queries(qpath, args.num_queries)
|
||||||
|
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||||
|
|
||||||
|
# Compute embeddings once (exclude from timing)
|
||||||
|
from leann.api import compute_embeddings as _compute
|
||||||
|
|
||||||
|
embs = _compute(
|
||||||
|
queries,
|
||||||
|
model_name="facebook/contriever-msmarco",
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
if embs.ndim != 2:
|
||||||
|
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||||
|
|
||||||
|
# Build searcher
|
||||||
|
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||||
|
|
||||||
|
index_prefix_path = str(index_dir / args.index_prefix)
|
||||||
|
searcher = _DiskannSearcher(
|
||||||
|
index_prefix_path,
|
||||||
|
num_threads=int(args.threads),
|
||||||
|
cache_mechanism=int(args.cache_mechanism),
|
||||||
|
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup (not timed)
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[0:1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timed loop
|
||||||
|
times: list[float] = []
|
||||||
|
for i in range(embs.shape[0]):
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[i : i + 1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
times.append(time.time() - t0)
|
||||||
|
|
||||||
|
times_sorted = sorted(times)
|
||||||
|
avg = float(sum(times) / len(times))
|
||||||
|
p50 = times_sorted[len(times) // 2]
|
||||||
|
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||||
|
|
||||||
|
print("\nDiskANN (NQ, search-only) Report")
|
||||||
|
print(f" queries: {len(times)}")
|
||||||
|
print(
|
||||||
|
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||||
|
)
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||||
|
print(f" QPS: {1.0 / avg:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
82
benchmarks/data/.gitattributes
vendored
82
benchmarks/data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - uncompressed
|
|
||||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - compressed
|
|
||||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - uncompressed
|
|
||||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.png filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - compressed
|
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Video files - compressed
|
|
||||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
44
benchmarks/data/README.md
Executable file
44
benchmarks/data/README.md
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
---
|
||||||
|
|
||||||
|
# LEANN-RAG Evaluation Data
|
||||||
|
|
||||||
|
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||||
|
|
||||||
|
## Dataset Components
|
||||||
|
|
||||||
|
This dataset is structured into three main parts:
|
||||||
|
|
||||||
|
1. **Pre-built LEANN Indices**:
|
||||||
|
* `dpr/`: A pre-built index for the DPR dataset.
|
||||||
|
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||||
|
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||||
|
|
||||||
|
2. **Ground Truth Data**:
|
||||||
|
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||||
|
|
||||||
|
3. **Queries**:
|
||||||
|
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir="data"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||||
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DiskANN vs HNSW Search Performance Comparison
|
||||||
|
|
||||||
|
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||||
|
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||||
|
- HNSW: With recompute enabled (is_recompute=True)
|
||||||
|
- Tests performance across different dataset sizes
|
||||||
|
- Measures search latency, recall, and index size
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import multiprocessing as mp
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||||
|
try:
|
||||||
|
mp.set_start_method("fork", force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_texts(n_docs: int) -> list[str]:
|
||||||
|
"""Create synthetic test documents for benchmarking."""
|
||||||
|
np.random.seed(42)
|
||||||
|
topics = [
|
||||||
|
"machine learning and artificial intelligence",
|
||||||
|
"natural language processing and text analysis",
|
||||||
|
"computer vision and image recognition",
|
||||||
|
"data science and statistical analysis",
|
||||||
|
"deep learning and neural networks",
|
||||||
|
"information retrieval and search engines",
|
||||||
|
"database systems and data management",
|
||||||
|
"software engineering and programming",
|
||||||
|
"cybersecurity and network protection",
|
||||||
|
"cloud computing and distributed systems",
|
||||||
|
]
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
for i in range(n_docs):
|
||||||
|
topic = topics[i % len(topics)]
|
||||||
|
variation = np.random.randint(1, 100)
|
||||||
|
text = (
|
||||||
|
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||||
|
f"Additional information about {topic} with details and examples. "
|
||||||
|
f"Technical discussion of {topic} including implementation aspects."
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_backend(
|
||||||
|
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark a specific backend with the given configuration."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure index size
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||||
|
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||||
|
size_mb = total_size / (1024 * 1024)
|
||||||
|
|
||||||
|
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||||
|
|
||||||
|
# Search benchmark
|
||||||
|
print("🔍 Running search benchmark...")
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||||
|
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||||
|
|
||||||
|
# Check for valid scores (detect -inf issues)
|
||||||
|
all_scores = [
|
||||||
|
result.score
|
||||||
|
for results in all_results
|
||||||
|
for result in results
|
||||||
|
if result.score is not None
|
||||||
|
]
|
||||||
|
valid_scores = [
|
||||||
|
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||||
|
]
|
||||||
|
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||||
|
|
||||||
|
# Clean up (ensure embedding server shutdown and object GC)
|
||||||
|
try:
|
||||||
|
if hasattr(searcher, "cleanup"):
|
||||||
|
searcher.cleanup()
|
||||||
|
del searcher
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"build_time": build_time,
|
||||||
|
"avg_search_time_ms": avg_search_time,
|
||||||
|
"index_size_mb": size_mb,
|
||||||
|
"score_validity_rate": score_validity_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||||
|
"""Run performance comparison between DiskANN and HNSW."""
|
||||||
|
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||||
|
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
texts = create_test_texts(n_docs)
|
||||||
|
test_queries = [
|
||||||
|
"machine learning algorithms",
|
||||||
|
"natural language processing",
|
||||||
|
"computer vision techniques",
|
||||||
|
"data analysis methods",
|
||||||
|
"neural network architectures",
|
||||||
|
"database query optimization",
|
||||||
|
"software development practices",
|
||||||
|
"security vulnerabilities",
|
||||||
|
"cloud infrastructure",
|
||||||
|
"distributed computing",
|
||||||
|
][:n_queries]
|
||||||
|
|
||||||
|
# HNSW benchmark
|
||||||
|
hnsw_results = benchmark_backend(
|
||||||
|
backend_name="hnsw",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable recompute for fair comparison
|
||||||
|
"M": 16,
|
||||||
|
"efConstruction": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiskANN benchmark
|
||||||
|
diskann_results = benchmark_backend(
|
||||||
|
backend_name="diskann",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable graph partitioning
|
||||||
|
"num_neighbors": 32,
|
||||||
|
"search_list_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
print("\n📈 Performance Comparison Results")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||||
|
print(f"{'-' * 60}")
|
||||||
|
|
||||||
|
# Build time comparison
|
||||||
|
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||||
|
print(
|
||||||
|
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search time comparison
|
||||||
|
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||||
|
print(
|
||||||
|
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index size comparison
|
||||||
|
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||||
|
print(
|
||||||
|
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score validity
|
||||||
|
print(
|
||||||
|
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print("\n🎯 Summary:")
|
||||||
|
if search_speedup > 1:
|
||||||
|
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||||
|
else:
|
||||||
|
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||||
|
|
||||||
|
if size_ratio > 1:
|
||||||
|
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||||
|
else:
|
||||||
|
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle help request
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||||
|
print()
|
||||||
|
print("Arguments:")
|
||||||
|
print(" n_docs Number of documents to index (default: 500)")
|
||||||
|
print(" n_queries Number of test queries to run (default: 10)")
|
||||||
|
print()
|
||||||
|
print("Examples:")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||||
|
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||||
|
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||||
|
print()
|
||||||
|
|
||||||
|
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Benchmark interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Benchmark failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||||
|
try:
|
||||||
|
gc.collect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
# Flush stdio to ensure message is visible before hard-exit
|
||||||
|
try:
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
_sys.stdout.flush()
|
||||||
|
_sys.stderr.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||||
|
import os as _os
|
||||||
|
|
||||||
|
_os._exit(0)
|
||||||
141
benchmarks/enron_emails/README.md
Normal file
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Enron Emails Benchmark
|
||||||
|
|
||||||
|
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||||
|
|
||||||
|
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||||
|
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||||
|
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
benchmarks/enron_emails/
|
||||||
|
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||||
|
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||||
|
- data/: Generated passages, queries, embeddings-related files
|
||||||
|
- baseline/: FAISS Flat baseline files
|
||||||
|
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
1) Prepare the data and index
|
||||||
|
|
||||||
|
cd benchmarks/enron_emails
|
||||||
|
python setup_enron_emails.py --data-dir data
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||||
|
Alternatively, pass a local path to `--emails-csv`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||||
|
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||||
|
|
||||||
|
2) Run recall evaluation (Stage 2)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||||
|
|
||||||
|
3) Complexity sweep (Stage 3)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||||
|
|
||||||
|
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||||
|
|
||||||
|
4) Index comparison (Stage 4)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||||
|
|
||||||
|
5) Generation evaluation (Stage 5)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||||
|
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||||
|
- `--baseline-dir` defaults to `baseline`
|
||||||
|
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||||
|
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||||
|
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||||
|
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||||
|
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||||
|
- --baseline-dir baseline (where FAISS baseline lives)
|
||||||
|
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||||
|
- --llm-backend hf|vllm (LLM backend for generation)
|
||||||
|
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||||
|
- --max-queries 1000 (limit number of queries for evaluation)
|
||||||
|
|
||||||
|
## Files Produced
|
||||||
|
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||||
|
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||||
|
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||||
|
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||||
|
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||||
|
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||||
|
|
||||||
|
## Stages Summary
|
||||||
|
|
||||||
|
- Stage 2 (Recall@3):
|
||||||
|
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||||
|
- Compact index runs with `recompute_embeddings=True`.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search for Complexity):
|
||||||
|
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||||
|
|
||||||
|
- Stage 4 (Index Comparison):
|
||||||
|
- Reports .index-only sizes for compact vs non-compact.
|
||||||
|
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||||
|
- Stores retrieval results for Stage 5 generation evaluation.
|
||||||
|
- Fails fast if compact recompute cannot run.
|
||||||
|
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||||
|
- First from the current run (when running `--stage all`), otherwise
|
||||||
|
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||||
|
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||||
|
|
||||||
|
- Stage 5 (Generation Evaluation):
|
||||||
|
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||||
|
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||||
|
- Measures generation timing separately from search timing.
|
||||||
|
- Requires Stage 4 results (no additional searching performed).
|
||||||
|
|
||||||
|
## Example Results
|
||||||
|
|
||||||
|
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search):
|
||||||
|
- Minimal complexity achieving 90% Recall@3: 88
|
||||||
|
- Sampled points:
|
||||||
|
- C=8 → 59.9% Recall@3
|
||||||
|
- C=72 → 89.4% Recall@3
|
||||||
|
- C=88 → 90.2% Recall@3
|
||||||
|
- C=96 → 90.7% Recall@3
|
||||||
|
- C=112 → 91.1% Recall@3
|
||||||
|
- C=136 → 91.3% Recall@3
|
||||||
|
- C=256 → 92.0% Recall@3
|
||||||
|
|
||||||
|
- Stage 4 (Index Sizes, .index only):
|
||||||
|
- Compact: ~2.2 MB
|
||||||
|
- Non-compact: ~82.0 MB
|
||||||
|
- Storage saving by compact: ~97.3%
|
||||||
|
|
||||||
|
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||||
|
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||||
|
- Compact (with recompute): ~1.981 s avg per query
|
||||||
|
- Speed ratio (non-compact/compact): ~0.0038x
|
||||||
|
|
||||||
|
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||||
|
- Average generation time: ~22.302 s per query
|
||||||
|
- Total queries processed: 988
|
||||||
|
- LLM backend: HuggingFace transformers
|
||||||
|
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
|
||||||
|
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||||
|
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
downloads/
|
||||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||||
|
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||||
|
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||||
|
On errors, fail fast without fallbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Compute query embedding with the same model/mode as the index
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS Flat ground truth
|
||||||
|
n = q_emb.shape[0]
|
||||||
|
k = 3
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(q_emb),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN (may require embedding server depending on index configuration)
|
||||||
|
results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {r.id for r in results}
|
||||||
|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN: {list(test_ids)}")
|
||||||
|
print(f" ∩: {list(intersection)}")
|
||||||
|
|
||||||
|
avg = total_recall / max(1, len(queries))
|
||||||
|
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class EnronEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
if "query" in data:
|
||||||
|
queries.append(data["query"])
|
||||||
|
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||||
|
|
||||||
|
sizes["index_only_mb"] = (
|
||||||
|
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source and embedding model
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
# Load all passages and ids
|
||||||
|
ids: list[str] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
texts.append(data["text"])
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
meta["embedding_model"],
|
||||||
|
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False,
|
||||||
|
is_compact=False,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare search speed for non-compact vs compact indexes."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
results: dict = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||||
|
}
|
||||||
|
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
# Non-compact (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Compact (with recompute). Fail fast if it cannot run.
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
docs = compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
results["compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Store retrieval results for Stage 5
|
||||||
|
results["retrieval_results"].append(
|
||||||
|
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||||
|
)
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
if results["non_compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
if results["compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
if results["avg_search_times"].get("compact", 0) > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = 0.0
|
||||||
|
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_complexity(
|
||||||
|
self,
|
||||||
|
recall_eval: "RecallEvaluator",
|
||||||
|
queries: list[str],
|
||||||
|
target: float = 0.90,
|
||||||
|
c_min: int = 8,
|
||||||
|
c_max: int = 256,
|
||||||
|
max_iters: int = 10,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||||
|
|
||||||
|
def round_c(x: int) -> int:
|
||||||
|
# snap to multiple of 8 like other benches typically do
|
||||||
|
return max(1, int((x + 7) // 8) * 8)
|
||||||
|
|
||||||
|
metrics: list[dict] = []
|
||||||
|
|
||||||
|
lo = round_c(c_min)
|
||||||
|
hi = round_c(c_max)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||||
|
r_lo = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=lo, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
cap = 1024
|
||||||
|
while r_hi < target and hi < cap:
|
||||||
|
lo = hi
|
||||||
|
r_lo = r_hi
|
||||||
|
hi = round_c(hi * 2)
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
if r_hi < target:
|
||||||
|
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||||
|
print("📈 Observations:")
|
||||||
|
for m in metrics:
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||||
|
|
||||||
|
# Binary search within [lo, hi]
|
||||||
|
best = hi
|
||||||
|
iters = 0
|
||||||
|
while lo < hi and iters < max_iters:
|
||||||
|
mid = round_c((lo + hi) // 2)
|
||||||
|
r_mid = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=mid, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||||
|
if r_mid >= target:
|
||||||
|
best = mid
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
print("📈 Binary search results (sampled points):")
|
||||||
|
# Print unique complexity entries ordered by complexity
|
||||||
|
for m in sorted(
|
||||||
|
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||||
|
):
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||||
|
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all", "45"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||||
|
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Resolve queries file: if default path not found, fall back to index's directory
|
||||||
|
if not os.path.exists(args.queries):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
idx_dir = Path(args.index).parent
|
||||||
|
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||||
|
if fallback_q.exists():
|
||||||
|
args.queries = str(fallback_q)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
results_out: dict = {}
|
||||||
|
|
||||||
|
if args.stage in ("2", "all"):
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[:10]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||||
|
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("3", "all"):
|
||||||
|
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[: args.max_queries]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact evaluator for binary search with recompute=False
|
||||||
|
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
sweep = enron_eval.evaluate_complexity(
|
||||||
|
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||||
|
)
|
||||||
|
results_out["stage3"] = sweep
|
||||||
|
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"stage3": sweep}, f, indent=2)
|
||||||
|
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||||
|
evaluator_nc.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 3 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("4", "all", "45"):
|
||||||
|
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
test_q = queries[: min(args.max_queries, len(queries))]
|
||||||
|
|
||||||
|
current_sizes = enron_eval.analyze_index_sizes()
|
||||||
|
# Build non-compact index for comparison (no fallback)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||||
|
nc_eval = EnronEvaluator(non_compact_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_sizes.get("index_only_mb", 0) > 0
|
||||||
|
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||||
|
):
|
||||||
|
storage_saving_percent = max(
|
||||||
|
0.0,
|
||||||
|
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_saving_percent = 0.0
|
||||||
|
|
||||||
|
if args.complexity is None:
|
||||||
|
# Prefer in-session Stage 3 result
|
||||||
|
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||||
|
complexity = results_out["stage3"]["best_complexity"]
|
||||||
|
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||||
|
else:
|
||||||
|
# Try to load last saved Stage 3 result near index
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
if default_stage3_path.exists():
|
||||||
|
with open(default_stage3_path, encoding="utf-8") as f:
|
||||||
|
prev = json.load(f)
|
||||||
|
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||||
|
if complexity is None:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||||
|
)
|
||||||
|
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||||
|
else:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
comp = enron_eval.compare_index_performance(
|
||||||
|
non_compact_path, args.index, test_q, complexity=complexity
|
||||||
|
)
|
||||||
|
results_out["stage4"] = {
|
||||||
|
"current_index": current_sizes,
|
||||||
|
"non_compact_index": non_compact_sizes,
|
||||||
|
"storage_saving_percent": storage_saving_percent,
|
||||||
|
"performance_comparison": comp,
|
||||||
|
}
|
||||||
|
nc_eval.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||||
|
|
||||||
|
# Check if Stage 4 results exist
|
||||||
|
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||||
|
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||||
|
print("💡 Run Stage 4 first or use --stage all")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||||
|
if not retrieval_results:
|
||||||
|
print("❌ No retrieval results found from Stage 4")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||||
|
|
||||||
|
# Load LLM
|
||||||
|
try:
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Run generation using stored retrieval results
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_utils import create_prompt
|
||||||
|
|
||||||
|
generation_times = []
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
print("🤖 Running generation on pre-retrieved results...")
|
||||||
|
for i, item in enumerate(retrieval_results):
|
||||||
|
query = item["query"]
|
||||||
|
retrieved_docs = item["retrieved_docs"]
|
||||||
|
|
||||||
|
# Prepare context from retrieved docs
|
||||||
|
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||||
|
prompt = create_prompt(context, query, "emails")
|
||||||
|
|
||||||
|
# Time generation only
|
||||||
|
gen_start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - gen_start
|
||||||
|
|
||||||
|
generation_times.append(gen_time)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||||
|
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Queries: {len(retrieval_results)}")
|
||||||
|
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||||
|
print(" (Search time from Stage 4)")
|
||||||
|
|
||||||
|
results_out["stage5"] = {
|
||||||
|
"total_queries": len(retrieval_results),
|
||||||
|
"avg_generation_time": avg_gen_time,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"responses": responses,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Results:")
|
||||||
|
for i in range(min(3, len(retrieval_results))):
|
||||||
|
query = retrieval_results[i]["query"]
|
||||||
|
response = responses[i]
|
||||||
|
print(f" Q{i + 1}: {query[:60]}...")
|
||||||
|
print(f" A{i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||||
|
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.output and results_out:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results_out, f, indent=2)
|
||||||
|
print(f"📝 Saved results to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Setup Script
|
||||||
|
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from email import message_from_string
|
||||||
|
from email.policy import default
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class EnronSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||||
|
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||||
|
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
self.downloads_dir = self.data_dir / "downloads"
|
||||||
|
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Dataset acquisition
|
||||||
|
# ----------------------------
|
||||||
|
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||||
|
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||||
|
if emails_csv:
|
||||||
|
p = Path(emails_csv)
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
|
|
||||||
|
api = KaggleApi()
|
||||||
|
api.authenticate()
|
||||||
|
api.dataset_download_files(
|
||||||
|
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||||
|
)
|
||||||
|
candidate = self.downloads_dir / "emails.csv"
|
||||||
|
if candidate.exists():
|
||||||
|
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||||
|
return str(candidate)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data preparation
|
||||||
|
# ----------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_id(raw_email: str) -> str:
|
||||||
|
msg = message_from_string(raw_email, policy=default)
|
||||||
|
val = msg.get("Message-ID", "")
|
||||||
|
if val.startswith("<") and val.endswith(">"):
|
||||||
|
val = val[1:-1]
|
||||||
|
return val or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||||
|
parts = raw_email.split("\n\n", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0].strip(), parts[1].strip()
|
||||||
|
# Heuristic fallback
|
||||||
|
first_lines = raw_email.splitlines()
|
||||||
|
if first_lines and ":" in first_lines[0]:
|
||||||
|
return raw_email.strip(), ""
|
||||||
|
return "", raw_email.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||||
|
text = (text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
if chunk_words <= 0:
|
||||||
|
return [text]
|
||||||
|
words = text.split()
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
limit = len(words)
|
||||||
|
if not keep_last:
|
||||||
|
limit = (len(words) // chunk_words) * chunk_words
|
||||||
|
if limit == 0:
|
||||||
|
return []
|
||||||
|
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||||
|
return [c for c in (s.strip() for s in chunks) if c]
|
||||||
|
|
||||||
|
def _iter_passages_from_csv(
|
||||||
|
self,
|
||||||
|
emails_csv: Path,
|
||||||
|
chunk_words: int = 256,
|
||||||
|
keep_last_header: bool = True,
|
||||||
|
keep_last_body: bool = True,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> Iterable[dict]:
|
||||||
|
with open(emails_csv, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
count = 0
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if max_emails is not None and count >= max_emails:
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_message = row.get("message", "")
|
||||||
|
email_file_id = row.get("file", "")
|
||||||
|
|
||||||
|
if not raw_message.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
message_id = self._extract_message_id(raw_message)
|
||||||
|
if not message_id:
|
||||||
|
# Fallback ID based on CSV position and file path
|
||||||
|
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||||
|
message_id = f"enron_{i}_{safe_file}"
|
||||||
|
|
||||||
|
header, body = self._split_header_body(raw_message)
|
||||||
|
|
||||||
|
# Header chunks
|
||||||
|
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": True,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Body chunks
|
||||||
|
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": False,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Build LEANN index and FAISS baseline
|
||||||
|
# ----------------------------
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
emails_csv: Optional[str],
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
chunk_words: int = 256,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||||
|
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream passages and add to builder
|
||||||
|
preview_written = 0
|
||||||
|
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||||
|
for p in self._iter_passages_from_csv(
|
||||||
|
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||||
|
):
|
||||||
|
builder.add_text(p["text"], metadata=p["metadata"])
|
||||||
|
if preview_written < 200:
|
||||||
|
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||||
|
preview_written += 1
|
||||||
|
|
||||||
|
print(f"🔨 Building index at {self.index_path}...")
|
||||||
|
builder.build_index(str(self.index_path))
|
||||||
|
print("✅ LEANN index built!")
|
||||||
|
return str(self.index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||||
|
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read meta for passage source and embedding model
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
# Load passages from builder output so IDs match LEANN
|
||||||
|
passages: list[str] = []
|
||||||
|
passage_ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"]) # builder-assigned ID
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||||
|
print(f"🤖 Embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS IndexFlatIP
|
||||||
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
emb_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||||
|
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as pf:
|
||||||
|
pickle.dump(passage_ids, pf)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved: {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Queries (optional): prepare evaluation queries file
|
||||||
|
# ----------------------------
|
||||||
|
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||||
|
print(
|
||||||
|
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to load dataset: {e}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||||
|
for i, item in enumerate(ds):
|
||||||
|
how_realistic = float(item.get("how_realistic", 0.0))
|
||||||
|
if how_realistic < min_realism:
|
||||||
|
continue
|
||||||
|
qid = str(item.get("id", f"enron_q_{i}"))
|
||||||
|
query = item.get("question", "")
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
record = {
|
||||||
|
"id": qid,
|
||||||
|
"query": query,
|
||||||
|
# For reference only, not used in recall metric below
|
||||||
|
"gt_message_ids": item.get("message_ids", []),
|
||||||
|
}
|
||||||
|
out.write(json.dumps(record) + "\n")
|
||||||
|
kept += 1
|
||||||
|
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--emails-csv",
|
||||||
|
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model for LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||||
|
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||||
|
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup = EnronSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
emails_csv=args.emails_csv,
|
||||||
|
backend=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
chunk_words=args.chunk_words,
|
||||||
|
max_emails=args.max_emails,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS baseline from the same passages & embeddings
|
||||||
|
setup.build_faiss_flat_baseline(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LEANN index build and baseline")
|
||||||
|
|
||||||
|
# Queries file (optional)
|
||||||
|
if not args.skip_queries:
|
||||||
|
setup.prepare_queries()
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping query preparation")
|
||||||
|
|
||||||
|
print("\n🎉 Enron Emails setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("Next steps:")
|
||||||
|
print(
|
||||||
|
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
115
benchmarks/financebench/README.md
Normal file
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# FinanceBench Benchmark for LEANN-RAG
|
||||||
|
|
||||||
|
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||||
|
- **Questions**: 150 financial Q&A examples
|
||||||
|
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||||
|
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||||
|
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/financebench/
|
||||||
|
├── setup_financebench.py # Downloads PDFs and builds index
|
||||||
|
├── evaluate_financebench.py # Intelligent evaluation script
|
||||||
|
├── data/
|
||||||
|
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||||
|
│ ├── pdfs/ # Downloaded financial documents
|
||||||
|
│ └── index/ # LEANN indexes
|
||||||
|
│ └── financebench_full_hnsw.leann
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Setup (Download & Build Index)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/financebench
|
||||||
|
python setup_financebench.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Download the 150 Q&A examples
|
||||||
|
- Download all 368 PDF documents (parallel processing)
|
||||||
|
- Build a LEANN index from 53K+ text chunks
|
||||||
|
- Verify setup with test query
|
||||||
|
|
||||||
|
### 2. Evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic retrieval evaluation
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||||
|
|
||||||
|
|
||||||
|
# RAG generation evaluation with Qwen3-8B
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Methods
|
||||||
|
|
||||||
|
### Retrieval Evaluation
|
||||||
|
Uses intelligent matching with three strategies:
|
||||||
|
1. **Exact text overlap** - Direct substring matches
|
||||||
|
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||||
|
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||||
|
|
||||||
|
### QA Evaluation
|
||||||
|
LLM-based answer evaluation using GPT-4o:
|
||||||
|
- Handles numerical rounding and equivalent representations
|
||||||
|
- Considers fractions, percentages, and decimal equivalents
|
||||||
|
- Evaluates semantic meaning rather than exact text match
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||||
|
|
||||||
|
**Retrieval Metrics:**
|
||||||
|
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||||
|
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||||
|
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||||
|
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||||
|
- **Average Search Time**: 0.097s
|
||||||
|
|
||||||
|
**QA Metrics:**
|
||||||
|
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||||
|
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||||
|
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||||
|
|
||||||
|
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||||
|
|
||||||
|
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||||
|
|
||||||
|
- **Stage 4 (Index Comparison):**
|
||||||
|
- Compact Index: 5.0 MB
|
||||||
|
- Non-compact Index: 172.2 MB
|
||||||
|
- **Storage Saving**: 97.1%
|
||||||
|
- **Search Performance**:
|
||||||
|
- Non-compact (no recompute): 0.009s avg per query
|
||||||
|
- Compact (with recompute): 2.203s avg per query
|
||||||
|
- Speed ratio: 0.004x
|
||||||
|
|
||||||
|
**Generation Evaluation (20 queries, complexity=64):**
|
||||||
|
- **Average Search Time**: 1.638s per query
|
||||||
|
- **Average Generation Time**: 45.957s per query
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
- **Total Questions Processed**: 20
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use different backends
|
||||||
|
python setup_financebench.py --backend diskann
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||||
|
|
||||||
|
# Use different embedding models
|
||||||
|
python setup_financebench.py --embedding-model facebook/contriever
|
||||||
|
```
|
||||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
|||||||
|
"""
|
||||||
|
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
from leann import LeannChat, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(queries)
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Get ground truth: search with FAISS flat
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
query_embedding = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchEvaluator:
|
||||||
|
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||||
|
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||||
|
|
||||||
|
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||||
|
"""Load FinanceBench dataset"""
|
||||||
|
data = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes with and without embeddings"""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes = {}
|
||||||
|
total_with_embeddings = 0
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
for file_path, name in [
|
||||||
|
(index_file, "index"),
|
||||||
|
(meta_file, "metadata"),
|
||||||
|
(passages_file, "passages_text"),
|
||||||
|
(passages_idx_file, "passages_index"),
|
||||||
|
]:
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
sizes[name] = size_mb
|
||||||
|
total_with_embeddings += size_mb
|
||||||
|
|
||||||
|
else:
|
||||||
|
sizes[name] = 0
|
||||||
|
|
||||||
|
sizes["total_with_embeddings"] = total_with_embeddings
|
||||||
|
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||||
|
|
||||||
|
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||||
|
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||||
|
"""Create a compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage
|
||||||
|
**meta.get("backend_kwargs", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||||
|
builder.build_index(compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||||
|
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
compact_sizes["index_type"] = "compact"
|
||||||
|
|
||||||
|
return compact_sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||||
|
builder.build_index(non_compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = [item["question"] for item in test_data[:5]]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_timing_breakdown(
|
||||||
|
self, data: list[dict], max_samples: Optional[int] = None
|
||||||
|
) -> dict:
|
||||||
|
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||||
|
if not self.chat or not self.openai_client:
|
||||||
|
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||||
|
return {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0.0,
|
||||||
|
"avg_generation_time": 0.0,
|
||||||
|
"avg_total_time": 0.0,
|
||||||
|
"accuracy": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||||
|
|
||||||
|
if max_samples:
|
||||||
|
data = data[:max_samples]
|
||||||
|
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
generation_times = []
|
||||||
|
total_times = []
|
||||||
|
correct_answers = 0
|
||||||
|
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
question = item["question"]
|
||||||
|
ground_truth = item["answer"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Hack: Monkey-patch the ask method to capture internal timing
|
||||||
|
original_ask = self.chat.ask
|
||||||
|
captured_search_time = None
|
||||||
|
captured_generation_time = None
|
||||||
|
|
||||||
|
def patched_ask(*args, **kwargs):
|
||||||
|
nonlocal captured_search_time, captured_generation_time
|
||||||
|
|
||||||
|
# Time the search part
|
||||||
|
search_start = time.time()
|
||||||
|
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||||
|
captured_search_time = time.time() - search_start
|
||||||
|
|
||||||
|
# Time the generation part
|
||||||
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
prompt = (
|
||||||
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
f"Question: {args[0]}\n\n"
|
||||||
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_start = time.time()
|
||||||
|
answer = self.chat.llm.ask(prompt)
|
||||||
|
captured_generation_time = time.time() - generation_start
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
self.chat.ask = patched_ask
|
||||||
|
|
||||||
|
# Time the total QA
|
||||||
|
total_start = time.time()
|
||||||
|
generated_answer = self.chat.ask(question)
|
||||||
|
total_time = time.time() - total_start
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
self.chat.ask = original_ask
|
||||||
|
|
||||||
|
# Store the timings
|
||||||
|
search_times.append(captured_search_time)
|
||||||
|
generation_times.append(captured_generation_time)
|
||||||
|
total_times.append(total_time)
|
||||||
|
|
||||||
|
# Check accuracy using LLM as judge
|
||||||
|
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||||
|
if is_correct:
|
||||||
|
correct_answers += 1
|
||||||
|
|
||||||
|
status = "✅" if is_correct else "❌"
|
||||||
|
print(
|
||||||
|
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||||
|
)
|
||||||
|
print(f" GT: {ground_truth}")
|
||||||
|
print(f" Gen: {generated_answer[:100]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error: {e}")
|
||||||
|
search_times.append(0.0)
|
||||||
|
generation_times.append(0.0)
|
||||||
|
total_times.append(0.0)
|
||||||
|
|
||||||
|
accuracy = correct_answers / len(data) if data else 0.0
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"total_questions": len(data),
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||||
|
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||||
|
if generation_times
|
||||||
|
else 0.0,
|
||||||
|
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct_answers": correct_answers,
|
||||||
|
"search_times": search_times,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"total_times": total_times,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _check_answer_accuracy(
|
||||||
|
self, generated_answer: str, ground_truth: str, question: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||||
|
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Ground Truth Answer: {ground_truth}
|
||||||
|
|
||||||
|
Generated Answer: {generated_answer}
|
||||||
|
|
||||||
|
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||||
|
1. Numerical accuracy (exact values, units, currency)
|
||||||
|
2. Key financial concepts and terminology
|
||||||
|
3. Overall factual correctness
|
||||||
|
|
||||||
|
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||||
|
|
||||||
|
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
judge_response = self.openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
messages=[{"role": "user", "content": judge_prompt}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||||
|
return judgment == "CORRECT"
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||||
|
# Fallback to simple string matching
|
||||||
|
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
return gt_clean in gen_clean
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 EVALUATION RESULTS")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Index comparison analysis
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
|
||||||
|
print("\n📊 Accuracy:")
|
||||||
|
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||||
|
print(
|
||||||
|
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Timing Breakdown:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||||
|
|
||||||
|
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||||
|
search_pct = (
|
||||||
|
timing_metrics.get("avg_search_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
gen_pct = (
|
||||||
|
timing_metrics.get("avg_generation_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print("\n📈 Time Distribution:")
|
||||||
|
print(f" Search: {search_pct:.1f}%")
|
||||||
|
print(f" Generation: {gen_pct:.1f}%")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load FinanceBench queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Test with more queries for robust measurement
|
||||||
|
test_queries = queries[:2000]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = FinanceBenchEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Use more queries for robust measurement
|
||||||
|
test_queries = queries[:200]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity (without recompute for speed)
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 32
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||||
|
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||||
|
|
||||||
|
# Use FinanceBench evaluator for QA evaluation
|
||||||
|
evaluator = FinanceBenchEvaluator(
|
||||||
|
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
data = evaluator.load_dataset(args.dataset)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes
|
||||||
|
print("\n📊 Index size comparison:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print("\n📊 Index-only size comparison (.index file only):")
|
||||||
|
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
# Use index-only size for fair comparison (same as Enron emails)
|
||||||
|
storage_saving = (
|
||||||
|
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||||
|
/ non_compact_size_metrics["index_only_mb"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for performance comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Generation evaluation
|
||||||
|
test_samples = 20
|
||||||
|
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||||
|
|
||||||
|
if args.llm_backend == "openai" and args.openai_api_key:
|
||||||
|
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||||
|
evaluation_start = time.time()
|
||||||
|
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||||
|
evaluation_time = time.time() - evaluation_start
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Load LLM
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Simple generation evaluation
|
||||||
|
queries = [item["question"] for item in data[:test_samples]]
|
||||||
|
gen_results = evaluate_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
llm_func,
|
||||||
|
queries,
|
||||||
|
domain="finance",
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": len(queries),
|
||||||
|
"avg_search_time": gen_results["avg_search_time"],
|
||||||
|
"avg_generation_time": gen_results["avg_generation_time"],
|
||||||
|
"results": gen_results["results"],
|
||||||
|
}
|
||||||
|
evaluation_time = time.time()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0,
|
||||||
|
"avg_generation_time": 0,
|
||||||
|
}
|
||||||
|
evaluation_time = 0
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
**timing_metrics,
|
||||||
|
"total_evaluation_time": evaluation_time,
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||||
|
print(" - Run full evaluation on complete dataset if needed")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
462
benchmarks/financebench/setup_financebench.py
Executable file
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FinanceBench Complete Setup Script
|
||||||
|
Downloads all PDFs and builds full LEANN datastore
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import pymupdf
|
||||||
|
import requests
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||||
|
self.data_dir = self.base_dir / data_dir
|
||||||
|
self.pdf_dir = self.data_dir / "pdfs"
|
||||||
|
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||||
|
self.index_dir = self.data_dir / "index"
|
||||||
|
self.download_lock = Lock()
|
||||||
|
|
||||||
|
def download_dataset(self):
|
||||||
|
"""Download the main FinanceBench dataset"""
|
||||||
|
print("📊 Downloading FinanceBench dataset...")
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self.dataset_file.exists():
|
||||||
|
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(self.dataset_file, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||||
|
|
||||||
|
def get_pdf_list(self):
|
||||||
|
"""Get list of all PDF files from GitHub"""
|
||||||
|
print("📋 Fetching PDF list from GitHub...")
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
pdf_files = response.json()
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files")
|
||||||
|
return pdf_files
|
||||||
|
|
||||||
|
def download_single_pdf(self, pdf_info, position):
|
||||||
|
"""Download a single PDF file"""
|
||||||
|
pdf_name = pdf_info["name"]
|
||||||
|
pdf_path = self.pdf_dir / pdf_name
|
||||||
|
|
||||||
|
# Skip if already downloaded
|
||||||
|
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||||
|
return f"✅ {pdf_name} (cached)"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download PDF
|
||||||
|
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Write to file
|
||||||
|
with self.download_lock:
|
||||||
|
with open(pdf_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ {pdf_name}: {e!s}"
|
||||||
|
|
||||||
|
def download_all_pdfs(self, max_workers: int = 5):
|
||||||
|
"""Download all PDF files with parallel processing"""
|
||||||
|
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_files = self.get_pdf_list()
|
||||||
|
|
||||||
|
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# Submit all download tasks
|
||||||
|
future_to_pdf = {
|
||||||
|
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||||
|
for i, pdf_info in enumerate(pdf_files)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process completed downloads with progress bar
|
||||||
|
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||||
|
for future in as_completed(future_to_pdf):
|
||||||
|
result = future.result()
|
||||||
|
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Verify downloads
|
||||||
|
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||||
|
|
||||||
|
# Show any failures
|
||||||
|
missing_pdfs = []
|
||||||
|
for pdf_info in pdf_files:
|
||||||
|
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||||
|
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||||
|
missing_pdfs.append(pdf_info["name"])
|
||||||
|
|
||||||
|
if missing_pdfs:
|
||||||
|
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||||
|
for pdf in missing_pdfs[:5]: # Show first 5
|
||||||
|
print(f" - {pdf}")
|
||||||
|
if len(missing_pdfs) > 5:
|
||||||
|
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||||
|
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
):
|
||||||
|
"""Build LEANN index from all PDFs"""
|
||||||
|
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
# Check if we have PDFs
|
||||||
|
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
if not pdf_files:
|
||||||
|
raise RuntimeError("No PDF files found! Run download first.")
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files to process")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Initialize builder with standard compact configuration
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage (pruned)
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process PDFs and extract text
|
||||||
|
total_chunks = 0
|
||||||
|
failed_pdfs = []
|
||||||
|
|
||||||
|
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||||
|
try:
|
||||||
|
chunks = self.extract_pdf_text(pdf_path)
|
||||||
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||||
|
total_chunks += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||||
|
failed_pdfs.append(pdf_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build index in index directory
|
||||||
|
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||||
|
print(f"🔨 Building index: {index_path}")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
print("✅ Index built successfully!")
|
||||||
|
print(f" 📁 Index path: {index_path}")
|
||||||
|
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||||
|
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||||
|
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||||
|
|
||||||
|
if failed_pdfs:
|
||||||
|
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||||
|
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||||
|
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read metadata from the built index
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.loads(f.read())
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
print(f"📊 Loading passages from {passage_file}...")
|
||||||
|
print(f"🤖 Using embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
# Load all passages for baseline
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages")
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
print("🧮 Computing embeddings...")
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(passage_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||||
|
"""Extract and chunk text from a PDF file"""
|
||||||
|
chunks = []
|
||||||
|
doc = pymupdf.open(pdf_path)
|
||||||
|
|
||||||
|
for page_num in range(len(doc)):
|
||||||
|
page = doc[page_num]
|
||||||
|
text = page.get_text() # type: ignore
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"source_file": pdf_path.name,
|
||||||
|
"page_number": page_num + 1,
|
||||||
|
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||||
|
"company": pdf_path.name.split("_")[0],
|
||||||
|
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use recursive character splitting like LangChain
|
||||||
|
if len(text.split()) > 500:
|
||||||
|
# Split by double newlines (paragraphs)
|
||||||
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||||
|
|
||||||
|
current_chunk = ""
|
||||||
|
for para in paragraphs:
|
||||||
|
# If adding this paragraph would make chunk too long, save current chunk
|
||||||
|
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_chunk = para
|
||||||
|
else:
|
||||||
|
current_chunk = (current_chunk + " " + para).strip()
|
||||||
|
|
||||||
|
# Add the last chunk
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Page is short enough, use as single chunk
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": text.strip(),
|
||||||
|
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_year_from_filename(self, filename: str) -> str:
|
||||||
|
"""Extract year from PDF filename"""
|
||||||
|
# Try to find 4-digit year in filename
|
||||||
|
|
||||||
|
match = re.search(r"(\d{4})", filename)
|
||||||
|
return match.group(1) if match else "unknown"
|
||||||
|
|
||||||
|
def verify_setup(self, index_path: str):
|
||||||
|
"""Verify the setup by testing a simple query"""
|
||||||
|
print("🧪 Verifying setup with test query...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Test query
|
||||||
|
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||||
|
results = searcher.search(test_query, top_k=3)
|
||||||
|
|
||||||
|
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
company = result.metadata.get("company", "Unknown")
|
||||||
|
year = result.metadata.get("doc_period", "Unknown")
|
||||||
|
page = result.metadata.get("page_number", "Unknown")
|
||||||
|
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:100]}...")
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
print("✅ Setup verification completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Setup verification failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
parser.add_argument(
|
||||||
|
"--build-baseline-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only build FAISS baseline from existing index",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🏦 FinanceBench Complete Setup")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
setup = FinanceBenchSetup(args.data_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.build_baseline_only:
|
||||||
|
# Only build baseline from existing index
|
||||||
|
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||||
|
index_file = f"{index_path}.index"
|
||||||
|
meta_file = f"{index_path}.leann.meta.json"
|
||||||
|
|
||||||
|
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||||
|
print("❌ Index files not found:")
|
||||||
|
print(f" Index: {index_file}")
|
||||||
|
print(f" Meta: {meta_file}")
|
||||||
|
print("💡 Run without --build-baseline-only to build the index first")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: Download dataset
|
||||||
|
setup.download_dataset()
|
||||||
|
|
||||||
|
# Step 2: Download PDFs
|
||||||
|
if not args.skip_download:
|
||||||
|
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping PDF download")
|
||||||
|
|
||||||
|
# Step 3: Build LEANN index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
backend=args.backend, embedding_model=args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Building FAISS flat baseline...")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
|
||||||
|
# Step 5: Verify setup
|
||||||
|
setup.verify_setup(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
|
||||||
|
print("\n🎉 FinanceBench setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print(
|
||||||
|
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
214
benchmarks/financebench/verify_recall.py
Normal file
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.9"
|
||||||
|
# dependencies = [
|
||||||
|
# "faiss-cpu",
|
||||||
|
# "numpy",
|
||||||
|
# "sentence-transformers",
|
||||||
|
# "torch",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
"""
|
||||||
|
Independent recall verification script using standard FAISS.
|
||||||
|
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Direct embedding computation using sentence-transformers.
|
||||||
|
Copied logic to avoid dependency issues.
|
||||||
|
"""
|
||||||
|
print(f"Loading model: {model_name}")
|
||||||
|
model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||||
|
embeddings = model.encode(
|
||||||
|
chunks,
|
||||||
|
show_progress_bar=True,
|
||||||
|
batch_size=32,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||||
|
"""Load FinanceBench queries from dataset"""
|
||||||
|
queries = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
if len(queries) >= max_queries:
|
||||||
|
break
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||||
|
"""Load passages from LEANN index structure"""
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
passage_file = index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"Loading passages from {passage_file}")
|
||||||
|
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Loading passages"):
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages)} passages")
|
||||||
|
return passages, passage_ids
|
||||||
|
|
||||||
|
|
||||||
|
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||||
|
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
|
||||||
|
# Build Flat index (ground truth)
|
||||||
|
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||||
|
flat_index = faiss.IndexFlatIP(dimension)
|
||||||
|
flat_index.add(embeddings)
|
||||||
|
|
||||||
|
# Build HNSW index
|
||||||
|
print("Building FAISS IndexHNSWFlat...")
|
||||||
|
M = 32 # Same as LEANN default
|
||||||
|
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||||
|
hnsw_index.add(embeddings)
|
||||||
|
|
||||||
|
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||||
|
return flat_index, hnsw_index
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_recall_at_k(
|
||||||
|
query_embeddings: np.ndarray,
|
||||||
|
flat_index: faiss.Index,
|
||||||
|
hnsw_index: faiss.Index,
|
||||||
|
passage_ids: list[str],
|
||||||
|
k: int = 3,
|
||||||
|
ef_search: int = 64,
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||||
|
|
||||||
|
# Set search parameters for HNSW
|
||||||
|
hnsw_index.hnsw.efSearch = ef_search
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = query_embeddings.shape[0]
|
||||||
|
|
||||||
|
for i in range(num_queries):
|
||||||
|
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||||
|
|
||||||
|
# Get ground truth from Flat index (standard FAISS API)
|
||||||
|
flat_distances, flat_indices = flat_index.search(query, k)
|
||||||
|
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||||
|
|
||||||
|
# Get results from HNSW index (standard FAISS API)
|
||||||
|
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||||
|
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||||
|
|
||||||
|
# Calculate recall
|
||||||
|
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||||
|
recall = len(intersection) / k
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||||
|
print(f" Flat: {list(ground_truth_ids)}")
|
||||||
|
print(f" HNSW: {list(hnsw_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configuration
|
||||||
|
dataset_path = "data/financebench_merged.jsonl"
|
||||||
|
index_path = "data/index/financebench_full_hnsw.leann"
|
||||||
|
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
|
||||||
|
print("🔍 FAISS Recall Verification")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check if files exist
|
||||||
|
if not Path(dataset_path).exists():
|
||||||
|
print(f"❌ Dataset not found: {dataset_path}")
|
||||||
|
return
|
||||||
|
if not Path(f"{index_path}.meta.json").exists():
|
||||||
|
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
print("📖 Loading FinanceBench queries...")
|
||||||
|
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||||
|
print(f"Loaded {len(queries)} queries")
|
||||||
|
|
||||||
|
print("📄 Loading passages from LEANN index...")
|
||||||
|
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
print("🧮 Computing passage embeddings...")
|
||||||
|
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||||
|
|
||||||
|
print("🧮 Computing query embeddings...")
|
||||||
|
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||||
|
|
||||||
|
# Build FAISS indexes
|
||||||
|
print("🏗️ Building FAISS indexes...")
|
||||||
|
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||||
|
|
||||||
|
# Test different efSearch values (equivalent to LEANN complexity)
|
||||||
|
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||||
|
ef_search_values = [16, 32, 64, 128, 256]
|
||||||
|
|
||||||
|
for ef_search in ef_search_values:
|
||||||
|
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
recall = evaluate_recall_at_k(
|
||||||
|
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Verification completed!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||||
|
print(" - Compared recall@3 at different efSearch values")
|
||||||
|
print(" - Used same embedding model as LEANN")
|
||||||
|
print(" - This validates LEANN's recall measurements")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
benchmarks/laion/.gitignore
vendored
Normal file
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
data/
|
||||||
199
benchmarks/laion/README.md
Normal file
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# LAION Multimodal Benchmark
|
||||||
|
|
||||||
|
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This benchmark evaluates:
|
||||||
|
- **Image retrieval timing** using caption-based queries
|
||||||
|
- **Recall@K performance** for image search
|
||||||
|
- **Complexity analysis** across different search parameters
|
||||||
|
- **Index size and storage efficiency**
|
||||||
|
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||||
|
|
||||||
|
## Dataset Configuration
|
||||||
|
|
||||||
|
- **Dataset**: LAION-400M subset (10,000 images)
|
||||||
|
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||||
|
- **Queries**: 200 random captions from the dataset
|
||||||
|
- **Ground Truth**: Self-recall (query caption → original image)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Setup the benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/laion
|
||||||
|
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Create dummy LAION data (10K samples)
|
||||||
|
- Generate CLIP embeddings (512-dim)
|
||||||
|
- Build LEANN index with HNSW backend
|
||||||
|
- Create 200 evaluation queries
|
||||||
|
|
||||||
|
### 2. Run evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all evaluation stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann
|
||||||
|
|
||||||
|
# Run specific stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||||
|
|
||||||
|
# Multimodal generation with Qwen2.5-VL
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Save results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Setup Options
|
||||||
|
```bash
|
||||||
|
python setup_laion.py \
|
||||||
|
--num-samples 10000 \
|
||||||
|
--num-queries 200 \
|
||||||
|
--index-path data/laion_index.leann \
|
||||||
|
--backend hnsw
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation Options
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py \
|
||||||
|
--index data/laion_index.leann \
|
||||||
|
--queries data/evaluation_queries.jsonl \
|
||||||
|
--complexity 64 \
|
||||||
|
--top-k 3 \
|
||||||
|
--num-samples 100 \
|
||||||
|
--stage all
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Stages
|
||||||
|
|
||||||
|
### Stage 2: Recall Evaluation
|
||||||
|
- Evaluates Recall@3 for multimodal retrieval
|
||||||
|
- Compares LEANN vs FAISS baseline performance
|
||||||
|
- Self-recall: query caption should retrieve original image
|
||||||
|
|
||||||
|
### Stage 3: Complexity Analysis
|
||||||
|
- Binary search for optimal complexity (90% recall target)
|
||||||
|
- Tests performance across different complexity levels
|
||||||
|
- Analyzes speed vs. accuracy tradeoffs
|
||||||
|
|
||||||
|
### Stage 4: Index Comparison
|
||||||
|
- Compares compact vs non-compact index sizes
|
||||||
|
- Measures search performance differences
|
||||||
|
- Reports storage efficiency and speed ratios
|
||||||
|
|
||||||
|
### Stage 5: Multimodal Generation
|
||||||
|
- Uses Qwen2.5-VL for image understanding and description
|
||||||
|
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||||
|
- Measures both search and generation timing
|
||||||
|
|
||||||
|
## Output Metrics
|
||||||
|
|
||||||
|
### Timing Metrics
|
||||||
|
- Average/median/min/max search time
|
||||||
|
- Standard deviation
|
||||||
|
- Searches per second
|
||||||
|
- Latency in milliseconds
|
||||||
|
|
||||||
|
### Recall Metrics
|
||||||
|
- Recall@3 percentage for image retrieval
|
||||||
|
- Number of queries with ground truth
|
||||||
|
|
||||||
|
### Index Metrics
|
||||||
|
- Total index size (MB)
|
||||||
|
- Component breakdown (index, passages, metadata)
|
||||||
|
- Storage savings (compact vs non-compact)
|
||||||
|
- Backend and embedding model info
|
||||||
|
|
||||||
|
### Generation Metrics (Stage 5)
|
||||||
|
- Average search time per query
|
||||||
|
- Average generation time per query
|
||||||
|
- Time distribution (search vs generation)
|
||||||
|
- Sample multimodal responses
|
||||||
|
- Model: Qwen2.5-VL performance
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||||
|
|
||||||
|
**Stage 3: Optimal Complexity Analysis**
|
||||||
|
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||||
|
- **Binary Search Range**: 1-128
|
||||||
|
- **Target Recall**: 90%
|
||||||
|
- **Index Type**: Non-compact (for fast binary search)
|
||||||
|
|
||||||
|
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||||
|
- **Total Queries**: 20
|
||||||
|
- **Average Search Time**: 1.200s per query
|
||||||
|
- **Average Generation Time**: 6.558s per query
|
||||||
|
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
- **Optimal Complexity**: 85
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||||
|
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||||
|
- **Backend**: HNSW with cosine distance
|
||||||
|
|
||||||
|
### Example Results
|
||||||
|
|
||||||
|
```
|
||||||
|
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
📊 Multimodal Generation Results:
|
||||||
|
Total Queries: 20
|
||||||
|
Avg Search Time: 1.200s
|
||||||
|
Avg Generation Time: 6.558s
|
||||||
|
Time Distribution: Search 15.5%, Generation 84.5%
|
||||||
|
LLM Backend: HuggingFace transformers
|
||||||
|
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
⚙️ Optimal Complexity Analysis:
|
||||||
|
Target Recall: 90%
|
||||||
|
Optimal Complexity: 85
|
||||||
|
Binary Search Range: 1-128
|
||||||
|
Non-compact Index (fast search, no recompute)
|
||||||
|
|
||||||
|
🚀 Performance Summary:
|
||||||
|
Multimodal RAG: 7.758s total per query
|
||||||
|
Search: 15.5% of total time
|
||||||
|
Generation: 84.5% of total time
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/laion/
|
||||||
|
├── setup_laion.py # Setup script
|
||||||
|
├── evaluate_laion.py # Evaluation script
|
||||||
|
├── README.md # This file
|
||||||
|
└── data/ # Generated data
|
||||||
|
├── laion_images/ # Image files (placeholder)
|
||||||
|
├── laion_metadata.jsonl # Image metadata
|
||||||
|
├── laion_passages.jsonl # LEANN passages
|
||||||
|
├── laion_embeddings.npy # CLIP embeddings
|
||||||
|
├── evaluation_queries.jsonl # Evaluation queries
|
||||||
|
└── laion_index.leann/ # LEANN index files
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Current implementation uses dummy data for demonstration
|
||||||
|
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||||
|
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||||
|
- Adjust `num_samples` and `num_queries` based on available resources
|
||||||
|
- Consider using `--num-samples` during evaluation for faster testing
|
||||||
725
benchmarks/laion/evaluate_laion.py
Normal file
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline (image embeddings)
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.image_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||||
|
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(captions)
|
||||||
|
|
||||||
|
for i, caption in enumerate(captions):
|
||||||
|
# Get ground truth: search with FAISS flat using caption text embedding
|
||||||
|
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||||
|
query_embedding = self.st_clip.encode(
|
||||||
|
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results (image IDs from FAISS)
|
||||||
|
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity (using caption as text query)
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
caption,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
"""Load caption queries from evaluation file"""
|
||||||
|
captions = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
query_data = json.loads(line)
|
||||||
|
captions.append(query_data["query"])
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(captions)} caption queries")
|
||||||
|
return captions
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
# Core index size (.index only)
|
||||||
|
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
sizes["index_only_mb"] = index_mb
|
||||||
|
|
||||||
|
# Other files for reference (not counted in index_only_mb)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||||
|
if sizes["metadata_mb"]:
|
||||||
|
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||||
|
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||||
|
print(
|
||||||
|
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Load CLIP embeddings
|
||||||
|
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
distance_metric="cosine",
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare ids and add passages
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
# Ensure metadata contains the id used by the vector index
|
||||||
|
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||||
|
builder.add_text(text=data["text"], metadata=metadata)
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = test_captions[:5]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Index comparison analysis (prefer .index-only view if present)
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||||
|
print("\n📏 Index Comparison Analysis (.index only):")
|
||||||
|
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
# Show excluded components for reference if available
|
||||||
|
if any(
|
||||||
|
k in non_compact
|
||||||
|
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||||
|
):
|
||||||
|
print(" (passages excluded in totals, shown for reference):")
|
||||||
|
print(
|
||||||
|
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||||
|
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||||
|
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to legacy totals if running with older metrics
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(
|
||||||
|
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend",
|
||||||
|
choices=["hf"],
|
||||||
|
default="hf",
|
||||||
|
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_laion.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
laion_evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = laion_evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Test with queries for robust measurement
|
||||||
|
test_captions = captions[:100] # Use subset for speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Use subset for robust measurement
|
||||||
|
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 128
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Index comparison (without LLM generation)
|
||||||
|
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||||
|
|
||||||
|
# Use LAION evaluator for index comparison
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
|
||||||
|
# Load caption queries
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes (.index only)
|
||||||
|
print("\n📊 Index size comparison (.index only):")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||||
|
|
||||||
|
storage_saving = 0.0
|
||||||
|
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||||
|
storage_saving = (
|
||||||
|
(
|
||||||
|
non_compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
- compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
)
|
||||||
|
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for index comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print comprehensive results
|
||||||
|
evaluator._print_results(combined_metrics)
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||||
|
|
||||||
|
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||||
|
|
||||||
|
# Load Qwen2.5-VL model
|
||||||
|
try:
|
||||||
|
print("Loading Qwen2.5-VL model...")
|
||||||
|
processor, model = load_qwen_vl_model(args.model_name)
|
||||||
|
|
||||||
|
# Run multimodal generation evaluation
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
gen_results = evaluate_multimodal_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
test_captions,
|
||||||
|
processor=processor,
|
||||||
|
model=model,
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Multimodal Generation Results:")
|
||||||
|
print(f" Total Queries: {len(test_captions)}")
|
||||||
|
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||||
|
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||||
|
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||||
|
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||||
|
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||||
|
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||||
|
print(" LLM Backend: HuggingFace transformers")
|
||||||
|
print(f" Model: {args.model_name}")
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Multimodal Generations:")
|
||||||
|
for i, response in enumerate(gen_results["results"][:3]):
|
||||||
|
# Handle both string and dict formats for captions
|
||||||
|
if isinstance(test_captions[i], dict):
|
||||||
|
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||||
|
else:
|
||||||
|
caption_text = str(test_captions[i])
|
||||||
|
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||||
|
print(f" Response {i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||||
|
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
576
benchmarks/laion/setup_laion.py
Normal file
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Setup Script
|
||||||
|
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann import LeannBuilder
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.images_dir = self.data_dir / "laion_images"
|
||||||
|
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
self.data_dir.mkdir(exist_ok=True)
|
||||||
|
self.images_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||||
|
"""Download a single image asynchronously"""
|
||||||
|
async with semaphore: # Limit concurrent downloads
|
||||||
|
try:
|
||||||
|
image_url = sample_data["url"]
|
||||||
|
image_path = sample_data["image_path"]
|
||||||
|
|
||||||
|
# Skip if already exists
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
|
||||||
|
async with session.get(image_url, timeout=10) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
|
# Verify it's a valid image
|
||||||
|
try:
|
||||||
|
img = Image.open(io.BytesIO(content))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(image_path, "JPEG")
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None # Skip invalid images
|
||||||
|
else:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def download_laion_subset(self, num_samples: int = 1000):
|
||||||
|
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||||
|
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||||
|
|
||||||
|
# Load LAION-400M subset from HuggingFace
|
||||||
|
print("🤗 Loading from HuggingFace datasets...")
|
||||||
|
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||||
|
|
||||||
|
# Collect sample metadata first (fast)
|
||||||
|
print("📋 Collecting sample metadata...")
|
||||||
|
candidates = []
|
||||||
|
for sample in dataset:
|
||||||
|
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||||
|
break
|
||||||
|
|
||||||
|
image_url = sample.get("url", "")
|
||||||
|
caption = sample.get("caption", "")
|
||||||
|
|
||||||
|
if not image_url or not caption:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||||
|
image_path = self.images_dir / image_filename
|
||||||
|
|
||||||
|
candidate = {
|
||||||
|
"id": f"laion_{len(candidates):06d}",
|
||||||
|
"url": image_url,
|
||||||
|
"caption": caption,
|
||||||
|
"image_path": str(image_path),
|
||||||
|
"width": sample.get("original_width", 512),
|
||||||
|
"height": sample.get("original_height", 512),
|
||||||
|
"similarity": sample.get("similarity", 0.0),
|
||||||
|
}
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download images in parallel
|
||||||
|
async def download_batch():
|
||||||
|
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||||
|
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
|
||||||
|
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||||
|
tasks = []
|
||||||
|
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||||
|
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all downloads
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
# Filter successful downloads
|
||||||
|
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||||
|
return successful[:num_samples]
|
||||||
|
|
||||||
|
# Run async download
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
samples = loop.run_until_complete(download_batch())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in samples:
|
||||||
|
f.write(json.dumps(sample) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||||
|
"""Generate CLIP image embeddings for downloaded images"""
|
||||||
|
print("🔍 Generating CLIP image embeddings...")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||||
|
# This single model can encode both images and text.
|
||||||
|
model = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
valid_samples = []
|
||||||
|
|
||||||
|
for sample in tqdm(samples, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
# Load image
|
||||||
|
image_path = sample["image_path"]
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||||
|
vec = model.encode(
|
||||||
|
[image],
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=1,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)[0]
|
||||||
|
embeddings.append(vec.astype(np.float32))
|
||||||
|
valid_samples.append(sample)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||||
|
# Skip invalid images
|
||||||
|
|
||||||
|
embeddings = np.array(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Save embeddings
|
||||||
|
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(embeddings_file, embeddings)
|
||||||
|
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
return embeddings, valid_samples
|
||||||
|
|
||||||
|
def build_faiss_baseline(
|
||||||
|
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||||
|
):
|
||||||
|
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Extract image IDs (must be present)
|
||||||
|
if not samples or "id" not in samples[0]:
|
||||||
|
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||||
|
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
print(f"📄 Processing {len(image_ids)} images")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(image_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def create_leann_passages(self, samples: list[dict]):
|
||||||
|
"""Create LEANN-compatible passages from LAION data"""
|
||||||
|
print("📝 Creating LEANN passages...")
|
||||||
|
|
||||||
|
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||||
|
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
passage = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"text": sample["caption"], # Use caption as searchable text
|
||||||
|
"metadata": {
|
||||||
|
"image_url": sample["url"],
|
||||||
|
"image_path": sample.get("image_path", ""),
|
||||||
|
"width": sample["width"],
|
||||||
|
"height": sample["height"],
|
||||||
|
"similarity": sample["similarity"],
|
||||||
|
"image_index": i, # Index for embedding lookup
|
||||||
|
},
|
||||||
|
}
|
||||||
|
f.write(json.dumps(passage) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(samples)} passages")
|
||||||
|
return passages_file
|
||||||
|
|
||||||
|
def build_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||||
|
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
|
||||||
|
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - compact with recompute
|
||||||
|
# Note: For multimodal case, we need to handle embeddings differently
|
||||||
|
# Let's try using sentence-transformers mode but with custom embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
# HNSW params (or forwarded to chosen backend)
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
# Compact/pruned with recompute at query time
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages (text + metadata)
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def build_non_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||||
|
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ensure embeddings are saved (npy + pickle)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
if not npy_path.exists():
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
# Prepare ids in same order as passages_file
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
if not pkl_path.exists():
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - non-compact without recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=False, # Store embeddings (no recompute needed)
|
||||||
|
is_compact=False, # Store full index (not pruned)
|
||||||
|
distance_metric="cosine",
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages - embeddings will be loaded from file
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||||
|
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Adding passages"):
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
|
||||||
|
# Add image metadata - LEANN will handle embeddings separately
|
||||||
|
# Note: We store image metadata and caption text for searchability
|
||||||
|
# Important: ensure passage ID in metadata matches vector ID
|
||||||
|
builder.add_text(
|
||||||
|
text=passage["text"], # Image caption for searchability
|
||||||
|
metadata={**passage["metadata"], "id": passage["id"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_index_size(self, index_path: str):
|
||||||
|
"""Analyze index file sizes"""
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
index_path = Path(index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.name # e.g., laion_index.leann
|
||||||
|
index_prefix = index_path.stem # e.g., laion_index
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(f"{index_prefix}.index", ".index", "core"),
|
||||||
|
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||||
|
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||||
|
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||||
|
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _fmt_size(bytes_val: int) -> str:
|
||||||
|
if bytes_val < 1024:
|
||||||
|
return f"{bytes_val} B"
|
||||||
|
kb = bytes_val / 1024
|
||||||
|
if kb < 1024:
|
||||||
|
return f"{kb:.1f} KB"
|
||||||
|
mb = kb / 1024
|
||||||
|
if mb < 1024:
|
||||||
|
return f"{mb:.2f} MB"
|
||||||
|
gb = mb / 1024
|
||||||
|
return f"{gb:.2f} GB"
|
||||||
|
|
||||||
|
total_index_only_mb = 0.0
|
||||||
|
total_all_mb = 0.0
|
||||||
|
for filename, label, group in files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_bytes = file_path.stat().st_size
|
||||||
|
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||||
|
size_mb = size_bytes / (1024 * 1024)
|
||||||
|
total_all_mb += size_mb
|
||||||
|
if group == "core":
|
||||||
|
total_index_only_mb += size_mb
|
||||||
|
else:
|
||||||
|
print(f" {label}: (missing)")
|
||||||
|
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||||
|
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||||
|
|
||||||
|
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||||
|
"""Create evaluation queries from captions"""
|
||||||
|
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||||
|
|
||||||
|
# Sample random captions as queries
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(42) # For reproducibility
|
||||||
|
|
||||||
|
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||||
|
|
||||||
|
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
with open(queries_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in query_samples:
|
||||||
|
query = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"query": sample["caption"],
|
||||||
|
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||||
|
}
|
||||||
|
f.write(json.dumps(query) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||||
|
return queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||||
|
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||||
|
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize setup
|
||||||
|
setup = LAIONSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Step 1: Download LAION subset
|
||||||
|
if not args.skip_download:
|
||||||
|
print("\n📦 Step 1: Download LAION subset")
|
||||||
|
samples = setup.download_laion_subset(args.num_samples)
|
||||||
|
|
||||||
|
# Step 2: Generate CLIP image embeddings
|
||||||
|
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||||
|
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||||
|
|
||||||
|
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||||
|
print("\n📝 Step 3: Create LEANN passages")
|
||||||
|
passages_file = setup.create_leann_passages(valid_samples)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LAION dataset download")
|
||||||
|
# Load existing data
|
||||||
|
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||||
|
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not embeddings_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passages or embeddings file not found. Run without --skip-download first."
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||||
|
|
||||||
|
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||||
|
if not args.skip_build:
|
||||||
|
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||||
|
|
||||||
|
# Build compact index (production mode - small, recompute required)
|
||||||
|
compact_index_path = args.index_path
|
||||||
|
print(f"Building compact index: {compact_index_path}")
|
||||||
|
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||||
|
|
||||||
|
# Build non-compact index (comparison mode - large, fast search)
|
||||||
|
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||||
|
print(f"Building non-compact index: {non_compact_index_path}")
|
||||||
|
setup.build_non_compact_index(
|
||||||
|
passages_file, embeddings, non_compact_index_path, args.backend
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||||
|
if not args.skip_download:
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
else:
|
||||||
|
# Load valid_samples from passages file for FAISS baseline
|
||||||
|
valid_samples = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
|
||||||
|
# Step 6: Create evaluation queries
|
||||||
|
print("\n📝 Step 6: Create evaluation queries")
|
||||||
|
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
baseline_path = "data/baseline/faiss_index.bin"
|
||||||
|
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||||
|
|
||||||
|
print("\n🎉 Setup completed successfully!")
|
||||||
|
print("📊 Summary:")
|
||||||
|
if not args.skip_download:
|
||||||
|
print(f" Downloaded samples: {len(samples)}")
|
||||||
|
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||||
|
else:
|
||||||
|
print(f" Loaded {len(embeddings)} embeddings")
|
||||||
|
|
||||||
|
if not args.skip_build:
|
||||||
|
print(f" Compact index: {compact_index_path}")
|
||||||
|
print(f" Non-compact index: {non_compact_index_path}")
|
||||||
|
print(f" FAISS baseline: {baseline_path}")
|
||||||
|
print(f" Queries: {queries_file}")
|
||||||
|
|
||||||
|
print("\n🔧 Next steps:")
|
||||||
|
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||||
|
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print(" Skipped building indexes")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
301
benchmarks/llm_utils.py
Normal file
301
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
HF_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
HF_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen3_model(model_name):
|
||||||
|
"""Check if model is Qwen3"""
|
||||||
|
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen_vl_model(model_name):
|
||||||
|
"""Check if model is Qwen2.5-VL"""
|
||||||
|
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||||
|
"""Apply Qwen3 chat template with thinking enabled"""
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_thinking_answer(response):
|
||||||
|
"""Extract final answer from Qwen3 thinking model response"""
|
||||||
|
if "<think>" in response and "</think>" in response:
|
||||||
|
try:
|
||||||
|
think_end = response.index("</think>") + len("</think>")
|
||||||
|
final_answer = response[think_end:].strip()
|
||||||
|
return final_answer
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_model(model_name="Qwen/Qwen3-8B"):
|
||||||
|
"""Load HuggingFace model"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
print(f"Loading HF: {model_name}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
|
||||||
|
"""Load vLLM model"""
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm not available")
|
||||||
|
|
||||||
|
print(f"Loading vLLM: {model_name}")
|
||||||
|
llm = LLM(model=model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Qwen3 specific config
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||||
|
max_tokens = 2048
|
||||||
|
else:
|
||||||
|
stop_tokens = None
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||||
|
return llm, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||||
|
"""Generate with HF - supports Qwen3 thinking models"""
|
||||||
|
model_name = getattr(model, "name_or_path", "unknown")
|
||||||
|
is_qwen3 = is_qwen3_model(model_name)
|
||||||
|
|
||||||
|
# Apply chat template for Qwen3
|
||||||
|
if is_qwen3:
|
||||||
|
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||||
|
max_tokens = max_tokens or 2048
|
||||||
|
else:
|
||||||
|
max_tokens = max_tokens or 1024
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
temperature=0.7,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
response = response[len(prompt) :].strip()
|
||||||
|
|
||||||
|
# Extract final answer for thinking models
|
||||||
|
if is_qwen3:
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vllm(llm, sampling_params, prompt):
|
||||||
|
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||||
|
outputs = llm.generate([prompt], sampling_params)
|
||||||
|
response = outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
# Extract final answer for Qwen3 thinking models
|
||||||
|
model_name = str(llm.llm_engine.model_config.model)
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_prompt(context, query, domain="default"):
|
||||||
|
"""Create RAG prompt"""
|
||||||
|
if domain == "emails":
|
||||||
|
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "finance":
|
||||||
|
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "multimodal":
|
||||||
|
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
else:
|
||||||
|
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||||
|
"""Simple RAG evaluation with timing"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Search
|
||||||
|
start = time.time()
|
||||||
|
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||||
|
search_time = time.time() - start
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
context = "\n\n".join([doc.text for doc in docs])
|
||||||
|
prompt = create_prompt(context, query, domain)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - start
|
||||||
|
|
||||||
|
search_times.append(search_time)
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
|
||||||
|
"""Load Qwen2.5-VL multimodal model"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||||
|
|
||||||
|
# Fallback to specific class
|
||||||
|
try:
|
||||||
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e2:
|
||||||
|
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||||
|
"""Generate with Qwen2.5-VL multimodal model"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
if image_path:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||||
|
else:
|
||||||
|
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode response
|
||||||
|
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||||
|
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||||
|
"""Create prompt for multimodal RAG"""
|
||||||
|
if task_type == "images":
|
||||||
|
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||||
|
|
||||||
|
Retrieved Image Descriptions:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
|
||||||
|
Provide a detailed answer based on the visual content described above."""
|
||||||
|
|
||||||
|
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||||
|
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query_item in enumerate(queries):
|
||||||
|
# Handle both string and dict formats for queries
|
||||||
|
if isinstance(query_item, dict):
|
||||||
|
query = query_item.get("query", "")
|
||||||
|
image_path = query_item.get("image_path") # Optional reference image
|
||||||
|
else:
|
||||||
|
query = str(query_item)
|
||||||
|
image_path = None
|
||||||
|
|
||||||
|
# Search
|
||||||
|
start_time = time.time()
|
||||||
|
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
|
||||||
|
# Prepare context from search results
|
||||||
|
context_parts = []
|
||||||
|
for result in search_results:
|
||||||
|
context_parts.append(f"- {result.text}")
|
||||||
|
context = "\n".join(context_parts)
|
||||||
|
|
||||||
|
# Generate with multimodal model
|
||||||
|
start_time = time.time()
|
||||||
|
if processor and model:
|
||||||
|
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||||
|
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||||
|
else:
|
||||||
|
response = f"Context: {context}"
|
||||||
|
gen_time = time.time() - start_time
|
||||||
|
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
@@ -53,7 +53,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|||||||
print(
|
print(
|
||||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
)
|
)
|
||||||
print("uv pip install -e '.[dev]'")
|
print("uv sync --only-group dev")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred during data download: {e}")
|
print(f"An error occurred during data download: {e}")
|
||||||
@@ -197,6 +197,25 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
@@ -318,9 +337,24 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
new_results = searcher.search(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ except ImportError:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: list[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
@@ -179,10 +179,16 @@ class Benchmark:
|
|||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|||||||
82
data/huawei_pangu.md
Normal file
82
data/huawei_pangu.md
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||||
|
|
||||||
|
首先为自证身份,列举一些细节:
|
||||||
|
|
||||||
|
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||||
|
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||||
|
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||||
|
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||||
|
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||||
|
|
||||||
|
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||||
|
|
||||||
|
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||||
|
|
||||||
|
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||||
|
|
||||||
|
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||||
|
|
||||||
|
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||||
|
|
||||||
|
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||||
|
|
||||||
|
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||||
|
|
||||||
|
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||||
|
|
||||||
|
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||||
|
|
||||||
|
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||||
|
|
||||||
|
|
||||||
|
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||||
|
|
||||||
|
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||||
|
|
||||||
|
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||||
|
|
||||||
|
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||||
|
|
||||||
|
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||||
|
|
||||||
|
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||||
|
|
||||||
|
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||||
|
|
||||||
|
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||||
|
|
||||||
|
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||||
|
|
||||||
|
现在,我累了,我想投降。
|
||||||
|
|
||||||
|
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||||
|
|
||||||
|
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||||
|
|
||||||
|
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||||
|
|
||||||
|
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||||
|
|
||||||
|
诺亚,再见
|
||||||
|
|
||||||
|
2025年7月6日凌晨 写于深圳
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||||
|
|
||||||
|
我补充一些细节,以免某些人继续颠倒黑白。
|
||||||
|
|
||||||
|
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||||
|
|
||||||
|
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||||
@@ -53,9 +53,9 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
|||||||
|
|
||||||
### Setup Pre-commit
|
### Setup Pre-commit
|
||||||
|
|
||||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
1. **Install pre-commit tools**:
|
||||||
```bash
|
```bash
|
||||||
uv pip install pre-commit
|
uv sync lint
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Install the git hooks**:
|
2. **Install the git hooks**:
|
||||||
@@ -65,7 +65,7 @@ We use pre-commit hooks to ensure code quality and consistency. This runs automa
|
|||||||
|
|
||||||
3. **Run pre-commit manually** (optional):
|
3. **Run pre-commit manually** (optional):
|
||||||
```bash
|
```bash
|
||||||
pre-commit run --all-files
|
uv run pre-commit run --all-files
|
||||||
```
|
```
|
||||||
|
|
||||||
### Pre-commit Checks
|
### Pre-commit Checks
|
||||||
@@ -85,6 +85,9 @@ Our pre-commit configuration includes:
|
|||||||
### Running Tests
|
### Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Install test tools only (no project runtime)
|
||||||
|
uv sync --group test
|
||||||
|
|
||||||
# Run all tests
|
# Run all tests
|
||||||
uv run pytest
|
uv run pytest
|
||||||
|
|
||||||
|
|||||||
143
docs/ast_chunking_guide.md
Normal file
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# AST-Aware Code chunking guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Enable AST chunking for mixed content (code + docs)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||||
|
|
||||||
|
# Specialized code repository indexing
|
||||||
|
python -m apps.code_rag --repo-dir ./my_codebase
|
||||||
|
|
||||||
|
# Global CLI with AST support
|
||||||
|
leann build my-code-index --docs ./src --use-ast-chunking
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN with AST chunking support
|
||||||
|
uv pip install -e "."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### When to Use AST Chunking
|
||||||
|
|
||||||
|
✅ **Recommended for:**
|
||||||
|
- Code repositories with multiple languages
|
||||||
|
- Mixed documentation and code content
|
||||||
|
- Complex codebases with deep function/class hierarchies
|
||||||
|
- When working with Claude Code for code assistance
|
||||||
|
|
||||||
|
❌ **Not recommended for:**
|
||||||
|
- Pure text documents
|
||||||
|
- Very large files (>1MB)
|
||||||
|
- Languages not supported by tree-sitter
|
||||||
|
|
||||||
|
### Optimal Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended settings for most codebases
|
||||||
|
python -m apps.code_rag \
|
||||||
|
--repo-dir ./src \
|
||||||
|
--ast-chunk-size 768 \
|
||||||
|
--ast-chunk-overlap 96 \
|
||||||
|
--exclude-dirs .git __pycache__ node_modules build dist
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Languages
|
||||||
|
|
||||||
|
| Extension | Language | Status |
|
||||||
|
|-----------|----------|--------|
|
||||||
|
| `.py` | Python | ✅ Full support |
|
||||||
|
| `.java` | Java | ✅ Full support |
|
||||||
|
| `.cs` | C# | ✅ Full support |
|
||||||
|
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||||
|
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||||
|
|
||||||
|
## Integration Examples
|
||||||
|
|
||||||
|
### Document RAG with Code Support
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Enable code chunking in document RAG
|
||||||
|
python -m apps.document_rag \
|
||||||
|
--enable-code-chunking \
|
||||||
|
--data-dir ./project \
|
||||||
|
--query "How does authentication work in the codebase?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Code Integration
|
||||||
|
|
||||||
|
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||||
|
- Code completion and suggestions
|
||||||
|
- Bug analysis and debugging
|
||||||
|
- Architecture understanding
|
||||||
|
- Refactoring assistance
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Fallback to Traditional Chunking**
|
||||||
|
- Normal behavior for unsupported languages
|
||||||
|
- Check logs for specific language support
|
||||||
|
|
||||||
|
2. **Performance with Large Files**
|
||||||
|
- Adjust `--max-file-size` parameter
|
||||||
|
- Use `--exclude-dirs` to skip unnecessary directories
|
||||||
|
|
||||||
|
3. **Quality Issues**
|
||||||
|
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||||
|
- Adjust overlap for better context preservation
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LEANN_LOG_LEVEL=DEBUG
|
||||||
|
python -m apps.code_rag --repo-dir ./my_code
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Traditional Chunking
|
||||||
|
|
||||||
|
Existing workflows continue to work without changes. To enable AST chunking:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Before
|
||||||
|
python -m apps.document_rag --chunk-size 256
|
||||||
|
|
||||||
|
# After (maintains traditional chunking for non-code files)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||||
|
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||||
|
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||||
@@ -49,14 +49,25 @@ Based on our experience developing LEANN, embedding models fall into three categ
|
|||||||
- **Cons**: Slower inference, longer index build times
|
- **Cons**: Slower inference, longer index build times
|
||||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||||
|
|
||||||
### Quick Start: OpenAI Embeddings (Fastest Setup)
|
### Quick Start: Cloud and Local Embedding Options
|
||||||
|
|
||||||
For immediate testing without local model downloads:
|
**OpenAI Embeddings (Fastest Setup)**
|
||||||
|
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
||||||
```bash
|
```bash
|
||||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Ollama Embeddings (Privacy-Focused)**
|
||||||
|
For local embeddings with complete privacy:
|
||||||
|
```bash
|
||||||
|
# First, pull an embedding model
|
||||||
|
ollama pull nomic-embed-text
|
||||||
|
|
||||||
|
# Use Ollama embeddings
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||||
|
|
||||||
@@ -72,6 +83,81 @@ For immediate testing without local model downloads:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Local & Remote Inference Endpoints
|
||||||
|
|
||||||
|
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||||
|
|
||||||
|
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||||
|
|
||||||
|
### One-Time Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||||
|
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||||
|
|
||||||
|
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||||
|
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||||
|
|
||||||
|
### Passing Hosts Per Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index with a remote embedding server
|
||||||
|
leann build my-notes \
|
||||||
|
--docs ./notes \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||||
|
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||||
|
--embedding-api-key local-dev-key
|
||||||
|
|
||||||
|
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm openai \
|
||||||
|
--llm-model qwen3-8b \
|
||||||
|
--api-base http://localhost:1234/v1 \
|
||||||
|
--api-key local-dev-key
|
||||||
|
|
||||||
|
# Query an Ollama instance running on another box
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm ollama \
|
||||||
|
--llm-model qwen3:14b \
|
||||||
|
--host http://192.168.1.101:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||||
|
|
||||||
|
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||||
|
- Configure router or cloud provider port forwarding.
|
||||||
|
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||||
|
|
||||||
|
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||||
|
|
||||||
|
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||||
|
|
||||||
|
### Python API Usage
|
||||||
|
|
||||||
|
You can pass the same configuration from Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://192.168.1.50:1234/v1",
|
||||||
|
"api_key": "local-dev-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-notes", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
@@ -86,16 +172,24 @@ For immediate testing without local model downloads:
|
|||||||
```
|
```
|
||||||
|
|
||||||
### DiskANN
|
### DiskANN
|
||||||
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
|
**Best for**: Large datasets, especially when you want `recompute=True`.
|
||||||
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
|
|
||||||
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
|
**Key advantages:**
|
||||||
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
|
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
||||||
|
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
||||||
|
- **Better scaling**: Designed for 100k+ documents
|
||||||
|
|
||||||
|
**Recompute behavior:**
|
||||||
|
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
||||||
|
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# For billion-scale deployments
|
# Recommended for most use cases
|
||||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||||
|
|
||||||
## LLM Selection: Engine and Model Comparison
|
## LLM Selection: Engine and Model Comparison
|
||||||
|
|
||||||
### LLM Engines
|
### LLM Engines
|
||||||
@@ -211,9 +305,15 @@ python apps/document_rag.py --query "What are the main techniques LEANN explores
|
|||||||
|
|
||||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||||
```bash
|
```bash
|
||||||
--embedding-mode mlx --embedding-model mlx-community/multilingual-e5-base-mlx
|
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||||
```
|
```
|
||||||
|
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||||
|
|
||||||
|
4. **Use Ollama**
|
||||||
|
```bash
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||||
### If Search Quality is Poor
|
### If Search Quality is Poor
|
||||||
|
|
||||||
1. **Increase retrieval count**:
|
1. **Increase retrieval count**:
|
||||||
@@ -242,27 +342,118 @@ Every configuration choice involves trade-offs:
|
|||||||
|
|
||||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||||
|
|
||||||
## Deep Dive: Critical Configuration Decisions
|
## Low-resource setups
|
||||||
|
|
||||||
### When to Disable Recomputation
|
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
||||||
|
|
||||||
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
|
### 1) Use OpenAI embeddings (no local compute)
|
||||||
|
|
||||||
|
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--no-recompute # Disable selective recomputation
|
export OPENAI_API_KEY=sk-...
|
||||||
|
|
||||||
|
# Build with OpenAI embeddings
|
||||||
|
leann build my-index \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-3-small
|
||||||
|
|
||||||
|
# Search with OpenAI embeddings (recompute at query time)
|
||||||
|
leann search my-index "your query" \
|
||||||
|
--recompute
|
||||||
```
|
```
|
||||||
|
|
||||||
**Trade-offs**:
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
|
|
||||||
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# One-time: install and configure SkyPilot
|
||||||
|
pip install skypilot
|
||||||
|
|
||||||
|
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml
|
||||||
|
|
||||||
|
# Override parameters via -e key=value (optional)
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml \
|
||||||
|
-e index_name=my-index \
|
||||||
|
-e backend=hnsw \
|
||||||
|
-e embedding_mode=sentence-transformers \
|
||||||
|
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
||||||
|
|
||||||
|
# Copy the built index back to your local .leann (use rsync)
|
||||||
|
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3) Disable recomputation to trade storage for speed
|
||||||
|
|
||||||
|
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build without recomputation (HNSW requires non-compact in this mode)
|
||||||
|
leann build my-index --no-recompute --no-compact
|
||||||
|
|
||||||
|
# Search without recomputation
|
||||||
|
leann search my-index "your query" --no-recompute
|
||||||
|
```
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Extreme low latency requirements (high QPS, interactive assistants)
|
||||||
|
- Read-heavy workloads where storage is cheaper than latency
|
||||||
|
- No always-available GPU
|
||||||
|
|
||||||
|
Constraints:
|
||||||
|
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
||||||
|
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
||||||
|
|
||||||
|
Storage impact:
|
||||||
|
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
||||||
|
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
||||||
|
|
||||||
|
Converting an existing index (rebuild required):
|
||||||
|
```bash
|
||||||
|
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
||||||
|
leann build my-index --force --no-recompute --no-compact
|
||||||
|
```
|
||||||
|
|
||||||
|
Python API usage:
|
||||||
|
```python
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("/path/to/my-index.leann")
|
||||||
|
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Trade-offs:
|
||||||
|
- Lower latency and fewer network hops at query time
|
||||||
|
- Significantly higher storage (10–100× vs selective recomputation)
|
||||||
|
- Slightly larger memory footprint during build and search
|
||||||
|
|
||||||
|
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
||||||
|
|
||||||
|
- HNSW
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.818s, size=1.1MB
|
||||||
|
recompute=False: search_time=0.012s, size=16.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
- DiskANN
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.041s, size=5.9MB
|
||||||
|
recompute=False: search_time=0.013s, size=24.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
Conclusion:
|
||||||
|
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
||||||
|
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
||||||
|
|
||||||
|
|
||||||
**Disable when**:
|
|
||||||
- You have abundant storage and memory
|
|
||||||
- Need extremely low latency (< 100ms)
|
|
||||||
- Running a read-heavy workload where storage cost is acceptable
|
|
||||||
|
|
||||||
## Further Reading
|
## Further Reading
|
||||||
|
|
||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
|
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
## 🔥 Core Features
|
## 🔥 Core Features
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||||
|
|||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# LEANN Metadata Filtering Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Adding Metadata to Your Documents
|
||||||
|
|
||||||
|
When building your index, add metadata to each text chunk:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
|
||||||
|
# Add text with metadata
|
||||||
|
builder.add_text(
|
||||||
|
text="Chapter 1: Alice falls down the rabbit hole",
|
||||||
|
metadata={
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"themes": ["adventure", "curiosity"],
|
||||||
|
"word_count": 150
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.build_index("alice_in_wonderland_index")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Searching with Metadata Filters
|
||||||
|
|
||||||
|
Use the `metadata_filters` parameter in search calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="What happens to Alice?",
|
||||||
|
top_k=10,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||||
|
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Filter Syntax
|
||||||
|
|
||||||
|
### Basic Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"field_name": {"operator": value},
|
||||||
|
"another_field": {"operator": value}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
#### Comparison Operators
|
||||||
|
- `"=="`: Equal to
|
||||||
|
- `"!="`: Not equal to
|
||||||
|
- `"<"`: Less than
|
||||||
|
- `"<="`: Less than or equal
|
||||||
|
- `">"`: Greater than
|
||||||
|
- `">="`: Greater than or equal
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||||
|
{"page": {">": 100}} # Pages after 100
|
||||||
|
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||||
|
{"word_count": {"<": 500}} # Short passages
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Membership Operators
|
||||||
|
- `"in"`: Value is in list
|
||||||
|
- `"not_in"`: Value is not in list
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||||
|
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||||
|
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||||
|
```
|
||||||
|
|
||||||
|
#### String Operators
|
||||||
|
- `"contains"`: String contains substring
|
||||||
|
- `"starts_with"`: String starts with prefix
|
||||||
|
- `"ends_with"`: String ends with suffix
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||||
|
{"filename": {"ends_with": ".py"}} # Python files
|
||||||
|
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Boolean Operators
|
||||||
|
- `"is_true"`: Field is truthy
|
||||||
|
- `"is_false"`: Field is falsy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"is_published": {"is_true": True}} # Published content
|
||||||
|
{"is_draft": {"is_false": False}} # Not drafts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Operators on Same Field
|
||||||
|
|
||||||
|
You can apply multiple operators to the same field (AND logic):
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"word_count": {
|
||||||
|
">=": 100, # At least 100 words
|
||||||
|
"<=": 500 # At most 500 words
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compound Filters
|
||||||
|
|
||||||
|
Multiple fields are combined with AND logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"chapter": {"<=": 10}, # Up to chapter 10
|
||||||
|
"character": {"==": "Alice"}, # About Alice
|
||||||
|
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Case Examples
|
||||||
|
|
||||||
|
### 1. Spoiler-Free Book Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reader has only read up to chapter 5
|
||||||
|
def search_spoiler_free(query, max_chapter):
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": max_chapter},
|
||||||
|
"spoiler_level": {"in": ["none", "low"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Document Management by Date
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find recent documents
|
||||||
|
recent_docs = searcher.search(
|
||||||
|
query="project updates",
|
||||||
|
metadata_filters={
|
||||||
|
"date": {">=": "2024-01-01"},
|
||||||
|
"document_type": {"==": "report"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Code Search by File Type
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search only Python files
|
||||||
|
python_code = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Content Filtering by Audience
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Age-appropriate content
|
||||||
|
family_content = searcher.search(
|
||||||
|
query="adventure stories",
|
||||||
|
metadata_filters={
|
||||||
|
"age_rating": {"in": ["G", "PG"]},
|
||||||
|
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Book Series Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search across first 3 books only
|
||||||
|
early_series = searcher.search(
|
||||||
|
query="character development",
|
||||||
|
metadata_filters={
|
||||||
|
"series": {"==": "Harry Potter"},
|
||||||
|
"book_number": {"<=": 3}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Don't forget to set up the environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Run the spoiler-free book RAG example
|
||||||
|
uv run examples/spoiler_free_book_rag.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates:
|
||||||
|
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||||
|
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||||
|
- Different scenarios for readers at various points in the book
|
||||||
|
|
||||||
|
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Custom Chunking with metadata
|
||||||
|
|
||||||
|
```python
|
||||||
|
def chunk_book_with_metadata(book_text, book_info):
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||||
|
# Extract entities, themes, etc.
|
||||||
|
characters = extract_characters(chapter_text)
|
||||||
|
themes = classify_themes(chapter_text)
|
||||||
|
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||||
|
|
||||||
|
# Create chunks with rich metadata
|
||||||
|
for paragraph in split_paragraphs(chapter_text):
|
||||||
|
chunks.append({
|
||||||
|
"text": paragraph,
|
||||||
|
"metadata": {
|
||||||
|
"book_title": book_info["title"],
|
||||||
|
"chapter": chapter_num,
|
||||||
|
"characters": characters,
|
||||||
|
"themes": themes,
|
||||||
|
"spoiler_level": spoiler_level,
|
||||||
|
"word_count": len(paragraph.split()),
|
||||||
|
"reading_level": calculate_reading_level(paragraph)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Efficient Filtering Strategies
|
||||||
|
|
||||||
|
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||||
|
|
||||||
|
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||||
|
|
||||||
|
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||||
|
|
||||||
|
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||||
|
|
||||||
|
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||||
|
|
||||||
|
### Adding Metadata to Existing Indices
|
||||||
|
|
||||||
|
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Read existing passages and add metadata
|
||||||
|
def add_metadata_to_existing_chunks(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
# Extract or assign metadata based on content
|
||||||
|
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Rebuild index with metadata
|
||||||
|
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
for chunk in enhanced_chunks:
|
||||||
|
builder.add_text(chunk["text"], chunk["metadata"])
|
||||||
|
builder.build_index("enhanced_index")
|
||||||
|
```
|
||||||
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
"""Dynamic HNSW update demo without compact storage.
|
||||||
|
|
||||||
|
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||||
|
recompute:
|
||||||
|
|
||||||
|
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||||
|
2. Print the top results with `recompute_embeddings=True`.
|
||||||
|
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||||
|
4. Run the same query again to show the newly inserted passages.
|
||||||
|
|
||||||
|
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||||
|
ZMQ activity)::
|
||||||
|
|
||||||
|
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||||
|
uv run -m examples.dynamic_update_no_recompute \
|
||||||
|
--index-path .leann/examples/leann-demo.leann
|
||||||
|
|
||||||
|
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||||
|
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||||
|
It issues the query "What's LEANN?" before and after the update to show how the
|
||||||
|
new passages become immediately searchable. The script uses the
|
||||||
|
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||||
|
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||||
|
freshly added passages are embedded locally just like the initial build.
|
||||||
|
|
||||||
|
To make storage comparisons easy, the script can also build a matching
|
||||||
|
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||||
|
delta after the update. Disable the baseline run with
|
||||||
|
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
DEFAULT_QUERY = "What's LEANN?"
|
||||||
|
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||||
|
searcher = LeannSearcher(str(index_path))
|
||||||
|
try:
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(title: str, results: Iterable) -> None:
|
||||||
|
print(f"\n=== {title} ===")
|
||||||
|
res_list = list(results)
|
||||||
|
print(f"results count: {len(res_list)}")
|
||||||
|
print("passages:")
|
||||||
|
if not res_list:
|
||||||
|
print(" (no passages returned)")
|
||||||
|
for res in res_list:
|
||||||
|
snippet = res.text.replace("\n", " ")[:120]
|
||||||
|
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def update_index(
|
||||||
|
index_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
updater = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||||
|
updater.add_text(passage, metadata={"id": str(offset)})
|
||||||
|
updater.update_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
"""Remove leftover index artifacts for a clean rebuild."""
|
||||||
|
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def index_file_size(index_path: Path) -> int:
|
||||||
|
"""Return the size of the primary .index file for the given index path."""
|
||||||
|
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
return index_file.stat().st_size if index_file.exists() else 0
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(meta_path.read_text())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_workflow(
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
index_path: Path,
|
||||||
|
initial_paragraphs: list[str],
|
||||||
|
update_paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
prefix = f"[{label}] " if label else ""
|
||||||
|
|
||||||
|
ensure_index_dir(index_path)
|
||||||
|
cleanup_index_files(index_path)
|
||||||
|
|
||||||
|
print(f"{prefix}Building initial index...")
|
||||||
|
build_initial_index(
|
||||||
|
index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_size = index_file_size(index_path)
|
||||||
|
before_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{prefix}Updating index with additional passages...")
|
||||||
|
update_index(
|
||||||
|
index_path,
|
||||||
|
start_id=len(initial_paragraphs),
|
||||||
|
paragraphs=update_paragraphs,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
after_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
updated_size = index_file_size(index_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"initial_size": initial_size,
|
||||||
|
"updated_size": updated_size,
|
||||||
|
"delta": updated_size - initial_size,
|
||||||
|
"before_results": before_results,
|
||||||
|
"after_results": after_results,
|
||||||
|
"metadata": load_metadata_snapshot(index_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Initial document files (PDF/TXT) used to build the base index",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/examples/leann-demo.leann"),
|
||||||
|
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-count",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of chunks to use from the initial documents (default: 8)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="*",
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Additional documents to add during update (PDF/TXT)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-count",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of chunks to append from update documents (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-text",
|
||||||
|
type=str,
|
||||||
|
default=(
|
||||||
|
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||||
|
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||||
|
"on demand to keep disk usage minimal."
|
||||||
|
),
|
||||||
|
help="Fallback text to append if --update-files is omitted",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of results to show for each search (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_QUERY,
|
||||||
|
help="Query to run before/after the update",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_false",
|
||||||
|
help="Skip building the no-recompute baseline.",
|
||||||
|
)
|
||||||
|
parser.set_defaults(compare_no_recompute=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||||
|
if not initial_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the initial files.")
|
||||||
|
|
||||||
|
initial = initial_chunks[: args.initial_count]
|
||||||
|
if not initial:
|
||||||
|
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||||
|
|
||||||
|
if args.update_files:
|
||||||
|
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||||
|
if not update_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the update files.")
|
||||||
|
to_add = update_chunks[: args.update_count]
|
||||||
|
else:
|
||||||
|
if not args.update_text:
|
||||||
|
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||||
|
to_add = [args.update_text]
|
||||||
|
if not to_add:
|
||||||
|
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||||
|
|
||||||
|
recompute_stats = run_workflow(
|
||||||
|
label="recompute",
|
||||||
|
index_path=args.index_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=True,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_results("initial search", recompute_stats["before_results"])
|
||||||
|
print_results("after update", recompute_stats["after_results"])
|
||||||
|
print(
|
||||||
|
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {recompute_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recompute_stats["metadata"]:
|
||||||
|
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
if args.compare_no_recompute:
|
||||||
|
baseline_path = (
|
||||||
|
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||||
|
)
|
||||||
|
baseline_stats = run_workflow(
|
||||||
|
label="no-recompute",
|
||||||
|
index_path=baseline_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=False,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {baseline_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
after_texts = [res.text for res in recompute_stats["after_results"]]
|
||||||
|
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
||||||
|
if after_texts == baseline_after_texts:
|
||||||
|
print(
|
||||||
|
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||||
|
|
||||||
|
if baseline_stats["metadata"]:
|
||||||
|
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[no-recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||||
|
|
||||||
|
This example demonstrates how to use LEANN's metadata filtering to create
|
||||||
|
a spoiler-free book RAG system where users can search for information
|
||||||
|
up to a specific chapter they've read.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python spoiler_free_book_rag.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Add LEANN to path (adjust path as needed)
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Create sample book chunks with metadata for demonstration.
|
||||||
|
|
||||||
|
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||||
|
and extract chapter boundaries, character mentions, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_title: Title of the book
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunk dictionaries with text and metadata
|
||||||
|
"""
|
||||||
|
# Sample book chunks with metadata
|
||||||
|
# In practice, you'd use proper text processing libraries
|
||||||
|
|
||||||
|
sample_chunks = [
|
||||||
|
{
|
||||||
|
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 1,
|
||||||
|
"characters": ["Alice", "Sister"],
|
||||||
|
"themes": ["boredom", "curiosity"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 2,
|
||||||
|
"characters": ["Alice", "White Rabbit"],
|
||||||
|
"themes": ["decision", "surprise", "magic"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 2,
|
||||||
|
"page": 15,
|
||||||
|
"characters": ["Alice"],
|
||||||
|
"themes": ["falling", "wonder", "transformation"],
|
||||||
|
"location": "rabbit hole",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 6,
|
||||||
|
"page": 85,
|
||||||
|
"characters": ["Alice", "Cheshire Cat"],
|
||||||
|
"themes": ["madness", "philosophy", "identity"],
|
||||||
|
"location": "Duchess's house",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 8,
|
||||||
|
"page": 120,
|
||||||
|
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||||
|
"themes": ["justice", "absurdity", "authority"],
|
||||||
|
"location": "Queen's court",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 12,
|
||||||
|
"page": 180,
|
||||||
|
"characters": ["Alice", "Sister", "Rabbit"],
|
||||||
|
"themes": ["revelation", "reality", "growth"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return sample_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a LEANN index with book chunks that include spoiler metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_chunks: List of book chunks with metadata
|
||||||
|
index_name: Name for the index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the built index
|
||||||
|
"""
|
||||||
|
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||||
|
|
||||||
|
# Initialize LEANN builder
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add each chunk with its metadata
|
||||||
|
for chunk in book_chunks:
|
||||||
|
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
index_path = f"{index_name}_book_index"
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
print(f"✅ Index built successfully: {index_path}")
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def spoiler_free_search(
|
||||||
|
index_path: str,
|
||||||
|
query: str,
|
||||||
|
max_chapter: int,
|
||||||
|
character_filter: Optional[list[str]] = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Perform a spoiler-free search on the book index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: Search query
|
||||||
|
max_chapter: Maximum chapter number to include
|
||||||
|
character_filter: Optional list of characters to focus on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results safe for the reader
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||||
|
|
||||||
|
if character_filter:
|
||||||
|
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||||
|
|
||||||
|
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def demo_spoiler_free_rag():
|
||||||
|
"""
|
||||||
|
Demonstrate the spoiler-free book RAG system.
|
||||||
|
"""
|
||||||
|
print("🎭 Spoiler-Free Book RAG Demo")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Step 1: Prepare book data
|
||||||
|
book_title = "Alice's Adventures in Wonderland"
|
||||||
|
book_chunks = chunk_book_with_metadata(book_title)
|
||||||
|
|
||||||
|
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||||
|
|
||||||
|
# Step 2: Build the index (in practice, this would be done once)
|
||||||
|
try:
|
||||||
|
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||||
|
print(
|
||||||
|
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Demonstrate various spoiler-free searches
|
||||||
|
search_scenarios = [
|
||||||
|
{
|
||||||
|
"description": "Reader who has only read Chapter 1",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read up to Chapter 5",
|
||||||
|
"query": "Tell me about Alice's adventures",
|
||||||
|
"max_chapter": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read most of the book",
|
||||||
|
"query": "What does the Cheshire Cat represent?",
|
||||||
|
"max_chapter": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read the whole book",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 12,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in search_scenarios:
|
||||||
|
print(f"\n📚 Scenario: {scenario['description']}")
|
||||||
|
print(f" Query: {scenario['query']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = spoiler_free_search(
|
||||||
|
index_path=index_path,
|
||||||
|
query=scenario["query"],
|
||||||
|
max_chapter=scenario["max_chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📄 Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||||
|
chapter = result.metadata.get("chapter", "?")
|
||||||
|
location = result.metadata.get("location", "?")
|
||||||
|
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Search failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||||
|
print("=====================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
demo_spoiler_free_rag()
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error running demo: {e}")
|
||||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# llms.txt — LEANN MCP and Agent Integration
|
||||||
|
product: LEANN
|
||||||
|
homepage: https://github.com/yichuan-w/LEANN
|
||||||
|
contact: https://github.com/yichuan-w/LEANN/issues
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
install: uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
# MCP Server Entry Point
|
||||||
|
mcp.server: leann_mcp
|
||||||
|
mcp.protocol_version: 2024-11-05
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
mcp.tools: leann_list, leann_search
|
||||||
|
|
||||||
|
mcp.tool.leann_list.description: List available LEANN indexes
|
||||||
|
mcp.tool.leann_list.input: {}
|
||||||
|
|
||||||
|
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||||
|
mcp.tool.leann_search.input.index_name: string, required
|
||||||
|
mcp.tool.leann_search.input.query: string, required
|
||||||
|
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||||
|
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||||
|
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
example.verify: claude mcp list | cat
|
||||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at ad9afa07b9
@@ -1,8 +0,0 @@
|
|||||||
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.20)
|
|
||||||
project(leann_backend_diskann_wrapper)
|
|
||||||
|
|
||||||
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
|
||||||
# DiskANN will handle everything itself, including compiling Python bindings
|
|
||||||
add_subdirectory(src/third_party/DiskANN)
|
|
||||||
@@ -1 +1,7 @@
|
|||||||
from . import diskann_backend as diskann_backend
|
from . import diskann_backend as diskann_backend
|
||||||
|
from . import graph_partition
|
||||||
|
|
||||||
|
# Export main classes and functions
|
||||||
|
from .graph_partition import GraphPartitioner, partition_graph
|
||||||
|
|
||||||
|
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -22,6 +22,11 @@ logger = logging.getLogger(__name__)
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def suppress_cpp_output_if_needed():
|
def suppress_cpp_output_if_needed():
|
||||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
||||||
|
if os.getenv("CI") == "true":
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
@@ -137,6 +142,71 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
|
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||||
|
"""
|
||||||
|
Safely cleanup files after partition.
|
||||||
|
In partition mode, C++ doesn't read _disk.index content,
|
||||||
|
so we can delete it if all derived files exist.
|
||||||
|
"""
|
||||||
|
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||||
|
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||||
|
|
||||||
|
# Required files that C++ partition mode needs
|
||||||
|
# Note: C++ generates these with _disk.index suffix
|
||||||
|
disk_suffix = "_disk.index"
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||||
|
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||||
|
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||||
|
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if all required files exist
|
||||||
|
missing_files = []
|
||||||
|
for filename in required_files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if not file_path.exists():
|
||||||
|
missing_files.append(filename)
|
||||||
|
|
||||||
|
if missing_files:
|
||||||
|
logger.warning(
|
||||||
|
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||||
|
)
|
||||||
|
logger.info("Keeping all original files for safety")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate space savings
|
||||||
|
space_saved = 0
|
||||||
|
files_to_delete = []
|
||||||
|
|
||||||
|
if disk_index_file.exists():
|
||||||
|
space_saved += disk_index_file.stat().st_size
|
||||||
|
files_to_delete.append(disk_index_file)
|
||||||
|
|
||||||
|
if beam_search_file.exists():
|
||||||
|
space_saved += beam_search_file.stat().st_size
|
||||||
|
files_to_delete.append(beam_search_file)
|
||||||
|
|
||||||
|
# Safe to delete!
|
||||||
|
for file_to_delete in files_to_delete:
|
||||||
|
try:
|
||||||
|
os.remove(file_to_delete)
|
||||||
|
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||||
|
|
||||||
|
if space_saved > 0:
|
||||||
|
space_saved_mb = space_saved / (1024 * 1024)
|
||||||
|
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||||
|
|
||||||
|
# Show what files are kept
|
||||||
|
logger.info("📁 Kept essential files for partition mode:")
|
||||||
|
for filename in required_files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -151,6 +221,17 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
|
|
||||||
|
# Extract is_recompute from nested backend_kwargs if needed
|
||||||
|
is_recompute = build_kwargs.get("is_recompute", False)
|
||||||
|
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||||
|
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||||
|
|
||||||
|
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||||
|
if "backend_kwargs" in build_kwargs:
|
||||||
|
nested_params = build_kwargs.pop("backend_kwargs")
|
||||||
|
build_kwargs.update(nested_params)
|
||||||
|
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
@@ -185,6 +266,30 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
build_kwargs.get("pq_disk_bytes", 0),
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Auto-partition if is_recompute is enabled
|
||||||
|
if build_kwargs.get("is_recompute", False):
|
||||||
|
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||||
|
from .graph_partition import partition_graph
|
||||||
|
|
||||||
|
# Partition the index using absolute paths
|
||||||
|
# Convert to absolute paths to avoid issues with working directory changes
|
||||||
|
absolute_index_dir = Path(index_dir).resolve()
|
||||||
|
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||||
|
disk_graph_path, partition_bin_path = partition_graph(
|
||||||
|
index_prefix_path=absolute_index_prefix_path,
|
||||||
|
output_dir=str(absolute_index_dir),
|
||||||
|
partition_prefix=index_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||||
|
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||||
|
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||||
|
|
||||||
|
logger.info("✅ Graph partitioning completed successfully!")
|
||||||
|
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||||
|
logger.info(f" - Partition file: {partition_bin_path}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
@@ -213,16 +318,42 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
# Store the initialization parameters for later use
|
# Store the initialization parameters for later use
|
||||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||||
|
# C++ internally constructs: index_prefix + "_disk.index"
|
||||||
|
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||||
|
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||||
|
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||||
|
|
||||||
|
# Auto-detect partition files and set partition_prefix
|
||||||
|
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||||
|
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||||
|
|
||||||
|
partition_prefix = ""
|
||||||
|
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||||
|
# C++ expects full path prefix, not just filename
|
||||||
|
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||||
|
logger.info(
|
||||||
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No partition files detected, using standard index files")
|
||||||
|
|
||||||
self._init_params = {
|
self._init_params = {
|
||||||
"metric_enum": metric_enum,
|
"metric_enum": metric_enum,
|
||||||
"full_index_prefix": full_index_prefix,
|
"full_index_prefix": full_index_prefix,
|
||||||
"num_threads": self.num_threads,
|
"num_threads": self.num_threads,
|
||||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
"cache_mechanism": 1,
|
# 1 -> initialize cache using sample_data; 2 -> ready cache without init; others disable cache
|
||||||
|
"cache_mechanism": kwargs.get("cache_mechanism", 1),
|
||||||
"pq_prefix": "",
|
"pq_prefix": "",
|
||||||
"partition_prefix": "",
|
"partition_prefix": partition_prefix,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log partition configuration for debugging
|
||||||
|
if partition_prefix:
|
||||||
|
logger.info(
|
||||||
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||||
|
)
|
||||||
self._diskannpy = diskannpy
|
self._diskannpy = diskannpy
|
||||||
self._current_zmq_port = None
|
self._current_zmq_port = None
|
||||||
self._index = None
|
self._index = None
|
||||||
@@ -259,7 +390,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int | None = None,
|
zmq_port: Optional[int] = None,
|
||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -311,9 +442,14 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
else: # "global"
|
else: # "global"
|
||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
# Perform search with suppressed C++ output based on log level
|
# Strategy:
|
||||||
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
# - Traversal always uses PQ distances
|
||||||
recompute_neighors = False
|
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
|
||||||
|
# (fetch embeddings for the final candidate set only)
|
||||||
|
# - Do not recompute neighbor distances along the path
|
||||||
|
use_deferred_fetch = True if recompute_embeddings else False
|
||||||
|
recompute_neighors = False # Expected typo. For backward compatibility.
|
||||||
|
|
||||||
with suppress_cpp_output_if_needed():
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -31,8 +32,18 @@ if not logger.handlers:
|
|||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: str | None = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
@@ -80,10 +91,9 @@ def create_diskann_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||||
logger.info(
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||||
)
|
|
||||||
|
|
||||||
# Import protobuf after ensuring the path is correct
|
# Import protobuf after ensuring the path is correct
|
||||||
try:
|
try:
|
||||||
@@ -101,8 +111,9 @@ def create_diskann_embedding_server(
|
|||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -180,7 +191,12 @@ def create_diskann_embedding_server(
|
|||||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
|
|
||||||
# Process embeddings using unified computation
|
# Process embeddings using unified computation
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
@@ -219,30 +235,222 @@ def create_diskann_embedding_server(
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
This creates its own REP socket, binds to zmq_port, and periodically
|
||||||
|
checks shutdown_event using recv timeouts to exit cleanly.
|
||||||
|
"""
|
||||||
|
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
rep_socket = context.socket(zmq.REP)
|
||||||
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
# Set receive timeout so we can check shutdown_event periodically
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
e2e_start = time.time()
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = rep_socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if not message:
|
||||||
|
logger.warning("Received empty message, sending empty response")
|
||||||
|
rep_socket.send(b"")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try protobuf first (same logic as original)
|
||||||
|
texts = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
# Look up texts by node IDs
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
|
||||||
|
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
||||||
|
except Exception:
|
||||||
|
# Fallback to msgpack for text requests
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(
|
||||||
|
f"ZMQ received msgpack text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception:
|
||||||
|
logger.error("Both protobuf and msgpack parsing failed!")
|
||||||
|
# Send error response
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process the request
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error("NaN or Inf detected in embeddings!")
|
||||||
|
# Send error response
|
||||||
|
if is_text_request:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb([])
|
||||||
|
else:
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For direct text requests, return msgpack
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For protobuf requests, return protobuf
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
try:
|
||||||
|
# Send error response for REP socket
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
||||||
|
|
||||||
|
# Add shutdown coordination
|
||||||
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
|
def shutdown_zmq_server():
|
||||||
|
"""Gracefully shutdown ZMQ server."""
|
||||||
|
logger.info("Initiating graceful shutdown...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.info("Waiting for ZMQ thread to finish...")
|
||||||
|
zmq_thread.join(timeout=5)
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.warning("ZMQ thread did not finish in time")
|
||||||
|
|
||||||
|
# Clean up ZMQ resources
|
||||||
|
try:
|
||||||
|
# Note: socket and context are cleaned up by thread exit
|
||||||
|
logger.info("ZMQ resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||||
|
|
||||||
|
# Clean up other resources
|
||||||
|
try:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Additional resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning additional resources: {e}")
|
||||||
|
|
||||||
|
logger.info("Graceful shutdown completed")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers within this function scope
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
shutdown_zmq_server()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Start ZMQ thread (NOT daemon!)
|
||||||
|
zmq_thread = threading.Thread(
|
||||||
|
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||||
|
daemon=False, # Not daemon - we want to wait for it
|
||||||
|
)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(0.1) # Check shutdown more frequently
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("DiskANN Server shutting down...")
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
shutdown_zmq_server()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we reach here, shutdown was triggered by signal
|
||||||
|
logger.info("Main loop exited, process should be shutting down")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# Signal handlers are now registered within create_diskann_embedding_server
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
@@ -261,7 +469,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -0,0 +1,299 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Graph Partition Module for LEANN DiskANN Backend
|
||||||
|
|
||||||
|
This module provides Python bindings for the graph partition functionality
|
||||||
|
of DiskANN, allowing users to partition disk-based indices for better
|
||||||
|
performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class GraphPartitioner:
|
||||||
|
"""
|
||||||
|
A Python interface for DiskANN's graph partition functionality.
|
||||||
|
|
||||||
|
This class provides methods to partition disk-based indices for improved
|
||||||
|
search performance and memory efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, build_type: str = "release"):
|
||||||
|
"""
|
||||||
|
Initialize the GraphPartitioner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
build_type: Build type for the executables ("debug" or "release")
|
||||||
|
"""
|
||||||
|
self.build_type = build_type
|
||||||
|
self._ensure_executables()
|
||||||
|
|
||||||
|
def _get_executable_path(self, name: str) -> str:
|
||||||
|
"""Get the path to a graph partition executable."""
|
||||||
|
# Get the directory where this Python module is located
|
||||||
|
module_dir = Path(__file__).parent
|
||||||
|
# Navigate to the graph_partition directory
|
||||||
|
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||||
|
|
||||||
|
if not executable_path.exists():
|
||||||
|
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||||
|
|
||||||
|
return str(executable_path)
|
||||||
|
|
||||||
|
def _ensure_executables(self):
|
||||||
|
"""Ensure that the required executables are built."""
|
||||||
|
try:
|
||||||
|
self._get_executable_path("partitioner")
|
||||||
|
self._get_executable_path("index_relayout")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# Try to build the executables automatically
|
||||||
|
print("Executables not found, attempting to build them...")
|
||||||
|
self._build_executables()
|
||||||
|
|
||||||
|
def _build_executables(self):
|
||||||
|
"""Build the required executables."""
|
||||||
|
graph_partition_dir = (
|
||||||
|
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.chdir(graph_partition_dir)
|
||||||
|
|
||||||
|
# Clean any existing build
|
||||||
|
if (graph_partition_dir / "build").exists():
|
||||||
|
shutil.rmtree(graph_partition_dir / "build")
|
||||||
|
|
||||||
|
# Run the build script
|
||||||
|
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||||
|
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||||
|
|
||||||
|
# Check if executables were created
|
||||||
|
partitioner_path = self._get_executable_path("partitioner")
|
||||||
|
relayout_path = self._get_executable_path("index_relayout")
|
||||||
|
|
||||||
|
print(f"✅ Built partitioner: {partitioner_path}")
|
||||||
|
print(f"✅ Built index_relayout: {relayout_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to build executables: {e}")
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
def partition_graph(
|
||||||
|
self,
|
||||||
|
index_prefix_path: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
partition_prefix: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Partition a disk-based index for improved performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||||
|
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||||
|
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||||
|
**kwargs: Additional parameters for graph partitioning:
|
||||||
|
- gp_times: Number of LDG partition iterations (default: 10)
|
||||||
|
- lock_nums: Number of lock nodes (default: 10)
|
||||||
|
- cut: Cut adjacency list degree (default: 100)
|
||||||
|
- scale_factor: Scale factor (default: 1)
|
||||||
|
- data_type: Data type (default: "float")
|
||||||
|
- thread_nums: Number of threads (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the partitioning process fails
|
||||||
|
"""
|
||||||
|
# Set default parameters
|
||||||
|
params = {
|
||||||
|
"gp_times": 10,
|
||||||
|
"lock_nums": 10,
|
||||||
|
"cut": 100,
|
||||||
|
"scale_factor": 1,
|
||||||
|
"data_type": "float",
|
||||||
|
"thread_nums": 10,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine output directory
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = str(Path(index_prefix_path).parent)
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine partition prefix
|
||||||
|
if partition_prefix is None:
|
||||||
|
partition_prefix = Path(index_prefix_path).name
|
||||||
|
|
||||||
|
# Get executable paths
|
||||||
|
partitioner_path = self._get_executable_path("partitioner")
|
||||||
|
relayout_path = self._get_executable_path("index_relayout")
|
||||||
|
|
||||||
|
# Create temporary directory for processing
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Change to the graph_partition directory for temporary files
|
||||||
|
graph_partition_dir = (
|
||||||
|
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.chdir(graph_partition_dir)
|
||||||
|
|
||||||
|
# Create temporary data directory
|
||||||
|
temp_data_dir = Path(temp_dir) / "data"
|
||||||
|
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Set up paths for temporary files
|
||||||
|
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||||
|
graph_gp_path = (
|
||||||
|
graph_path
|
||||||
|
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||||
|
)
|
||||||
|
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Find input index file
|
||||||
|
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||||
|
if not os.path.exists(old_index_file):
|
||||||
|
old_index_file = f"{index_prefix_path}_disk.index"
|
||||||
|
|
||||||
|
if not os.path.exists(old_index_file):
|
||||||
|
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||||
|
|
||||||
|
# Run partitioner
|
||||||
|
gp_file_path = graph_gp_path / "_part.bin"
|
||||||
|
partitioner_cmd = [
|
||||||
|
partitioner_path,
|
||||||
|
"--index_file",
|
||||||
|
old_index_file,
|
||||||
|
"--data_type",
|
||||||
|
params["data_type"],
|
||||||
|
"--gp_file",
|
||||||
|
str(gp_file_path),
|
||||||
|
"-T",
|
||||||
|
str(params["thread_nums"]),
|
||||||
|
"--ldg_times",
|
||||||
|
str(params["gp_times"]),
|
||||||
|
"--scale",
|
||||||
|
str(params["scale_factor"]),
|
||||||
|
"--mode",
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||||
|
result = subprocess.run(
|
||||||
|
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Partitioner failed with return code {result.returncode}.\n"
|
||||||
|
f"stdout: {result.stdout}\n"
|
||||||
|
f"stderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run relayout
|
||||||
|
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||||
|
relayout_cmd = [
|
||||||
|
relayout_path,
|
||||||
|
old_index_file,
|
||||||
|
str(gp_file_path),
|
||||||
|
params["data_type"],
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||||
|
result = subprocess.run(
|
||||||
|
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Relayout failed with return code {result.returncode}.\n"
|
||||||
|
f"stdout: {result.stdout}\n"
|
||||||
|
f"stderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy results to output directory
|
||||||
|
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||||
|
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||||
|
|
||||||
|
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||||
|
shutil.copy2(gp_file_path, partition_bin_path)
|
||||||
|
|
||||||
|
print(f"Results copied to: {output_dir}")
|
||||||
|
return str(disk_graph_path), str(partition_bin_path)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about a partition file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
partition_bin_path: Path to the partition binary file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing partition information
|
||||||
|
"""
|
||||||
|
if not os.path.exists(partition_bin_path):
|
||||||
|
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||||
|
|
||||||
|
# For now, return basic file information
|
||||||
|
# In the future, this could parse the binary file for detailed info
|
||||||
|
stat = os.stat(partition_bin_path)
|
||||||
|
return {
|
||||||
|
"file_size": stat.st_size,
|
||||||
|
"file_path": partition_bin_path,
|
||||||
|
"modified_time": stat.st_mtime,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def partition_graph(
|
||||||
|
index_prefix_path: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
partition_prefix: Optional[str] = None,
|
||||||
|
build_type: str = "release",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Convenience function to partition a graph index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_prefix_path: Path to the index prefix
|
||||||
|
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||||
|
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||||
|
build_type: Build type for executables ("debug" or "release")
|
||||||
|
**kwargs: Additional parameters for graph partitioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||||
|
"""
|
||||||
|
partitioner = GraphPartitioner(build_type=build_type)
|
||||||
|
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example: partition an index
|
||||||
|
try:
|
||||||
|
disk_graph_path, partition_bin_path = partition_graph(
|
||||||
|
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||||
|
)
|
||||||
|
print("Partitioning completed successfully!")
|
||||||
|
print(f"Disk graph index: {disk_graph_path}")
|
||||||
|
print(f"Partition binary: {partition_bin_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Partitioning failed: {e}")
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
|
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
|
||||||
build-backend = "scikit_build_core.build"
|
build-backend = "scikit_build_core.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.2.1"
|
version = "0.3.4"
|
||||||
dependencies = ["leann-core==0.2.1", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
@@ -17,3 +17,5 @@ editable.mode = "redirect"
|
|||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
# Let CMake find packages via Homebrew prefix
|
||||||
|
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...19f9603c72
@@ -5,11 +5,20 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
|
|||||||
|
|
||||||
# Set OpenMP path for macOS
|
# Set OpenMP path for macOS
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
# Detect Homebrew installation path (Apple Silicon vs Intel)
|
||||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
if(EXISTS "/opt/homebrew/opt/libomp")
|
||||||
|
set(HOMEBREW_PREFIX "/opt/homebrew")
|
||||||
|
elseif(EXISTS "/usr/local/opt/libomp")
|
||||||
|
set(HOMEBREW_PREFIX "/usr/local")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||||
|
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
||||||
set(OpenMP_C_LIB_NAMES "omp")
|
set(OpenMP_C_LIB_NAMES "omp")
|
||||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
|
||||||
|
|
||||||
# Force use of system libc++ to avoid version mismatch
|
# Force use of system libc++ to avoid version mismatch
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||||
@@ -40,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
# Disable additional SIMD versions to speed up compilation
|
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
||||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# ARM64-specific configuration
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
message(STATUS "Configuring Faiss for ARM64 architecture")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
||||||
|
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
||||||
|
else()
|
||||||
|
# Use generic optimization for other ARM64 platforms (like macOS)
|
||||||
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
||||||
|
message(STATUS "Using ARM64-compatible Faiss submodule")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Additional optimization options from INSTALL.md
|
# Additional optimization options from INSTALL.md
|
||||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
|||||||
@@ -1,12 +1,21 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc # Import garbage collector interface
|
import gc # Import garbage collector interface
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# Set up logging to avoid print buffer issues
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
@@ -230,6 +239,288 @@ def write_compact_format(
|
|||||||
f_out.write(storage_data)
|
f_out.write(storage_data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HNSWComponents:
|
||||||
|
original_hnsw_data: dict[str, Any]
|
||||||
|
assign_probas_np: np.ndarray
|
||||||
|
cum_nneighbor_per_level_np: np.ndarray
|
||||||
|
levels_np: np.ndarray
|
||||||
|
is_compact: bool
|
||||||
|
compact_level_ptr: Optional[np.ndarray] = None
|
||||||
|
compact_node_offsets_np: Optional[np.ndarray] = None
|
||||||
|
compact_neighbors_data: Optional[list[int]] = None
|
||||||
|
offsets_np: Optional[np.ndarray] = None
|
||||||
|
neighbors_np: Optional[np.ndarray] = None
|
||||||
|
storage_fourcc: int = NULL_INDEX_FOURCC
|
||||||
|
storage_data: bytes = b""
|
||||||
|
|
||||||
|
|
||||||
|
def _read_hnsw_structure(f) -> HNSWComponents:
|
||||||
|
original_hnsw_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
hnsw_index_fourcc = read_struct(f, "<I")
|
||||||
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
|
||||||
|
)
|
||||||
|
|
||||||
|
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||||
|
original_hnsw_data["d"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["ntotal"] = read_struct(f, "<q")
|
||||||
|
original_hnsw_data["dummy1"] = read_struct(f, "<q")
|
||||||
|
original_hnsw_data["dummy2"] = read_struct(f, "<q")
|
||||||
|
original_hnsw_data["is_trained"] = read_struct(f, "?")
|
||||||
|
original_hnsw_data["metric_type"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["metric_arg"] = 0.0
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
|
||||||
|
|
||||||
|
assign_probas_np = read_numpy_vector(f, np.float64, "d")
|
||||||
|
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
levels_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
|
||||||
|
ntotal = len(levels_np)
|
||||||
|
if ntotal != original_hnsw_data["ntotal"]:
|
||||||
|
original_hnsw_data["ntotal"] = ntotal
|
||||||
|
|
||||||
|
pos_before_compact = f.tell()
|
||||||
|
is_compact_flag = None
|
||||||
|
try:
|
||||||
|
is_compact_flag = read_struct(f, "<?")
|
||||||
|
except EOFError:
|
||||||
|
is_compact_flag = None
|
||||||
|
|
||||||
|
if is_compact_flag:
|
||||||
|
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
|
||||||
|
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||||
|
|
||||||
|
storage_fourcc = read_struct(f, "<I")
|
||||||
|
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
|
storage_data = f.read()
|
||||||
|
|
||||||
|
return HNSWComponents(
|
||||||
|
original_hnsw_data=original_hnsw_data,
|
||||||
|
assign_probas_np=assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||||
|
levels_np=levels_np,
|
||||||
|
is_compact=True,
|
||||||
|
compact_level_ptr=compact_level_ptr,
|
||||||
|
compact_node_offsets_np=compact_node_offsets_np,
|
||||||
|
compact_neighbors_data=compact_neighbors_data,
|
||||||
|
storage_fourcc=storage_fourcc,
|
||||||
|
storage_data=storage_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-compact case
|
||||||
|
f.seek(pos_before_compact)
|
||||||
|
|
||||||
|
pos_before_probe = f.tell()
|
||||||
|
try:
|
||||||
|
suspected_flag = read_struct(f, "<B")
|
||||||
|
if suspected_flag != 0x00:
|
||||||
|
f.seek(pos_before_probe)
|
||||||
|
except EOFError:
|
||||||
|
f.seek(pos_before_probe)
|
||||||
|
|
||||||
|
offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||||
|
neighbors_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||||
|
|
||||||
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
storage_data = b""
|
||||||
|
try:
|
||||||
|
storage_fourcc = read_struct(f, "<I")
|
||||||
|
storage_data = f.read()
|
||||||
|
except EOFError:
|
||||||
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
|
||||||
|
return HNSWComponents(
|
||||||
|
original_hnsw_data=original_hnsw_data,
|
||||||
|
assign_probas_np=assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||||
|
levels_np=levels_np,
|
||||||
|
is_compact=False,
|
||||||
|
offsets_np=offsets_np,
|
||||||
|
neighbors_np=neighbors_np,
|
||||||
|
storage_fourcc=storage_fourcc,
|
||||||
|
storage_data=storage_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return _read_hnsw_structure(f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_original_format(
|
||||||
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
offsets_np,
|
||||||
|
neighbors_np,
|
||||||
|
storage_fourcc,
|
||||||
|
storage_data,
|
||||||
|
):
|
||||||
|
"""Write non-compact HNSW data in original FAISS order."""
|
||||||
|
|
||||||
|
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||||
|
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||||
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||||
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||||
|
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||||
|
|
||||||
|
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||||
|
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||||
|
write_numpy_vector(f_out, levels_np, "i")
|
||||||
|
|
||||||
|
write_numpy_vector(f_out, offsets_np, "Q")
|
||||||
|
write_numpy_vector(f_out, neighbors_np, "i")
|
||||||
|
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||||
|
|
||||||
|
f_out.write(struct.pack("<I", storage_fourcc))
|
||||||
|
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||||
|
f_out.write(storage_data)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
|
||||||
|
"""Rewrite an HNSW index while dropping the embedded storage section."""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||||
|
original_hnsw_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||||
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
|
print(
|
||||||
|
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||||
|
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||||
|
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||||
|
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||||
|
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||||
|
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["metric_arg"] = 0.0
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||||
|
|
||||||
|
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||||
|
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
|
||||||
|
ntotal = len(levels_np)
|
||||||
|
if ntotal != original_hnsw_data["ntotal"]:
|
||||||
|
original_hnsw_data["ntotal"] = ntotal
|
||||||
|
|
||||||
|
pos_before_compact = f_in.tell()
|
||||||
|
is_compact_flag = None
|
||||||
|
try:
|
||||||
|
is_compact_flag = read_struct(f_in, "<?")
|
||||||
|
except EOFError:
|
||||||
|
is_compact_flag = None
|
||||||
|
|
||||||
|
if is_compact_flag:
|
||||||
|
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
|
|
||||||
|
_storage_fourcc = read_struct(f_in, "<I")
|
||||||
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
|
_storage_data = f_in.read()
|
||||||
|
|
||||||
|
write_compact_format(
|
||||||
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
NULL_INDEX_FOURCC,
|
||||||
|
b"",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
f_in.seek(pos_before_compact)
|
||||||
|
|
||||||
|
pos_before_probe = f_in.tell()
|
||||||
|
try:
|
||||||
|
suspected_flag = read_struct(f_in, "<B")
|
||||||
|
if suspected_flag != 0x00:
|
||||||
|
f_in.seek(pos_before_probe)
|
||||||
|
except EOFError:
|
||||||
|
f_in.seek(pos_before_probe)
|
||||||
|
|
||||||
|
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
|
|
||||||
|
_storage_fourcc = None
|
||||||
|
_storage_data = b""
|
||||||
|
try:
|
||||||
|
_storage_fourcc = read_struct(f_in, "<I")
|
||||||
|
_storage_data = f_in.read()
|
||||||
|
except EOFError:
|
||||||
|
_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
|
||||||
|
write_original_format(
|
||||||
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
offsets_np,
|
||||||
|
neighbors_np,
|
||||||
|
NULL_INDEX_FOURCC,
|
||||||
|
b"",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# --- Main Conversion Logic ---
|
# --- Main Conversion Logic ---
|
||||||
|
|
||||||
|
|
||||||
@@ -243,6 +534,8 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
output_filename: Output CSR index file
|
output_filename: Output CSR index file
|
||||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||||
"""
|
"""
|
||||||
|
# Keep prints simple; rely on CI runner to flush output as needed
|
||||||
|
|
||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
@@ -691,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
|
||||||
|
"""Convenience wrapper to prune embeddings in-place."""
|
||||||
|
|
||||||
|
temp_path = f"{index_filename}.prune.tmp"
|
||||||
|
success = prune_hnsw_embeddings(index_filename, temp_path)
|
||||||
|
if success:
|
||||||
|
try:
|
||||||
|
os.replace(temp_path, index_filename)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.error(f"Failed to replace original index with pruned version: {exc}")
|
||||||
|
try:
|
||||||
|
os.remove(temp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
os.remove(temp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
# --- Script Execution ---
|
# --- Script Execution ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
@@ -13,7 +14,7 @@ from leann.interface import (
|
|||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -54,12 +55,13 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||||
self.dimensions = self.build_params.get("dimensions")
|
self.dimensions = self.build_params.get("dimensions")
|
||||||
if not self.is_recompute:
|
if not self.is_recompute and self.is_compact:
|
||||||
if self.is_compact:
|
# Auto-correct: non-recompute requires non-compact storage for HNSW
|
||||||
# TODO: support this case @andy
|
logger.warning(
|
||||||
raise ValueError(
|
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
|
||||||
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
)
|
||||||
)
|
self.is_compact = False
|
||||||
|
self.build_params["is_compact"] = False
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
@@ -88,8 +90,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
index_file = index_dir / f"{index_prefix}.index"
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
faiss.write_index(index, str(index_file))
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
|
||||||
|
try:
|
||||||
|
idmap_file = index_dir / f"{index_prefix}.ids.txt"
|
||||||
|
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||||
|
for id_str in ids:
|
||||||
|
f.write(str(id_str) + "\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to write ID map: {e}")
|
||||||
|
|
||||||
if self.is_compact:
|
if self.is_compact:
|
||||||
self._convert_to_csr(index_file)
|
self._convert_to_csr(index_file)
|
||||||
|
elif self.is_recompute:
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
"""Convert built index to CSR format"""
|
"""Convert built index to CSR format"""
|
||||||
@@ -131,10 +144,10 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
self.is_compact, self.is_pruned = (
|
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
|
||||||
self.meta.get("is_compact", True),
|
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
|
||||||
self.meta.get("is_pruned", True),
|
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
|
||||||
)
|
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
|
||||||
|
|
||||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||||
if not index_file.exists():
|
if not index_file.exists():
|
||||||
@@ -148,11 +161,21 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
||||||
|
|
||||||
|
# Load ID map if available
|
||||||
|
self._id_map: list[str] = []
|
||||||
|
try:
|
||||||
|
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
|
||||||
|
if idmap_file.exists():
|
||||||
|
with open(idmap_file, encoding="utf-8") as f:
|
||||||
|
self._id_map = [line.rstrip("\n") for line in f]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load ID map: {e}")
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
zmq_port: int | None = None,
|
zmq_port: Optional[int] = None,
|
||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
@@ -184,9 +207,11 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
"""
|
"""
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
if not recompute_embeddings:
|
if not recompute_embeddings and self.is_pruned:
|
||||||
if self.is_pruned:
|
raise RuntimeError(
|
||||||
raise RuntimeError("Recompute is required for pruned index.")
|
"Recompute is required for pruned/compact HNSW index. "
|
||||||
|
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
|
||||||
|
)
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
@@ -233,6 +258,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|
||||||
|
search_time = time.time()
|
||||||
self._index.search(
|
self._index.search(
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
faiss.swig_ptr(query),
|
faiss.swig_ptr(query),
|
||||||
@@ -241,7 +267,21 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.swig_ptr(labels),
|
faiss.swig_ptr(labels),
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
search_time = time.time() - search_time
|
||||||
|
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||||
|
if self._id_map:
|
||||||
|
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
def map_label(x: int) -> str:
|
||||||
|
if 0 <= x < len(self._id_map):
|
||||||
|
return self._id_map[x]
|
||||||
|
return str(x)
|
||||||
|
|
||||||
|
string_labels = [
|
||||||
|
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
string_labels = [
|
||||||
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -23,17 +24,39 @@ logger = logging.getLogger(__name__)
|
|||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Ensure we have a handler if none exists
|
# Ensure we have handlers if none exist
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
handler.setFormatter(formatter)
|
stream_handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(stream_handler)
|
||||||
logger.propagate = False
|
|
||||||
|
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
|
||||||
|
if log_path:
|
||||||
|
try:
|
||||||
|
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
|
||||||
|
file_formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
except Exception as exc: # pragma: no cover - best effort logging
|
||||||
|
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
|
||||||
|
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: str | None = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -81,199 +104,359 @@ def create_hnsw_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Convert relative paths to absolute paths based on metadata file location
|
# Let PassageManager handle path resolution uniformly. It supports fallback order:
|
||||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
|
||||||
passage_sources = []
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
for source in meta["passage_sources"]:
|
# Dimension from metadata for shaping responses
|
||||||
source_copy = source.copy()
|
try:
|
||||||
# Convert relative paths to absolute paths
|
embedding_dim: int = int(meta.get("dimensions", 0))
|
||||||
if not Path(source_copy["path"]).is_absolute():
|
except Exception:
|
||||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
embedding_dim = 0
|
||||||
if not Path(source_copy["index_path"]).is_absolute():
|
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
|
||||||
passage_sources.append(source_copy)
|
|
||||||
|
|
||||||
passages = PassageManager(passage_sources)
|
# Attempt to load ID map (maps FAISS integer labels -> passage IDs)
|
||||||
logger.info(
|
id_map: list[str] = []
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
try:
|
||||||
)
|
meta_path = Path(passages_file)
|
||||||
|
base = meta_path.name
|
||||||
|
if base.endswith(".meta.json"):
|
||||||
|
base = base[: -len(".meta.json")] # e.g., laion_index.leann
|
||||||
|
if base.endswith(".leann"):
|
||||||
|
base = base[: -len(".leann")] # e.g., laion_index
|
||||||
|
idmap_file = meta_path.parent / f"{base}.ids.txt"
|
||||||
|
if idmap_file.exists():
|
||||||
|
with open(idmap_file, encoding="utf-8") as f:
|
||||||
|
id_map = [line.rstrip("\n") for line in f]
|
||||||
|
logger.info(f"Loaded ID map with {len(id_map)} entries from {idmap_file}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"ID map file not found at {idmap_file}; will use raw labels")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load ID map: {e}")
|
||||||
|
|
||||||
|
def _map_node_id(nid) -> str:
|
||||||
|
try:
|
||||||
|
if id_map is not None and len(id_map) > 0 and isinstance(nid, (int, np.integer)):
|
||||||
|
idx = int(nid)
|
||||||
|
if 0 <= idx < len(id_map):
|
||||||
|
return id_map[idx]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return str(nid)
|
||||||
|
|
||||||
|
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||||
|
|
||||||
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
Creates its own REP socket bound to zmq_port and polls with timeouts
|
||||||
|
to allow graceful shutdown.
|
||||||
|
"""
|
||||||
|
logger.info("ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
def zmq_server_thread():
|
|
||||||
"""ZMQ server thread"""
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
rep_socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
# Track last request type/length for shape-correct fallbacks
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||||
|
last_request_length = 0
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
try:
|
while not shutdown_event.is_set():
|
||||||
message_bytes = socket.recv()
|
try:
|
||||||
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
e2e_start = time.time()
|
||||||
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
|
request_bytes = rep_socket.recv()
|
||||||
|
|
||||||
e2e_start = time.time()
|
# Rest of the processing logic (same as original)
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
request = msgpack.unpackb(request_bytes)
|
||||||
|
|
||||||
# Handle direct text embedding request
|
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
response_bytes = msgpack.packb([model_name])
|
||||||
# Check if this is a direct text request (list of strings)
|
rep_socket.send(response_bytes)
|
||||||
if all(isinstance(item, str) for item in request_payload):
|
continue
|
||||||
logger.info(
|
|
||||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use unified embedding computation (now with model caching)
|
# Handle direct text embedding request
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and request
|
||||||
|
and all(isinstance(item, str) for item in request)
|
||||||
|
):
|
||||||
|
last_request_type = "text"
|
||||||
|
last_request_length = len(request)
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
request_payload, model_name, mode=embedding_mode
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
)
|
)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
response = embeddings.tolist()
|
|
||||||
socket.send(msgpack.packb(response))
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Handle distance calculation request: [[ids], [query_vector]]
|
||||||
if (
|
if (
|
||||||
isinstance(request_payload, list)
|
isinstance(request, list)
|
||||||
and len(request_payload) == 2
|
and len(request) == 2
|
||||||
and isinstance(request_payload[0], list)
|
and isinstance(request[0], list)
|
||||||
and isinstance(request_payload[1], list)
|
and isinstance(request[1], list)
|
||||||
):
|
):
|
||||||
node_ids = request_payload[0]
|
node_ids = request[0]
|
||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
# Handle nested [[ids]] shape defensively
|
||||||
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
|
node_ids = node_ids[0]
|
||||||
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
|
last_request_type = "distance"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
# Get embeddings for node IDs
|
# Gather texts for found ids
|
||||||
texts = []
|
texts: list[str] = []
|
||||||
for nid in node_ids:
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
|
||||||
|
# Prepare full-length response with large sentinel values
|
||||||
|
large_distance = 1e9
|
||||||
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
if distance_metric == "l2":
|
||||||
|
partial = np.sum(
|
||||||
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
|
)
|
||||||
|
else: # mips or cosine
|
||||||
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
|
response_distances[pos] = float(dval)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||||
|
|
||||||
|
# Send response in expected shape [[distances]]
|
||||||
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fallback: treat as embedding-by-id request
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
):
|
||||||
|
node_ids = request[0]
|
||||||
|
elif isinstance(request, list):
|
||||||
|
node_ids = request
|
||||||
|
else:
|
||||||
|
node_ids = []
|
||||||
|
last_request_type = "embedding"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
|
# Preallocate zero-filled flat data for robustness
|
||||||
|
if embedding_dim <= 0:
|
||||||
|
dims = [0, 0]
|
||||||
|
flat_data: list[float] = []
|
||||||
|
else:
|
||||||
|
dims = [len(node_ids), embedding_dim]
|
||||||
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
|
# Collect texts for found ids
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_data = passages.get_passage(str(nid))
|
passage_id = _map_node_id(nid)
|
||||||
txt = passage_data["text"]
|
passage_data = passages.get_passage(passage_id)
|
||||||
texts.append(txt)
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage ID {nid} not found")
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
|
||||||
|
|
||||||
# Process embeddings
|
if texts:
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
try:
|
||||||
logger.info(
|
embeddings = compute_embeddings(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
texts,
|
||||||
)
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate distances
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
if distance_metric == "l2":
|
logger.error(
|
||||||
distances = np.sum(
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
)
|
||||||
)
|
dims = [0, embedding_dim]
|
||||||
else: # mips or cosine
|
flat_data = []
|
||||||
distances = -np.dot(embeddings, query_vector)
|
else:
|
||||||
|
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
flat = emb_f32.flatten().tolist()
|
||||||
|
for j, pos in enumerate(found_indices):
|
||||||
|
start = pos * embedding_dim
|
||||||
|
end = start + embedding_dim
|
||||||
|
if end <= len(flat_data):
|
||||||
|
flat_data[start:end] = flat[
|
||||||
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||||
|
|
||||||
response_payload = distances.flatten().tolist()
|
response_payload = [dims, flat_data]
|
||||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
logger.debug(f"Sending distance response with {len(distances)} distances")
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
rep_socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
# Shape-correct fallback
|
||||||
|
try:
|
||||||
|
if last_request_type == "distance":
|
||||||
|
large_distance = 1e9
|
||||||
|
fallback_len = max(0, int(last_request_length))
|
||||||
|
safe = [[large_distance] * fallback_len]
|
||||||
|
elif last_request_type == "embedding":
|
||||||
|
bsz = max(0, int(last_request_length))
|
||||||
|
dim = max(0, int(embedding_dim))
|
||||||
|
safe = (
|
||||||
|
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||||
|
)
|
||||||
|
elif last_request_type == "text":
|
||||||
|
safe = [] # direct text embeddings expectation is a flat list
|
||||||
|
else:
|
||||||
|
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Standard embedding request (passage ID lookup)
|
logger.info("ZMQ server thread exiting gracefully")
|
||||||
if (
|
|
||||||
not isinstance(request_payload, list)
|
|
||||||
or len(request_payload) != 1
|
|
||||||
or not isinstance(request_payload[0], list)
|
|
||||||
):
|
|
||||||
logger.error(
|
|
||||||
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
|
||||||
)
|
|
||||||
socket.send(msgpack.packb([[], []]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_ids = request_payload[0]
|
# Add shutdown coordination
|
||||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
# Look up texts by node IDs
|
def shutdown_zmq_server():
|
||||||
texts = []
|
"""Gracefully shutdown ZMQ server."""
|
||||||
for nid in node_ids:
|
logger.info("Initiating graceful shutdown...")
|
||||||
try:
|
shutdown_event.set()
|
||||||
passage_data = passages.get_passage(str(nid))
|
|
||||||
txt = passage_data["text"]
|
|
||||||
if not txt:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Process embeddings
|
if zmq_thread.is_alive():
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
logger.info("Waiting for ZMQ thread to finish...")
|
||||||
logger.info(
|
zmq_thread.join(timeout=5)
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
if zmq_thread.is_alive():
|
||||||
)
|
logger.warning("ZMQ thread did not finish in time")
|
||||||
|
|
||||||
# Serialization and response
|
# Clean up ZMQ resources
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
try:
|
||||||
logger.error(
|
# Note: socket and context are cleaned up by thread exit
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
logger.info("ZMQ resources cleaned up")
|
||||||
)
|
except Exception as e:
|
||||||
raise AssertionError()
|
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||||
|
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
# Clean up other resources
|
||||||
response_payload = [
|
try:
|
||||||
list(hidden_contiguous_f32.shape),
|
import gc
|
||||||
hidden_contiguous_f32.flatten().tolist(),
|
|
||||||
]
|
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
gc.collect()
|
||||||
e2e_end = time.time()
|
logger.info("Additional resources cleaned up")
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning additional resources: {e}")
|
||||||
|
|
||||||
except zmq.Again:
|
logger.info("Graceful shutdown completed")
|
||||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
sys.exit(0)
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
# Register signal handlers within this function scope
|
||||||
socket.send(msgpack.packb([[], []]))
|
import signal
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
shutdown_zmq_server()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Pass shutdown_event to ZMQ thread
|
||||||
|
zmq_thread = threading.Thread(
|
||||||
|
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||||
|
daemon=False, # Not daemon - we want to wait for it
|
||||||
|
)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(0.1) # Check shutdown more frequently
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
|
shutdown_zmq_server()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we reach here, shutdown was triggered by signal
|
||||||
|
logger.info("Main loop exited, process should be shutting down")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# Signal handlers are now registered within create_hnsw_embedding_server
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
@@ -295,7 +478,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.2.1"
|
version = "0.3.4"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.2.1",
|
"leann-core==0.3.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
@@ -22,6 +22,8 @@ cmake.build-type = "Release"
|
|||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
|
||||||
# CMake definitions to optimize compilation
|
# CMake definitions to optimize compilation and find Homebrew packages
|
||||||
[tool.scikit-build.cmake.define]
|
[tool.scikit-build.cmake.define]
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
|
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
|
||||||
|
OpenMP_ROOT = {env = "OpenMP_ROOT"}
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: ff22e2c86b...c69511a99c
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.2.1"
|
version = "0.3.4"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
@@ -31,8 +31,10 @@ dependencies = [
|
|||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
"pymupdf>=1.23.0",
|
"pymupdf>=1.23.0",
|
||||||
"pdfplumber>=0.10.0",
|
"pdfplumber>=0.10.0",
|
||||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
"nbconvert>=7.0.0", # For .ipynb file support
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
||||||
|
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
|
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -44,6 +46,7 @@ colab = [
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
leann = "leann.cli:main"
|
leann = "leann.cli:main"
|
||||||
|
leann_mcp = "leann.mcp:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|||||||
@@ -6,18 +6,22 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
|
from .metadata_filter import MetadataFilterEngine
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -33,8 +37,9 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: int | None = None,
|
port: Optional[int] = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -46,6 +51,7 @@ def compute_embeddings(
|
|||||||
- "sentence-transformers": Use sentence-transformers library (default)
|
- "sentence-transformers": Use sentence-transformers library (default)
|
||||||
- "mlx": Use MLX backend for Apple Silicon
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
- "openai": Use OpenAI embedding API
|
- "openai": Use OpenAI embedding API
|
||||||
|
- "gemini": Use Google Gemini embedding API
|
||||||
use_server: Whether to use embedding server (True for search, False for build)
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -67,6 +73,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -115,60 +122,188 @@ class SearchResult:
|
|||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
def __init__(
|
||||||
self.offset_maps = {}
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||||
self.passage_files = {}
|
):
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.offset_maps: dict[str, dict[str, int]] = {}
|
||||||
|
self.passage_files: dict[str, str] = {}
|
||||||
|
# Avoid materializing a single gigantic global map to reduce memory
|
||||||
|
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
||||||
|
# per-shard maps and do a lightweight per-shard lookup on demand.
|
||||||
|
self._total_count: int = 0
|
||||||
|
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
|
||||||
|
|
||||||
|
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
||||||
|
index_name_base = None
|
||||||
|
if metadata_file_path:
|
||||||
|
meta_name = Path(metadata_file_path).name
|
||||||
|
if meta_name.endswith(".meta.json"):
|
||||||
|
index_name_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source.get("path", "")
|
||||||
index_file = source["index_path"] # .idx file
|
index_file = source.get("index_path", "") # .idx file
|
||||||
|
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution - relative paths should be relative to metadata file directory
|
||||||
if not Path(index_file).is_absolute():
|
def _resolve_candidates(
|
||||||
# If relative path, try to resolve it properly
|
primary: str,
|
||||||
index_file = str(Path(index_file).resolve())
|
relative_key: str,
|
||||||
|
default_name: Optional[str],
|
||||||
|
source_dict: dict[str, Any],
|
||||||
|
) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Build an ordered list of candidate paths. For relative paths specified in
|
||||||
|
metadata, prefer resolution relative to the metadata file directory first,
|
||||||
|
then fall back to CWD-based resolution, and finally to conventional
|
||||||
|
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
|
||||||
|
"""
|
||||||
|
candidates: list[Path] = []
|
||||||
|
# 1) Primary path
|
||||||
|
if primary:
|
||||||
|
p = Path(primary)
|
||||||
|
if p.is_absolute():
|
||||||
|
candidates.append(p)
|
||||||
|
else:
|
||||||
|
# Prefer metadata-relative resolution for relative paths
|
||||||
|
if metadata_file_path:
|
||||||
|
candidates.append(Path(metadata_file_path).parent / p)
|
||||||
|
# Also consider CWD-relative as a fallback for legacy layouts
|
||||||
|
candidates.append(Path.cwd() / p)
|
||||||
|
# 2) metadata-relative explicit relative key (if present)
|
||||||
|
if metadata_file_path and source_dict.get(relative_key):
|
||||||
|
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
||||||
|
# 3) metadata-relative standard sibling filename
|
||||||
|
if metadata_file_path and default_name:
|
||||||
|
candidates.append(Path(metadata_file_path).parent / default_name)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
|
||||||
|
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
|
||||||
|
idx_candidates = _resolve_candidates(
|
||||||
|
index_file, "index_path_relative", idx_default, source
|
||||||
|
)
|
||||||
|
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
|
||||||
|
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
|
||||||
|
|
||||||
|
def _pick_existing(cands: list[Path]) -> str:
|
||||||
|
for c in cands:
|
||||||
|
if c.exists():
|
||||||
|
return str(c.resolve())
|
||||||
|
# Fallback to last candidate (best guess) even if not exists; will error below
|
||||||
|
return str(cands[-1].resolve()) if cands else ""
|
||||||
|
|
||||||
|
index_file = _pick_existing(idx_candidates)
|
||||||
|
passage_file = _pick_existing(pas_candidates)
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
offset_map = pickle.load(f)
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
self.offset_maps[passage_file] = offset_map
|
self.offset_maps[passage_file] = offset_map
|
||||||
self.passage_files[passage_file] = passage_file
|
self.passage_files[passage_file] = passage_file
|
||||||
|
self._total_count += len(offset_map)
|
||||||
# Build global map for O(1) lookup
|
|
||||||
for passage_id, offset in offset_map.items():
|
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
# Fast path: check each shard map (there are typically few shards).
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
# This avoids building a massive combined dict while keeping lookups
|
||||||
# Lazy file opening - only open when needed
|
# bounded by the number of shards.
|
||||||
with open(passage_file, encoding="utf-8") as f:
|
for passage_file, offset_map in self.offset_maps.items():
|
||||||
f.seek(offset)
|
try:
|
||||||
return json.loads(f.readline())
|
offset = offset_map[passage_id]
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
f.seek(offset)
|
||||||
|
return json.loads(f.readline())
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
|
def filter_search_results(
|
||||||
|
self,
|
||||||
|
search_results: list[SearchResult],
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of SearchResult objects
|
||||||
|
metadata_filters: Filter specifications to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of SearchResult objects
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying metadata filters to {len(search_results)} results")
|
||||||
|
|
||||||
|
# Convert SearchResult objects to dictionaries for the filter engine
|
||||||
|
result_dicts = []
|
||||||
|
for result in search_results:
|
||||||
|
result_dicts.append(
|
||||||
|
{
|
||||||
|
"id": result.id,
|
||||||
|
"score": result.score,
|
||||||
|
"text": result.text,
|
||||||
|
"metadata": result.metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters using the filter engine
|
||||||
|
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
|
||||||
|
|
||||||
|
# Convert back to SearchResult objects
|
||||||
|
filtered_results = []
|
||||||
|
for result_dict in filtered_dicts:
|
||||||
|
filtered_results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=result_dict["id"],
|
||||||
|
score=result_dict["score"],
|
||||||
|
text=result_dict["text"],
|
||||||
|
metadata=result_dict["metadata"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._total_count
|
||||||
|
|
||||||
|
|
||||||
class LeannBuilder:
|
class LeannBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: int | None = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
embedding_options: Optional[dict[str, Any]] = None,
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
# Normalize incompatible combinations early (for consistent metadata)
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
is_recompute = backend_kwargs.get("is_recompute", True)
|
||||||
|
is_compact = backend_kwargs.get("is_compact", True)
|
||||||
|
if is_recompute is False and is_compact is True:
|
||||||
|
warnings.warn(
|
||||||
|
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
backend_kwargs["is_compact"] = False
|
||||||
|
|
||||||
|
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
self.backend_factory = backend_factory
|
self.backend_factory = backend_factory
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
|
self.embedding_options = embedding_options or {}
|
||||||
|
|
||||||
# Check if we need to use cosine distance for normalized embeddings
|
# Check if we need to use cosine distance for normalized embeddings
|
||||||
normalized_embeddings_models = {
|
normalized_embeddings_models = {
|
||||||
@@ -242,7 +377,7 @@ class LeannBuilder:
|
|||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
self.chunks: list[dict[str, Any]] = []
|
self.chunks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
@@ -252,6 +387,23 @@ class LeannBuilder:
|
|||||||
def build_index(self, index_path: str):
|
def build_index(self, index_path: str):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
raise ValueError("No chunks added.")
|
raise ValueError("No chunks added.")
|
||||||
|
|
||||||
|
# Filter out invalid/empty text chunks early to keep passage and embedding counts aligned
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
skipped = 0
|
||||||
|
for chunk in self.chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if isinstance(text, str) and text.strip():
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
if skipped > 0:
|
||||||
|
print(
|
||||||
|
f"Warning: Skipping {skipped} empty/invalid text chunk(s). Processing {len(valid_chunks)} valid chunks"
|
||||||
|
)
|
||||||
|
self.chunks = valid_chunks
|
||||||
|
if not self.chunks:
|
||||||
|
raise ValueError("All provided chunks are empty or invalid. Nothing to index.")
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = len(
|
self.dimensions = len(
|
||||||
compute_embeddings(
|
compute_embeddings(
|
||||||
@@ -259,6 +411,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model,
|
self.embedding_model,
|
||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)[0]
|
)[0]
|
||||||
)
|
)
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -298,8 +451,20 @@ class LeannBuilder:
|
|||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
is_build=True,
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
|
# Persist ID map alongside index so backends that return integer labels can remap to passage IDs
|
||||||
|
try:
|
||||||
|
idmap_file = (
|
||||||
|
index_dir
|
||||||
|
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||||
|
)
|
||||||
|
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||||
|
for sid in string_ids:
|
||||||
|
f.write(str(sid) + "\n")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||||
@@ -314,20 +479,25 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": str(passages_file),
|
# Preserve existing relative file names (backward-compatible)
|
||||||
"index_path": str(offset_file),
|
"path": passages_file.name,
|
||||||
|
"index_path": offset_file.name,
|
||||||
|
# Add optional redundant relative keys for remote build portability (non-breaking)
|
||||||
|
"path_relative": passages_file.name,
|
||||||
|
"index_path_relative": offset_file.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
meta_data["is_compact"] = is_compact
|
meta_data["is_compact"] = is_compact
|
||||||
meta_data["is_pruned"] = (
|
meta_data["is_pruned"] = bool(is_recompute)
|
||||||
is_compact and is_recompute
|
|
||||||
) # Pruned only if compact and recompute
|
|
||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
@@ -414,6 +584,17 @@ class LeannBuilder:
|
|||||||
|
|
||||||
# Build the vector index using precomputed embeddings
|
# Build the vector index using precomputed embeddings
|
||||||
string_ids = [str(id_val) for id_val in ids]
|
string_ids = [str(id_val) for id_val in ids]
|
||||||
|
# Persist ID map (order == embeddings order)
|
||||||
|
try:
|
||||||
|
idmap_file = (
|
||||||
|
index_dir
|
||||||
|
/ f"{index_name[: -len('.leann')] if index_name.endswith('.leann') else index_name}.ids.txt"
|
||||||
|
)
|
||||||
|
with open(idmap_file, "w", encoding="utf-8") as f:
|
||||||
|
for sid in string_ids:
|
||||||
|
f.write(str(sid) + "\n")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
builder_instance.build(embeddings, string_ids, index_path)
|
builder_instance.build(embeddings, string_ids, index_path)
|
||||||
@@ -430,26 +611,178 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": str(passages_file),
|
# Preserve existing relative file names (backward-compatible)
|
||||||
"index_path": str(offset_file),
|
"path": passages_file.name,
|
||||||
|
"index_path": offset_file.name,
|
||||||
|
# Add optional redundant relative keys for remote build portability (non-breaking)
|
||||||
|
"path_relative": passages_file.name,
|
||||||
|
"index_path_relative": offset_file.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"built_from_precomputed_embeddings": True,
|
"built_from_precomputed_embeddings": True,
|
||||||
"embeddings_source": str(embeddings_file),
|
"embeddings_source": str(embeddings_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
meta_data["is_compact"] = is_compact
|
meta_data["is_compact"] = is_compact
|
||||||
meta_data["is_pruned"] = is_compact and is_recompute
|
meta_data["is_pruned"] = bool(is_recompute)
|
||||||
|
|
||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||||
|
|
||||||
|
def update_index(self, index_path: str):
|
||||||
|
"""Append new passages and vectors to an existing HNSW index."""
|
||||||
|
if not self.chunks:
|
||||||
|
raise ValueError("No new chunks provided for update.")
|
||||||
|
|
||||||
|
path = Path(index_path)
|
||||||
|
index_dir = path.parent
|
||||||
|
index_name = path.name
|
||||||
|
index_prefix = path.stem
|
||||||
|
|
||||||
|
meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||||
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||||
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
|
|
||||||
|
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
|
||||||
|
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
|
||||||
|
if not index_file.exists():
|
||||||
|
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
backend_name = meta.get("backend_name")
|
||||||
|
if backend_name != self.backend_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_backend_kwargs = meta.get("backend_kwargs", {})
|
||||||
|
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
|
||||||
|
if index_is_compact:
|
||||||
|
raise ValueError(
|
||||||
|
"Compact HNSW indices do not support in-place updates. Rebuild required."
|
||||||
|
)
|
||||||
|
|
||||||
|
distance_metric = meta_backend_kwargs.get(
|
||||||
|
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
|
||||||
|
).lower()
|
||||||
|
needs_recompute = bool(
|
||||||
|
meta.get("is_pruned")
|
||||||
|
or meta_backend_kwargs.get("is_recompute")
|
||||||
|
or self.backend_kwargs.get("is_recompute")
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in self.chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_dim = embeddings.shape[1]
|
||||||
|
expected_dim = meta.get("dimensions")
|
||||||
|
if expected_dim is not None and expected_dim != embedding_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
|
||||||
|
)
|
||||||
|
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
if hasattr(index, "is_recompute"):
|
||||||
|
index.is_recompute = needs_recompute
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
if index.d != embedding_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Appended %d passages to index '%s'. New total: %d",
|
||||||
|
len(valid_chunks),
|
||||||
|
index_path,
|
||||||
|
len(offset_map),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chunks.clear()
|
||||||
|
|
||||||
|
if needs_recompute:
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
@@ -473,12 +806,20 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||||
|
# Delegate portability handling to PassageManager
|
||||||
|
self.passage_manager = PassageManager(
|
||||||
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
|
)
|
||||||
|
# Preserve backend name for conditional parameter forwarding
|
||||||
|
self.backend_name = backend_name
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
|
if self.embedding_options:
|
||||||
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
@@ -493,15 +834,49 @@ class LeannSearcher:
|
|||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors with optional metadata filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text query to search for
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||||
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
Supported operators:
|
||||||
|
- Comparison: "==", "!=", "<", "<=", ">", ">="
|
||||||
|
- Membership: "in", "not_in"
|
||||||
|
- String: "contains", "starts_with", "ends_with"
|
||||||
|
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
|
||||||
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects with text, metadata, and similarity scores
|
||||||
|
"""
|
||||||
|
# Handle grep search
|
||||||
|
if use_grep:
|
||||||
|
return self._grep_search(query, top_k)
|
||||||
|
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
|
logger.info(f" Metadata filters: {metadata_filters}")
|
||||||
logger.info(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Smart top_k detection and adjustment
|
# Smart top_k detection and adjustment
|
||||||
total_docs = len(self.passage_manager.global_offset_map)
|
# Use PassageManager length (sum of shard sizes) to avoid
|
||||||
|
# depending on a massive combined map
|
||||||
|
total_docs = len(self.passage_manager)
|
||||||
original_top_k = top_k
|
original_top_k = top_k
|
||||||
if top_k > total_docs:
|
if top_k > total_docs:
|
||||||
top_k = total_docs
|
top_k = total_docs
|
||||||
@@ -530,31 +905,41 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
backend_search_kwargs: dict[str, Any] = {
|
||||||
|
"complexity": complexity,
|
||||||
|
"beam_width": beam_width,
|
||||||
|
"prune_ratio": prune_ratio,
|
||||||
|
"recompute_embeddings": recompute_embeddings,
|
||||||
|
"pruning_strategy": pruning_strategy,
|
||||||
|
"zmq_port": zmq_port,
|
||||||
|
}
|
||||||
|
# Only HNSW supports batching; forward conditionally
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
backend_search_kwargs["batch_size"] = batch_size
|
||||||
|
|
||||||
|
# Merge any extra kwargs last
|
||||||
|
backend_search_kwargs.update(kwargs)
|
||||||
|
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
top_k,
|
top_k,
|
||||||
complexity=complexity,
|
**backend_search_kwargs,
|
||||||
beam_width=beam_width,
|
|
||||||
prune_ratio=prune_ratio,
|
|
||||||
recompute_embeddings=recompute_embeddings,
|
|
||||||
pruning_strategy=pruning_strategy,
|
|
||||||
zmq_port=zmq_port,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
|
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0], strict=False)
|
zip(results["labels"][0], results["distances"][0])
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
@@ -580,23 +965,154 @@ class LeannSearcher:
|
|||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
|
RESET = "\033[0m"
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply metadata filters if specified
|
||||||
|
if metadata_filters:
|
||||||
|
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
|
||||||
|
enriched_results = self.passage_manager.filter_search_results(
|
||||||
|
enriched_results, metadata_filters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define color codes outside the loop for final message
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
RESET = "\033[0m"
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
def _find_jsonl_file(self) -> Optional[str]:
|
||||||
|
"""Find the .jsonl file containing raw passages for grep search"""
|
||||||
|
index_path = Path(self.meta_path_str).parent
|
||||||
|
potential_files = [
|
||||||
|
index_path / "documents.leann.passages.jsonl",
|
||||||
|
index_path.parent / "documents.leann.passages.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_path in potential_files:
|
||||||
|
if file_path.exists():
|
||||||
|
return str(file_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Perform grep-based search on raw passages"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl passages file found for grep search")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = ["grep", "-i", "-n", query, jsonl_file]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 1:
|
||||||
|
return []
|
||||||
|
elif result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Grep failed: {result.stderr}")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for line in result.stdout.strip().split("\n"):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(parts[1])
|
||||||
|
text = data.get("text", "")
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", parts[0]),
|
||||||
|
text=text,
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(score),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"grep command not found. Please install grep or use semantic search."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Fallback regex search"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl file found")
|
||||||
|
|
||||||
|
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
with open(jsonl_file, encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if pattern.search(line):
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", str(line_num)),
|
||||||
|
text=data.get("text", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(len(pattern.findall(data.get("text", "")))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
This method should be called after you're done using the searcher,
|
||||||
|
especially in test environments or batch processing scenarios.
|
||||||
|
"""
|
||||||
|
backend = getattr(self.backend_impl, "embedding_server_manager", None)
|
||||||
|
if backend is not None:
|
||||||
|
backend.stop_server()
|
||||||
|
|
||||||
|
# Enable automatic cleanup patterns
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
# Avoid noisy errors during interpreter shutdown
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: dict[str, Any] | None = None,
|
llm_config: Optional[dict[str, Any]] = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
|
searcher: Optional[LeannSearcher] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
if searcher is None:
|
||||||
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
|
self._owns_searcher = True
|
||||||
|
else:
|
||||||
|
self.searcher = searcher
|
||||||
|
self._owns_searcher = False
|
||||||
self.llm = get_llm(llm_config)
|
self.llm = get_llm(llm_config)
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
@@ -608,8 +1124,11 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: dict[str, Any] | None = None,
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
@@ -624,10 +1143,12 @@ class LeannChat:
|
|||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
|
metadata_filters=metadata_filters,
|
||||||
|
batch_size=batch_size,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
search_time = time.time() - search_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -656,3 +1177,30 @@ class LeannChat:
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
print("\nGoodbye!")
|
print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
|
This method should be called after you're done using the chat interface,
|
||||||
|
especially in test environments or batch processing scenarios.
|
||||||
|
"""
|
||||||
|
# Only stop the embedding server if this LeannChat instance created the searcher.
|
||||||
|
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
|
||||||
|
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
# Enable automatic cleanup patterns
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import difflib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,12 +17,12 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_models() -> list[str]:
|
def check_ollama_models(host: str) -> list[str]:
|
||||||
"""Check available Ollama models and return a list"""
|
"""Check available Ollama models and return a list"""
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return [model["name"] for model in data.get("models", [])]
|
return [model["name"] for model in data.get("models", [])]
|
||||||
@@ -309,10 +309,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
return search_hf_models_fuzzy(query, limit)
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
def validate_model_and_suggest(
|
||||||
|
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||||
|
) -> Optional[str]:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models()
|
available_models = check_ollama_models(host)
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -420,7 +422,6 @@ class LLMInterface(ABC):
|
|||||||
top_k=10,
|
top_k=10,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
beam_width=8,
|
beam_width=8,
|
||||||
USE_DEFERRED_FETCH=True,
|
|
||||||
skip_search_reorder=True,
|
skip_search_reorder=True,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
dedup_node_dis=True,
|
dedup_node_dis=True,
|
||||||
@@ -432,7 +433,6 @@ class LLMInterface(ABC):
|
|||||||
Supported kwargs:
|
Supported kwargs:
|
||||||
- complexity (int): Search complexity parameter (default: 32)
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
- beam_width (int): Beam width for search (default: 4)
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
|
||||||
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||||
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||||
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||||
@@ -469,7 +469,7 @@ class OllamaChat(LLMInterface):
|
|||||||
requests.get(host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama")
|
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -577,18 +577,33 @@ class HFChat(LLMInterface):
|
|||||||
def timeout_handler(signum, frame):
|
def timeout_handler(signum, frame):
|
||||||
raise TimeoutError("Model download/loading timed out")
|
raise TimeoutError("Model download/loading timed out")
|
||||||
|
|
||||||
# Set timeout for model loading (60 seconds)
|
# Set timeout for model loading (increase to 300s for large models)
|
||||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||||
signal.alarm(60)
|
signal.alarm(300)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Loading tokenizer for {model_name}...")
|
logger.info(f"Loading tokenizer for {model_name}...")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
logger.info(f"Loading model {model_name}...")
|
logger.info(f"Loading model {model_name}...")
|
||||||
|
# Choose a numerically stable dtype per device
|
||||||
|
if self.device == "cuda":
|
||||||
|
# Prefer bfloat16 when available; otherwise fall back to float16
|
||||||
|
try:
|
||||||
|
bf16_ok = torch.cuda.is_bf16_supported()
|
||||||
|
except Exception:
|
||||||
|
bf16_ok = False
|
||||||
|
load_dtype = torch.bfloat16 if bf16_ok else torch.float16
|
||||||
|
elif self.device == "mps":
|
||||||
|
# On Apple MPS, float16 often causes NaNs/INFs during sampling.
|
||||||
|
# Use float32 for stability, even if it increases memory.
|
||||||
|
load_dtype = torch.float32
|
||||||
|
else:
|
||||||
|
load_dtype = torch.float32
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
torch_dtype=load_dtype,
|
||||||
device_map="auto" if self.device != "cpu" else None,
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
@@ -606,8 +621,12 @@ class HFChat(LLMInterface):
|
|||||||
logger.error(f"Failed to load model {model_name}: {e}")
|
logger.error(f"Failed to load model {model_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Move model to device if not using device_map
|
# Move model to device only if not managed by accelerate (no device_map)
|
||||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
try:
|
||||||
|
has_device_map = getattr(self.model, "hf_device_map", None) is not None
|
||||||
|
except Exception:
|
||||||
|
has_device_map = False
|
||||||
|
if self.device != "cpu" and not has_device_map:
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
# Set pad token if not present
|
# Set pad token if not present
|
||||||
@@ -639,13 +658,15 @@ class HFChat(LLMInterface):
|
|||||||
# Fallback for models without chat template
|
# Fallback for models without chat template
|
||||||
formatted_prompt = prompt
|
formatted_prompt = prompt
|
||||||
|
|
||||||
# Tokenize input
|
# Tokenize input (respect model context length when available)
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
formatted_prompt,
|
formatted_prompt,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=2048,
|
max_length=getattr(
|
||||||
|
getattr(self.model, "config", None), "max_position_embeddings", 2048
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move inputs to device
|
# Move inputs to device
|
||||||
@@ -660,6 +681,8 @@ class HFChat(LLMInterface):
|
|||||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
"eos_token_id": self.tokenizer.eos_token_id,
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
# Helps avoid numerical issues in sampling when logits processors are used
|
||||||
|
"renormalize_logits": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle temperature=0 for greedy decoding
|
# Handle temperature=0 for greedy decoding
|
||||||
@@ -669,21 +692,103 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
|
|
||||||
# Generate
|
# Streaming support (optional)
|
||||||
|
stream = bool(kwargs.get("stream", False))
|
||||||
|
if stream:
|
||||||
|
try:
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen():
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.generate(**inputs, **generation_config, streamer=streamer)
|
||||||
|
|
||||||
|
t = Thread(target=_gen)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
pieces = []
|
||||||
|
for new_text in streamer:
|
||||||
|
print(new_text, end="", flush=True)
|
||||||
|
pieces.append(new_text)
|
||||||
|
t.join()
|
||||||
|
print("") # newline after streaming
|
||||||
|
return ("".join(pieces)).strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Streaming failed, falling back to non-streaming: {e}")
|
||||||
|
|
||||||
|
# Non-streaming path
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model.generate(**inputs, **generation_config)
|
outputs = self.model.generate(**inputs, **generation_config)
|
||||||
|
|
||||||
# Decode response
|
|
||||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiChat(LLMInterface):
|
||||||
|
"""LLM interface for Google Gemini models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Initializing Gemini Chat with model='{model}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import google.genai as genai
|
||||||
|
|
||||||
|
self.client = genai.Client(api_key=self.api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info(f"Sending request to Gemini with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from google.genai.types import GenerateContentConfig
|
||||||
|
|
||||||
|
generation_config = GenerateContentConfig(
|
||||||
|
temperature=kwargs.get("temperature", 0.7),
|
||||||
|
max_output_tokens=kwargs.get("max_tokens", 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle top_p parameter
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
generation_config.top_p = kwargs["top_p"]
|
||||||
|
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=self.model,
|
||||||
|
contents=prompt,
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
# Handle potential None response text
|
||||||
|
response_text = response.text
|
||||||
|
if response_text is None:
|
||||||
|
logger.warning("Gemini returned None response text")
|
||||||
|
return ""
|
||||||
|
return response_text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with Gemini: {e}")
|
||||||
|
return f"Error: Could not get a response from Gemini. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
@@ -759,7 +864,7 @@ class SimulatedChat(LLMInterface):
|
|||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
@@ -793,6 +898,8 @@ def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
|||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "gemini":
|
||||||
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Enhanced chunking utilities with AST-aware code chunking support.
|
||||||
|
Packaged within leann-core so installed wheels can import it reliably.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Code file extensions supported by astchunk
|
||||||
|
CODE_EXTENSIONS = {
|
||||||
|
".py": "python",
|
||||||
|
".java": "java",
|
||||||
|
".cs": "csharp",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
".js": "typescript",
|
||||||
|
".jsx": "typescript",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||||
|
"""Separate documents into code files and regular text files."""
|
||||||
|
if code_extensions is None:
|
||||||
|
code_extensions = CODE_EXTENSIONS
|
||||||
|
|
||||||
|
code_docs = []
|
||||||
|
text_docs = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
||||||
|
if file_path:
|
||||||
|
file_ext = Path(file_path).suffix.lower()
|
||||||
|
if file_ext in code_extensions:
|
||||||
|
doc.metadata["language"] = code_extensions[file_ext]
|
||||||
|
doc.metadata["is_code"] = True
|
||||||
|
code_docs.append(doc)
|
||||||
|
else:
|
||||||
|
doc.metadata["is_code"] = False
|
||||||
|
text_docs.append(doc)
|
||||||
|
else:
|
||||||
|
doc.metadata["is_code"] = False
|
||||||
|
text_docs.append(doc)
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
||||||
|
return code_docs, text_docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||||
|
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
return CODE_EXTENSIONS.get(ext)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ast_chunks(
|
||||||
|
documents,
|
||||||
|
max_chunk_size: int = 512,
|
||||||
|
chunk_overlap: int = 64,
|
||||||
|
metadata_template: str = "default",
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
|
|
||||||
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"astchunk not available: {e}")
|
||||||
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
for doc in documents:
|
||||||
|
language = doc.metadata.get("language")
|
||||||
|
if not language:
|
||||||
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
configs = {
|
||||||
|
"max_chunk_size": max_chunk_size,
|
||||||
|
"language": language,
|
||||||
|
"metadata_template": metadata_template,
|
||||||
|
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
repo_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
"creation_date": doc.metadata.get("creation_date", ""),
|
||||||
|
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
||||||
|
}
|
||||||
|
configs["repo_level_metadata"] = repo_metadata
|
||||||
|
|
||||||
|
chunk_builder = ASTChunkBuilder(**configs)
|
||||||
|
code_content = doc.get_content()
|
||||||
|
if not code_content or not code_content.strip():
|
||||||
|
logger.warning("Empty code content, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
|
for chunk in chunks:
|
||||||
|
if hasattr(chunk, "text"):
|
||||||
|
chunk_text = chunk.text
|
||||||
|
elif isinstance(chunk, dict) and "text" in chunk:
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
chunk_text = chunk
|
||||||
|
else:
|
||||||
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
|
if chunk_text and chunk_text.strip():
|
||||||
|
all_chunks.append(chunk_text.strip())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
|
logger.info("Falling back to traditional chunking")
|
||||||
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
|
||||||
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def create_traditional_chunks(
|
||||||
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||||
|
if chunk_size <= 0:
|
||||||
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
|
chunk_size = 256
|
||||||
|
if chunk_overlap < 0:
|
||||||
|
chunk_overlap = 0
|
||||||
|
if chunk_overlap >= chunk_size:
|
||||||
|
chunk_overlap = chunk_size // 2
|
||||||
|
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
|
content = doc.get_content()
|
||||||
|
if content and content.strip():
|
||||||
|
all_texts.append(content.strip())
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
chunk_overlap: int = 128,
|
||||||
|
use_ast_chunking: bool = False,
|
||||||
|
ast_chunk_size: int = 512,
|
||||||
|
ast_chunk_overlap: int = 64,
|
||||||
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
|
ast_fallback_traditional: bool = True,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create text chunks from documents with optional AST support for code files."""
|
||||||
|
if not documents:
|
||||||
|
logger.warning("No documents provided for chunking")
|
||||||
|
return []
|
||||||
|
|
||||||
|
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||||
|
if code_file_extensions:
|
||||||
|
ext_mapping = {
|
||||||
|
".py": "python",
|
||||||
|
".java": "java",
|
||||||
|
".cs": "c_sharp",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
}
|
||||||
|
for ext in code_file_extensions:
|
||||||
|
if ext.lower() not in local_code_extensions:
|
||||||
|
if ext.lower() in ext_mapping:
|
||||||
|
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
if use_ast_chunking:
|
||||||
|
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||||
|
if code_docs:
|
||||||
|
try:
|
||||||
|
all_chunks.extend(
|
||||||
|
create_ast_chunks(
|
||||||
|
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AST chunking failed: {e}")
|
||||||
|
if ast_fallback_traditional:
|
||||||
|
all_chunks.extend(
|
||||||
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
if text_docs:
|
||||||
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||||
|
else:
|
||||||
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
|
return all_chunks
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -28,6 +31,9 @@ def compute_embeddings(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -35,7 +41,7 @@ def compute_embeddings(
|
|||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Model name
|
model_name: Model name
|
||||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
batch_size: Batch size for processing
|
batch_size: Batch size for processing
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
@@ -43,6 +49,8 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
|
provider_options = provider_options or {}
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(
|
return compute_embeddings_sentence_transformers(
|
||||||
texts,
|
texts,
|
||||||
@@ -50,11 +58,27 @@ def compute_embeddings(
|
|||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
adaptive_optimization=adaptive_optimization,
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
manual_tokenize=manual_tokenize,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
base_url=provider_options.get("base_url"),
|
||||||
|
api_key=provider_options.get("api_key"),
|
||||||
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
|
elif mode == "ollama":
|
||||||
|
return compute_embeddings_ollama(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
is_build=is_build,
|
||||||
|
host=provider_options.get("host"),
|
||||||
|
)
|
||||||
|
elif mode == "gemini":
|
||||||
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
@@ -67,6 +91,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
@@ -210,20 +236,130 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
|
||||||
# Use torch.inference_mode for optimal performance
|
start_time = time.time()
|
||||||
with torch.inference_mode():
|
if not manual_tokenize:
|
||||||
embeddings = model.encode(
|
# Use SentenceTransformer's optimized encode path (default)
|
||||||
texts,
|
with torch.inference_mode():
|
||||||
batch_size=batch_size,
|
embeddings = model.encode(
|
||||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
texts,
|
||||||
convert_to_numpy=True,
|
batch_size=batch_size,
|
||||||
normalize_embeddings=False,
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
device=device,
|
convert_to_numpy=True,
|
||||||
)
|
normalize_embeddings=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Synchronize if CUDA to measure accurate wall time
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
|
try:
|
||||||
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
|
else:
|
||||||
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
|
hf_model.to(device)
|
||||||
|
hf_model.eval()
|
||||||
|
# Optional compile on supported devices
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
|
all_embeddings: list[np.ndarray] = []
|
||||||
|
# Progress bar when building or for large inputs
|
||||||
|
show_progress = is_build or len(texts) > 32
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm # type: ignore
|
||||||
|
|
||||||
|
batch_iter = tqdm(
|
||||||
|
range(0, len(texts), batch_size),
|
||||||
|
desc="Embedding (manual)",
|
||||||
|
unit="batch",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
except Exception:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
start_time_manual = time.time()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for start_index in batch_iter:
|
||||||
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
|
inputs = hf_tokenizer(
|
||||||
|
batch_texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
|
masked = last_hidden_state * mask
|
||||||
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
|
||||||
|
end_time = time.time()
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
logger.info(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
@@ -232,26 +368,41 @@ def compute_embeddings_sentence_transformers(
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
import os
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
# Validate input list
|
||||||
if not api_key:
|
if not texts:
|
||||||
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
# Extra validation: abort early if any item is empty/whitespace
|
||||||
|
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
|
||||||
|
if invalid_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_base_url = resolve_openai_base_url(base_url)
|
||||||
|
resolved_api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Cache OpenAI client
|
||||||
cache_key = "openai_client"
|
cache_key = f"openai_client::{resolved_base_url}"
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
client = _model_cache[cache_key]
|
client = _model_cache[cache_key]
|
||||||
else:
|
else:
|
||||||
client = openai.OpenAI(api_key=api_key)
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
_model_cache[cache_key] = client
|
_model_cache[cache_key] = client
|
||||||
logger.info("OpenAI client cached")
|
logger.info("OpenAI client cached")
|
||||||
|
|
||||||
@@ -261,8 +412,16 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
print(f"len of texts: {len(texts)}")
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 1000 # Conservative batch size
|
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
# get the avg len of texts
|
||||||
|
avg_len = sum(len(text) for text in texts) / len(texts)
|
||||||
|
print(f"avg len of texts: {avg_len}")
|
||||||
|
# if avg len is less than 1000, use the max batch size
|
||||||
|
if avg_len > 300:
|
||||||
|
max_batch_size = 500
|
||||||
|
|
||||||
|
# if avg len is less than 1000, use the max batch size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -365,3 +524,373 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
# Stack numpy arrays
|
# Stack numpy arrays
|
||||||
return np.stack(all_embeddings)
|
return np.stack(all_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_ollama(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
is_build: bool = False,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using Ollama API with simplified batch processing.
|
||||||
|
|
||||||
|
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
|
||||||
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if Ollama is running
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
error_msg = (
|
||||||
|
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||||
|
"Please ensure Ollama is running:\n"
|
||||||
|
" • macOS/Linux: ollama serve\n"
|
||||||
|
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||||
|
"Installation: https://ollama.com/download"
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
||||||
|
|
||||||
|
# Check if model exists and provide helpful suggestions
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||||
|
response.raise_for_status()
|
||||||
|
models = response.json()
|
||||||
|
model_names = [model["name"] for model in models.get("models", [])]
|
||||||
|
|
||||||
|
# Filter for embedding models (models that support embeddings)
|
||||||
|
embedding_models = []
|
||||||
|
suggested_embedding_models = [
|
||||||
|
"nomic-embed-text",
|
||||||
|
"mxbai-embed-large",
|
||||||
|
"bge-m3",
|
||||||
|
"all-minilm",
|
||||||
|
"snowflake-arctic-embed",
|
||||||
|
]
|
||||||
|
|
||||||
|
for model in model_names:
|
||||||
|
# Check if it's an embedding model (by name patterns or known models)
|
||||||
|
base_name = model.split(":")[0]
|
||||||
|
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
||||||
|
embedding_models.append(model)
|
||||||
|
|
||||||
|
# Check if model exists (handle versioned names) and resolve to full name
|
||||||
|
resolved_model_name = None
|
||||||
|
for name in model_names:
|
||||||
|
# Exact match
|
||||||
|
if model_name == name:
|
||||||
|
resolved_model_name = name
|
||||||
|
break
|
||||||
|
# Match without version tag (use the versioned name)
|
||||||
|
elif model_name == name.split(":")[0]:
|
||||||
|
resolved_model_name = name
|
||||||
|
break
|
||||||
|
|
||||||
|
if not resolved_model_name:
|
||||||
|
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
||||||
|
|
||||||
|
# Suggest pulling the model
|
||||||
|
error_msg += "📦 To install this embedding model:\n"
|
||||||
|
error_msg += f" ollama pull {model_name}\n\n"
|
||||||
|
|
||||||
|
# Show available embedding models
|
||||||
|
if embedding_models:
|
||||||
|
error_msg += "✅ Available embedding models:\n"
|
||||||
|
for model in embedding_models[:5]:
|
||||||
|
error_msg += f" • {model}\n"
|
||||||
|
if len(embedding_models) > 5:
|
||||||
|
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
||||||
|
else:
|
||||||
|
error_msg += "💡 Popular embedding models to install:\n"
|
||||||
|
for model in suggested_embedding_models[:3]:
|
||||||
|
error_msg += f" • ollama pull {model}\n"
|
||||||
|
|
||||||
|
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
# Use the resolved model name for all subsequent operations
|
||||||
|
if resolved_model_name != model_name:
|
||||||
|
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
||||||
|
model_name = resolved_model_name
|
||||||
|
|
||||||
|
# Verify the model supports embeddings by testing it
|
||||||
|
try:
|
||||||
|
test_response = requests.post(
|
||||||
|
f"{resolved_host}/api/embeddings",
|
||||||
|
json={"model": model_name, "prompt": "test"},
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
if test_response.status_code != 200:
|
||||||
|
error_msg = (
|
||||||
|
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
||||||
|
f"Please use an embedding model like:\n"
|
||||||
|
)
|
||||||
|
for model in suggested_embedding_models[:3]:
|
||||||
|
error_msg += f" • {model}\n"
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
# If test fails, continue anyway - model might still work
|
||||||
|
pass
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.warning(f"Could not verify model existence: {e}")
|
||||||
|
|
||||||
|
# Determine batch size based on device availability
|
||||||
|
# Check for CUDA/MPS availability using torch if available
|
||||||
|
batch_size = 32 # Default for MPS/CPU
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
batch_size = 128 # CUDA gets larger batch size
|
||||||
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
batch_size = 32 # MPS gets smaller batch size
|
||||||
|
except ImportError:
|
||||||
|
# If torch is not available, use conservative batch size
|
||||||
|
batch_size = 32
|
||||||
|
|
||||||
|
logger.info(f"Using batch size: {batch_size}")
|
||||||
|
|
||||||
|
def get_batch_embeddings(batch_texts):
|
||||||
|
"""Get embeddings for a batch of texts."""
|
||||||
|
all_embeddings = []
|
||||||
|
failed_indices = []
|
||||||
|
|
||||||
|
for i, text in enumerate(batch_texts):
|
||||||
|
max_retries = 3
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
# Truncate very long texts to avoid API issues
|
||||||
|
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||||
|
while retry_count < max_retries:
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{resolved_host}/api/embeddings",
|
||||||
|
json={"model": model_name, "prompt": truncated_text},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
embedding = result.get("embedding")
|
||||||
|
|
||||||
|
if embedding is None:
|
||||||
|
raise ValueError(f"No embedding returned for text {i}")
|
||||||
|
|
||||||
|
if not isinstance(embedding, list) or len(embedding) == 0:
|
||||||
|
raise ValueError(f"Invalid embedding format for text {i}")
|
||||||
|
|
||||||
|
all_embeddings.append(embedding)
|
||||||
|
break
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
||||||
|
failed_indices.append(i)
|
||||||
|
all_embeddings.append(None)
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.error(f"Failed to get embedding for text {i}: {e}")
|
||||||
|
failed_indices.append(i)
|
||||||
|
all_embeddings.append(None)
|
||||||
|
break
|
||||||
|
return all_embeddings, failed_indices
|
||||||
|
|
||||||
|
# Process texts in batches
|
||||||
|
all_embeddings = []
|
||||||
|
all_failed_indices = []
|
||||||
|
|
||||||
|
# Setup progress bar if needed
|
||||||
|
show_progress = is_build or len(texts) > 10
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ImportError:
|
||||||
|
show_progress = False
|
||||||
|
|
||||||
|
# Process batches
|
||||||
|
num_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
|
||||||
|
if show_progress:
|
||||||
|
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
||||||
|
else:
|
||||||
|
batch_iterator = range(num_batches)
|
||||||
|
|
||||||
|
for batch_idx in batch_iterator:
|
||||||
|
start_idx = batch_idx * batch_size
|
||||||
|
end_idx = min(start_idx + batch_size, len(texts))
|
||||||
|
batch_texts = texts[start_idx:end_idx]
|
||||||
|
|
||||||
|
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
||||||
|
|
||||||
|
# Adjust failed indices to global indices
|
||||||
|
global_failed = [start_idx + idx for idx in batch_failed]
|
||||||
|
all_failed_indices.extend(global_failed)
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
# Handle failed embeddings
|
||||||
|
if all_failed_indices:
|
||||||
|
if len(all_failed_indices) == len(texts):
|
||||||
|
raise RuntimeError("Failed to compute any embeddings")
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use zero embeddings as fallback for failed ones
|
||||||
|
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
||||||
|
if valid_embedding:
|
||||||
|
embedding_dim = len(valid_embedding)
|
||||||
|
for i, embedding in enumerate(all_embeddings):
|
||||||
|
if embedding is None:
|
||||||
|
all_embeddings[i] = [0.0] * embedding_dim
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
all_embeddings = [e for e in all_embeddings if e is not None]
|
||||||
|
|
||||||
|
if not all_embeddings:
|
||||||
|
raise RuntimeError("No valid embeddings were computed")
|
||||||
|
|
||||||
|
# Validate embedding dimensions
|
||||||
|
expected_dim = len(all_embeddings[0])
|
||||||
|
inconsistent_dims = []
|
||||||
|
for i, embedding in enumerate(all_embeddings):
|
||||||
|
if len(embedding) != expected_dim:
|
||||||
|
inconsistent_dims.append((i, len(embedding)))
|
||||||
|
|
||||||
|
if inconsistent_dims:
|
||||||
|
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
|
||||||
|
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
|
||||||
|
error_msg += f" - Text {idx}: {dim} dimensions\n"
|
||||||
|
if len(inconsistent_dims) > 10:
|
||||||
|
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
|
||||||
|
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
|
||||||
|
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
|
||||||
|
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
|
||||||
|
error_msg += (
|
||||||
|
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
|
||||||
|
)
|
||||||
|
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
# Convert to numpy array and normalize
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalize embeddings (L2 normalization)
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_gemini(
|
||||||
|
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using Google Gemini API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Gemini model name (default: "text-embedding-004")
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import google.genai as genai
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Google GenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Cache Gemini client
|
||||||
|
cache_key = "gemini_client"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
client = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
client = genai.Client(api_key=api_key)
|
||||||
|
_model_cache[cache_key] = client
|
||||||
|
logger.info("Gemini client cached")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gemini supports batch embedding
|
||||||
|
max_batch_size = 100 # Conservative batch size for Gemini
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the embed_content method from the new Google GenAI SDK
|
||||||
|
response = client.models.embed_content(
|
||||||
|
model=model_name,
|
||||||
|
contents=batch_texts,
|
||||||
|
config=genai.types.EmbedContentConfig(
|
||||||
|
task_type="RETRIEVAL_DOCUMENT" # For document embedding
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
for embedding_data in response.embeddings:
|
||||||
|
all_embeddings.append(embedding_data.values)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import psutil
|
from .settings import encode_provider_options
|
||||||
|
|
||||||
|
# Lightweight, self-contained server manager with no cross-process inspection
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -42,130 +45,7 @@ def _check_port(port: int) -> bool:
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
def _check_process_matches_config(
|
# Note: All cross-process scanning helpers removed for simplicity
|
||||||
port: int, expected_model: str, expected_passages_file: str
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the process using the port matches our expected model and passages file.
|
|
||||||
Returns True if matches, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
||||||
if not _is_process_listening_on_port(proc, port):
|
|
||||||
continue
|
|
||||||
|
|
||||||
cmdline = proc.info["cmdline"]
|
|
||||||
if not cmdline:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return _check_cmdline_matches_config(
|
|
||||||
cmdline, port, expected_model, expected_passages_file
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"No process found listening on port {port}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not check process on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
|
||||||
"""Check if a process is listening on the given port."""
|
|
||||||
try:
|
|
||||||
connections = proc.net_connections()
|
|
||||||
for conn in connections:
|
|
||||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _check_cmdline_matches_config(
|
|
||||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
|
||||||
) -> bool:
|
|
||||||
"""Check if command line matches our expected configuration."""
|
|
||||||
cmdline_str = " ".join(cmdline)
|
|
||||||
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
|
||||||
|
|
||||||
# Check if it's our embedding server
|
|
||||||
is_embedding_server = any(
|
|
||||||
server_type in cmdline_str
|
|
||||||
for server_type in [
|
|
||||||
"embedding_server",
|
|
||||||
"leann_backend_diskann.embedding_server",
|
|
||||||
"leann_backend_hnsw.hnsw_embedding_server",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_embedding_server:
|
|
||||||
logger.debug(f"Process on port {port} is not our embedding server")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check model name
|
|
||||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
|
||||||
|
|
||||||
# Check passages file if provided
|
|
||||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
|
||||||
|
|
||||||
result = model_matches and passages_matches
|
|
||||||
logger.debug(
|
|
||||||
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
|
||||||
"""Check if the command line contains the expected model."""
|
|
||||||
if "--model-name" not in cmdline:
|
|
||||||
return False
|
|
||||||
|
|
||||||
model_idx = cmdline.index("--model-name")
|
|
||||||
if model_idx + 1 >= len(cmdline):
|
|
||||||
return False
|
|
||||||
|
|
||||||
actual_model = cmdline[model_idx + 1]
|
|
||||||
return actual_model == expected_model
|
|
||||||
|
|
||||||
|
|
||||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
|
||||||
"""Check if the command line contains the expected passages file."""
|
|
||||||
if "--passages-file" not in cmdline:
|
|
||||||
return False # Expected but not found
|
|
||||||
|
|
||||||
passages_idx = cmdline.index("--passages-file")
|
|
||||||
if passages_idx + 1 >= len(cmdline):
|
|
||||||
return False
|
|
||||||
|
|
||||||
actual_passages = cmdline[passages_idx + 1]
|
|
||||||
expected_path = Path(expected_passages_file).resolve()
|
|
||||||
actual_path = Path(actual_passages).resolve()
|
|
||||||
return actual_path == expected_path
|
|
||||||
|
|
||||||
|
|
||||||
def _find_compatible_port_or_next_available(
|
|
||||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
|
||||||
) -> tuple[int, bool]:
|
|
||||||
"""
|
|
||||||
Find a port that either has a compatible server or is available.
|
|
||||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
|
||||||
"""
|
|
||||||
for port in range(start_port, start_port + max_attempts):
|
|
||||||
if not _check_port(port):
|
|
||||||
# Port is available
|
|
||||||
return port, False
|
|
||||||
|
|
||||||
# Port is in use, check if it's compatible
|
|
||||||
if _check_process_matches_config(port, model_name, passages_file):
|
|
||||||
logger.info(f"Found compatible server on port {port}")
|
|
||||||
return port, True
|
|
||||||
else:
|
|
||||||
logger.info(f"Port {port} has incompatible server, trying next port...")
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
@@ -182,9 +62,18 @@ class EmbeddingServerManager:
|
|||||||
e.g., "leann_backend_diskann.embedding_server"
|
e.g., "leann_backend_diskann.embedding_server"
|
||||||
"""
|
"""
|
||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: subprocess.Popen | None = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: int | None = None
|
self.server_port: Optional[int] = None
|
||||||
|
# Track last-started config for in-process reuse only
|
||||||
|
self._server_config: Optional[dict] = None
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
|
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
|
||||||
|
try:
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
self._finalizer = weakref.finalize(self, self._finalize_process)
|
||||||
|
except Exception:
|
||||||
|
self._finalizer = None
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
self,
|
self,
|
||||||
@@ -194,35 +83,65 @@ class EmbeddingServerManager:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
passages_file = kwargs.get("passages_file")
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
|
provider_options = kwargs.pop("provider_options", None)
|
||||||
|
|
||||||
# Check if we have a compatible server already running
|
config_signature = {
|
||||||
if self._has_compatible_running_server(model_name, passages_file):
|
"model_name": model_name,
|
||||||
logger.info("Found compatible running server!")
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
return True, port
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# If this manager already has a live server, just reuse it
|
||||||
|
if (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
and self._server_config == config_signature
|
||||||
|
):
|
||||||
|
logger.info("Reusing in-process server")
|
||||||
|
return True, self.server_port
|
||||||
|
|
||||||
|
# Configuration changed, stop existing server before starting a new one
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
logger.info("Existing server configuration differs; restarting embedding server")
|
||||||
|
self.stop_server()
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(
|
||||||
|
port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Find a compatible port or next available
|
# Always pick a fresh available port
|
||||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
try:
|
||||||
port, model_name, passages_file
|
actual_port = _get_available_port(port)
|
||||||
)
|
except RuntimeError:
|
||||||
|
logger.error("No available ports found")
|
||||||
if is_compatible:
|
return False, port
|
||||||
logger.info(f"Found compatible server on port {actual_port}")
|
|
||||||
return True, actual_port
|
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(
|
||||||
|
actual_port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
config_signature=config_signature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _start_server_colab(
|
def _start_server_colab(
|
||||||
self,
|
self,
|
||||||
port: int,
|
port: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start server with Colab-specific configuration."""
|
"""Start server with Colab-specific configuration."""
|
||||||
@@ -240,26 +159,34 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# In Colab, we'll use a more direct approach
|
# In Colab, we'll use a more direct approach
|
||||||
self._launch_server_process_colab(command, actual_port)
|
self._launch_server_process_colab(
|
||||||
return self._wait_for_server_ready_colab(actual_port)
|
command,
|
||||||
|
actual_port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||||
|
if started:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
|
|
||||||
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||||
"""Check if we have a compatible running server."""
|
|
||||||
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
|
||||||
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info("Existing server process is incompatible. Should start a new server.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
logger.info(f"Starting embedding server on port {port}...")
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
@@ -267,8 +194,20 @@ class EmbeddingServerManager:
|
|||||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._launch_server_process(command, port)
|
self._launch_server_process(
|
||||||
return self._wait_for_server_ready(port)
|
command,
|
||||||
|
port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready(port)
|
||||||
|
if started:
|
||||||
|
self._server_config = config_signature or {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server: {e}")
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
return False, port
|
return False, port
|
||||||
@@ -298,27 +237,80 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def _launch_server_process(self, command: list, port: int) -> None:
|
def _launch_server_process(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
# Let server output go directly to console
|
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
|
||||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
# Embedding servers use many print statements that can fill stdout buffers
|
||||||
|
is_ci = os.environ.get("CI") == "true"
|
||||||
|
if is_ci:
|
||||||
|
stdout_target = subprocess.DEVNULL
|
||||||
|
stderr_target = None # Keep stderr for error debugging in CI
|
||||||
|
logger.info(
|
||||||
|
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stdout_target = None # Direct to console for visible logs
|
||||||
|
stderr_target = None # Direct to console for visible logs
|
||||||
|
|
||||||
|
# Start embedding server subprocess
|
||||||
|
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=None, # Direct to console
|
stdout=stdout_target,
|
||||||
stderr=None, # Direct to console
|
stderr=stderr_target,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
|
# Record config for in-process reuse (best effort; refined later when ready)
|
||||||
|
try:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
|
if "--model-name" in command
|
||||||
|
else "",
|
||||||
|
"passages_file": command[command.index("--passages-file") + 1]
|
||||||
|
if "--passages-file" in command
|
||||||
|
else "",
|
||||||
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
|
if "--embedding-mode" in command
|
||||||
|
else "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": "",
|
||||||
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
# Use a lambda to avoid issues with bound methods
|
# Always attempt best-effort finalize at interpreter exit
|
||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
# Touch finalizer so it knows there is a live process
|
||||||
|
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
|
||||||
|
try:
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
self._finalizer = weakref.finalize(self, self._finalize_process)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
@@ -343,24 +335,35 @@ class EmbeddingServerManager:
|
|||||||
if not self.server_process:
|
if not self.server_process:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.server_process.poll() is not None:
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
# Process already terminated
|
# Process already terminated
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
self.server_process.terminate()
|
|
||||||
|
# Use simple termination first; if the server installed signal handlers,
|
||||||
|
# it will exit cleanly. Otherwise escalate to kill after a short wait.
|
||||||
|
try:
|
||||||
|
self.server_process.terminate()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=3)
|
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
|
||||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
|
||||||
)
|
)
|
||||||
self.server_process.kill()
|
try:
|
||||||
|
self.server_process.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=2)
|
self.server_process.wait(timeout=2)
|
||||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
||||||
@@ -368,34 +371,70 @@ class EmbeddingServerManager:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
||||||
)
|
)
|
||||||
# Don't hang indefinitely
|
|
||||||
|
|
||||||
# Clean up process resources to prevent resource tracker warnings
|
# Clean up process resources with timeout to avoid CI hang
|
||||||
try:
|
try:
|
||||||
self.server_process.wait() # Ensure process is fully cleaned up
|
# Use shorter timeout in CI environments
|
||||||
|
is_ci = os.environ.get("CI") == "true"
|
||||||
|
timeout = 3 if is_ci else 10
|
||||||
|
self.server_process.wait(timeout=timeout)
|
||||||
|
logger.info(f"Server process {self.server_process.pid} cleanup completed")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during process cleanup: {e}")
|
||||||
|
finally:
|
||||||
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
|
|
||||||
|
def _finalize_process(self) -> None:
|
||||||
|
"""Best-effort cleanup used by weakref.finalize/atexit."""
|
||||||
|
try:
|
||||||
|
self.stop_server()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.server_process = None
|
def _adopt_existing_server(self, *args, **kwargs) -> None:
|
||||||
|
# Removed: cross-process adoption no longer supported
|
||||||
|
return
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
|
|
||||||
# In Colab, we need to be more careful about process management
|
# In Colab, we need to be more careful about process management
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
text=True,
|
text=True,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback
|
# Register atexit callback (unified)
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
# Record config for in-process reuse is best-effort in Colab mode
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": "",
|
||||||
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -34,7 +34,9 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
def _ensure_server_running(
|
||||||
|
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||||
|
) -> int:
|
||||||
"""Ensure server is running"""
|
"""Ensure server is running"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -48,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int | None = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -74,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: int | None = None,
|
zmq_port: Optional[int] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
154
packages/leann-core/src/leann/mcp.py
Executable file
154
packages/leann-core/src/leann/mcp.py
Executable file
@@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def handle_request(request):
|
||||||
|
if request.get("method") == "initialize":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request.get("id"),
|
||||||
|
"result": {
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
elif request.get("method") == "tools/list":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request.get("id"),
|
||||||
|
"result": {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "leann_search",
|
||||||
|
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
|
||||||
|
|
||||||
|
🎯 **Perfect for**:
|
||||||
|
- "How does authentication work?" → finds auth-related code
|
||||||
|
- "Error handling patterns" → locates try-catch blocks and error logic
|
||||||
|
- "Database connection setup" → finds DB initialization code
|
||||||
|
- "API endpoint definitions" → locates route handlers
|
||||||
|
- "Configuration management" → finds config files and usage
|
||||||
|
|
||||||
|
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": 5,
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 20,
|
||||||
|
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
|
||||||
|
},
|
||||||
|
"complexity": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": 32,
|
||||||
|
"minimum": 16,
|
||||||
|
"maximum": 128,
|
||||||
|
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["index_name", "query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "leann_list",
|
||||||
|
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
||||||
|
"inputSchema": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
elif request.get("method") == "tools/call":
|
||||||
|
tool_name = request["params"]["name"]
|
||||||
|
args = request["params"].get("arguments", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if tool_name == "leann_search":
|
||||||
|
# Validate required parameters
|
||||||
|
if not args.get("index_name") or not args.get("query"):
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request.get("id"),
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Error: Both index_name and query are required",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build simplified command with non-interactive flag for MCP compatibility
|
||||||
|
cmd = [
|
||||||
|
"leann",
|
||||||
|
"search",
|
||||||
|
args["index_name"],
|
||||||
|
args["query"],
|
||||||
|
f"--top-k={args.get('top_k', 5)}",
|
||||||
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
|
"--non-interactive",
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
elif tool_name == "leann_list":
|
||||||
|
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request.get("id"),
|
||||||
|
"result": {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": result.stdout
|
||||||
|
if result.returncode == 0
|
||||||
|
else f"Error: {result.stderr}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request.get("id"),
|
||||||
|
"error": {"code": -1, "message": str(e)},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for line in sys.stdin:
|
||||||
|
try:
|
||||||
|
request = json.loads(line.strip())
|
||||||
|
response = handle_request(request)
|
||||||
|
if response:
|
||||||
|
print(json.dumps(response))
|
||||||
|
sys.stdout.flush()
|
||||||
|
except Exception as e:
|
||||||
|
error_response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": None,
|
||||||
|
"error": {"code": -1, "message": str(e)},
|
||||||
|
}
|
||||||
|
print(json.dumps(error_response))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""
|
||||||
|
Metadata filtering engine for LEANN search results.
|
||||||
|
|
||||||
|
This module provides generic metadata filtering capabilities that can be applied
|
||||||
|
to search results from any LEANN backend. The filtering supports various
|
||||||
|
operators for different data types including numbers, strings, booleans, and lists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for filter specifications
|
||||||
|
FilterValue = Union[str, int, float, bool, list]
|
||||||
|
FilterSpec = dict[str, FilterValue]
|
||||||
|
MetadataFilters = dict[str, FilterSpec]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilterEngine:
|
||||||
|
"""
|
||||||
|
Engine for evaluating metadata filters against search results.
|
||||||
|
|
||||||
|
Supports various operators for filtering based on metadata fields:
|
||||||
|
- Comparison: ==, !=, <, <=, >, >=
|
||||||
|
- Membership: in, not_in
|
||||||
|
- String operations: contains, starts_with, ends_with
|
||||||
|
- Boolean operations: is_true, is_false
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the filter engine with supported operators."""
|
||||||
|
self.operators = {
|
||||||
|
"==": self._equals,
|
||||||
|
"!=": self._not_equals,
|
||||||
|
"<": self._less_than,
|
||||||
|
"<=": self._less_than_or_equal,
|
||||||
|
">": self._greater_than,
|
||||||
|
">=": self._greater_than_or_equal,
|
||||||
|
"in": self._in,
|
||||||
|
"not_in": self._not_in,
|
||||||
|
"contains": self._contains,
|
||||||
|
"starts_with": self._starts_with,
|
||||||
|
"ends_with": self._ends_with,
|
||||||
|
"is_true": self._is_true,
|
||||||
|
"is_false": self._is_false,
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_filters(
|
||||||
|
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to a list of search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of result dictionaries, each containing 'metadata' field
|
||||||
|
metadata_filters: Dictionary of filter specifications
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of search results
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying filters: {metadata_filters}")
|
||||||
|
logger.debug(f"Input results count: {len(search_results)}")
|
||||||
|
|
||||||
|
filtered_results = []
|
||||||
|
for result in search_results:
|
||||||
|
if self._evaluate_filters(result, metadata_filters):
|
||||||
|
filtered_results.append(result)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results count: {len(filtered_results)}")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate all filters against a single search result.
|
||||||
|
|
||||||
|
All filters must pass (AND logic) for the result to be included.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary (including metadata, text, etc.)
|
||||||
|
filters: Filter specifications to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all filters pass, False otherwise
|
||||||
|
"""
|
||||||
|
for field_name, filter_spec in filters.items():
|
||||||
|
if not self._evaluate_field_filter(result, field_name, filter_spec):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _evaluate_field_filter(
|
||||||
|
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate a single field filter against a search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary
|
||||||
|
field_name: Name of the field to filter on
|
||||||
|
filter_spec: Filter specification for this field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the filter passes, False otherwise
|
||||||
|
"""
|
||||||
|
# First check top-level fields, then check metadata
|
||||||
|
field_value = result.get(field_name)
|
||||||
|
if field_value is None:
|
||||||
|
# Try to get from metadata if not found at top level
|
||||||
|
metadata = result.get("metadata", {})
|
||||||
|
field_value = metadata.get(field_name)
|
||||||
|
|
||||||
|
# Handle missing fields - they fail all filters except existence checks
|
||||||
|
if field_value is None:
|
||||||
|
logger.debug(f"Field '{field_name}' not found in result or metadata")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Evaluate each operator in the filter spec
|
||||||
|
for operator, expected_value in filter_spec.items():
|
||||||
|
if operator not in self.operators:
|
||||||
|
logger.warning(f"Unsupported operator: {operator}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.operators[operator](field_value, expected_value):
|
||||||
|
logger.debug(
|
||||||
|
f"Filter failed: {field_name} {operator} {expected_value} "
|
||||||
|
f"(actual: {field_value})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Comparison operators
|
||||||
|
def _equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value equals expected value."""
|
||||||
|
return field_value == expected_value
|
||||||
|
|
||||||
|
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value does not equal expected value."""
|
||||||
|
return field_value != expected_value
|
||||||
|
|
||||||
|
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
|
||||||
|
|
||||||
|
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
|
||||||
|
|
||||||
|
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
|
||||||
|
|
||||||
|
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
|
||||||
|
|
||||||
|
# Membership operators
|
||||||
|
def _in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'in' operator requires a list, tuple, or set")
|
||||||
|
return field_value in expected_value
|
||||||
|
|
||||||
|
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is not in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'not_in' operator requires a list, tuple, or set")
|
||||||
|
return field_value not in expected_value
|
||||||
|
|
||||||
|
# String operators
|
||||||
|
def _contains(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value contains the expected substring."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return expected_str in field_str
|
||||||
|
|
||||||
|
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value starts with the expected prefix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.startswith(expected_str)
|
||||||
|
|
||||||
|
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value ends with the expected suffix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.endswith(expected_str)
|
||||||
|
|
||||||
|
# Boolean operators
|
||||||
|
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is truthy."""
|
||||||
|
return bool(field_value)
|
||||||
|
|
||||||
|
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is falsy."""
|
||||||
|
return not bool(field_value)
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
|
||||||
|
"""
|
||||||
|
Helper for numeric comparisons with type coercion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_value: Value from metadata
|
||||||
|
expected_value: Value to compare against
|
||||||
|
compare_func: Comparison function to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of comparison
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to convert both values to numbers for comparison
|
||||||
|
if isinstance(field_value, str) and isinstance(expected_value, str):
|
||||||
|
# String comparison if both are strings
|
||||||
|
return compare_func(field_value, expected_value)
|
||||||
|
|
||||||
|
# Numeric comparison - attempt to convert to float
|
||||||
|
field_num = (
|
||||||
|
float(field_value) if not isinstance(field_value, (int, float)) else field_value
|
||||||
|
)
|
||||||
|
expected_num = (
|
||||||
|
float(expected_value)
|
||||||
|
if not isinstance(expected_value, (int, float))
|
||||||
|
else expected_value
|
||||||
|
)
|
||||||
|
|
||||||
|
return compare_func(field_num, expected_num)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Fall back to string comparison if numeric conversion fails
|
||||||
|
return compare_func(str(field_value), str(expected_value))
|
||||||
@@ -2,11 +2,17 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from typing import TYPE_CHECKING
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
|
# Set up logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +20,7 @@ def register_backend(name: str):
|
|||||||
"""A decorator to register a new backend class."""
|
"""A decorator to register a new backend class."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
print(f"INFO: Registering backend '{name}'")
|
logger.debug(f"Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@@ -39,3 +45,54 @@ def autodiscover_backends():
|
|||||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
pass
|
pass
|
||||||
# print("INFO: Backend auto-discovery finished.")
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|
||||||
|
|
||||||
|
def register_project_directory(project_dir: Optional[Union[str, Path]] = None):
|
||||||
|
"""
|
||||||
|
Register a project directory in the global LEANN registry.
|
||||||
|
|
||||||
|
This allows `leann list` to discover indexes created by apps or other tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_dir: Directory to register. If None, uses current working directory.
|
||||||
|
"""
|
||||||
|
if project_dir is None:
|
||||||
|
project_dir = Path.cwd()
|
||||||
|
else:
|
||||||
|
project_dir = Path(project_dir)
|
||||||
|
|
||||||
|
# Only register directories that have some kind of LEANN content
|
||||||
|
# Either .leann/indexes/ (CLI format) or *.leann.meta.json files (apps format)
|
||||||
|
has_cli_indexes = (project_dir / ".leann" / "indexes").exists()
|
||||||
|
has_app_indexes = any(project_dir.rglob("*.leann.meta.json"))
|
||||||
|
|
||||||
|
if not (has_cli_indexes or has_app_indexes):
|
||||||
|
# Don't register if there are no LEANN indexes
|
||||||
|
return
|
||||||
|
|
||||||
|
global_registry = Path.home() / ".leann" / "projects.json"
|
||||||
|
global_registry.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
project_str = str(project_dir.resolve())
|
||||||
|
|
||||||
|
# Load existing registry
|
||||||
|
projects = []
|
||||||
|
if global_registry.exists():
|
||||||
|
try:
|
||||||
|
with open(global_registry) as f:
|
||||||
|
projects = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not load existing project registry")
|
||||||
|
projects = []
|
||||||
|
|
||||||
|
# Add project if not already present
|
||||||
|
if project_str not in projects:
|
||||||
|
projects.append(project_str)
|
||||||
|
|
||||||
|
# Save updated registry
|
||||||
|
try:
|
||||||
|
with open(global_registry, "w") as f:
|
||||||
|
json.dump(projects, f, indent=2)
|
||||||
|
logger.debug(f"Registered project directory: {project_str}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not save project registry: {e}")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||||
|
|
||||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta.get("embedding_options", {})
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
from .embedding_compute import compute_embeddings
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
return compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.embedding_model,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
"""Compute embeddings using the ZMQ embedding server."""
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
@@ -169,7 +176,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int | None = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
74
packages/leann-core/src/leann/settings.py
Normal file
74
packages/leann-core/src/leann/settings.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Runtime configuration helpers for LEANN."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_url(value: str) -> str:
|
||||||
|
"""Normalize URL strings by stripping trailing slashes."""
|
||||||
|
|
||||||
|
return value.rstrip("/") if value else value
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the Ollama-compatible endpoint to use."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||||
|
os.getenv("LEANN_OLLAMA_HOST"),
|
||||||
|
os.getenv("OLLAMA_HOST"),
|
||||||
|
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||||
|
os.getenv("OPENAI_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
if not options:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.dumps(options)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# Fall back to empty payload if serialization fails
|
||||||
|
return None
|
||||||
149
packages/leann-mcp/README.md
Normal file
149
packages/leann-mcp/README.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# 🔥 LEANN Claude Code Integration
|
||||||
|
|
||||||
|
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||||
|
|
||||||
|
For agent-facing discovery details, see `llms.txt` in the repository root.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Install LEANN globally for MCP integration (with default backend):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv tool install leann-core --with leann
|
||||||
|
```
|
||||||
|
This installs the `leann` CLI into an isolated tool environment and includes both backends so `leann build` works out-of-the-box.
|
||||||
|
|
||||||
|
## 🚀 Quick Setup
|
||||||
|
|
||||||
|
Add the LEANN MCP server to Claude Code. Choose the scope based on how widely you want it available. Below is the command to install it globally; if you prefer a local install, skip this step:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Global (recommended): available in all projects for your user
|
||||||
|
claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
- `leann-server`: the display name of the MCP server in Claude Code (you can change it).
|
||||||
|
- `leann_mcp`: the Python entry point installed with LEANN that starts the MCP server.
|
||||||
|
|
||||||
|
Verify it is registered globally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
claude mcp list | cat
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🛠️ Available Tools
|
||||||
|
|
||||||
|
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
|
||||||
|
|
||||||
|
- **`leann_list`** - List all available indexes across your projects
|
||||||
|
- **`leann_search`** - Perform semantic searches across code and documents
|
||||||
|
|
||||||
|
|
||||||
|
## 🎯 Quick Start Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Add locally if you did not add it globally (current folder only; default if --scope is omitted)
|
||||||
|
claude mcp add leann-server -- leann_mcp
|
||||||
|
|
||||||
|
# Build an index for your project (change to your actual path)
|
||||||
|
# See the advanced examples below for more ways to configure indexing
|
||||||
|
# Set the index name (replace 'my-project' with your own)
|
||||||
|
leann build my-project --docs $(git ls-files)
|
||||||
|
|
||||||
|
# Start Claude Code
|
||||||
|
claude
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 Advanced Usage Examples to build the index
|
||||||
|
|
||||||
|
### Index Entire Git Repository
|
||||||
|
```bash
|
||||||
|
# Index all tracked files in your Git repository.
|
||||||
|
# Note: submodules are currently skipped; we can add them back if needed.
|
||||||
|
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# Index only tracked Python files from Git.
|
||||||
|
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# If you encounter empty requests caused by empty files (e.g., __init__.py), exclude zero-byte files. Thanks @ww2283 for pointing [that](https://github.com/yichuan-w/LEANN/issues/48) out
|
||||||
|
leann build leann-prospec-lig --docs $(find ./src -name "*.py" -not -empty) --embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Directories and Files
|
||||||
|
```bash
|
||||||
|
# Index multiple directories
|
||||||
|
leann build my-codebase --docs ./src ./tests ./docs ./config --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# Mix files and directories
|
||||||
|
leann build my-project --docs ./README.md ./src/ ./package.json ./docs/ --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# Specific files only
|
||||||
|
leann build my-configs --docs ./tsconfig.json ./package.json ./webpack.config.js --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Git Integration
|
||||||
|
```bash
|
||||||
|
# Index recently modified files
|
||||||
|
leann build recent-changes --docs $(git diff --name-only HEAD~10..HEAD) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# Index files matching pattern
|
||||||
|
leann build frontend --docs $(git ls-files "*.tsx" "*.ts" "*.jsx" "*.js") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# Index documentation and config files
|
||||||
|
leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.json" "*.toml") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## **Try this in Claude Code:**
|
||||||
|
```
|
||||||
|
Help me understand this codebase. List available indexes and search for authentication patterns.
|
||||||
|
```
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
If you see a prompt asking whether to proceed with LEANN, you can now use it in your chat!
|
||||||
|
|
||||||
|
## 🧠 How It Works
|
||||||
|
|
||||||
|
The integration consists of three key components working seamlessly together:
|
||||||
|
|
||||||
|
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
|
||||||
|
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
|
||||||
|
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
|
||||||
|
|
||||||
|
## 📁 File Support
|
||||||
|
|
||||||
|
LEANN understands **30+ file types** including:
|
||||||
|
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
|
||||||
|
- **Data**: SQL, YAML, JSON, CSV, XML
|
||||||
|
- **Documentation**: Markdown, TXT, PDF
|
||||||
|
- **And many more!**
|
||||||
|
|
||||||
|
## 💾 Storage & Organization
|
||||||
|
|
||||||
|
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
|
||||||
|
- **Global registry**: Project tracking at `~/.leann/projects.json`
|
||||||
|
- **Multi-project support**: Switch between different codebases seamlessly
|
||||||
|
- **Portable**: Transfer indexes between machines with minimal overhead
|
||||||
|
|
||||||
|
## 🗑️ Uninstalling
|
||||||
|
|
||||||
|
To remove the LEANN MCP server from Claude Code:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
claude mcp remove leann-server
|
||||||
|
```
|
||||||
|
To remove LEANN
|
||||||
|
```
|
||||||
|
uv pip uninstall leann leann-backend-hnsw leann-core
|
||||||
|
```
|
||||||
|
|
||||||
|
To globally remove LEANN (for version update)
|
||||||
|
```
|
||||||
|
uv tool list | cat
|
||||||
|
uv tool uninstall leann-core
|
||||||
|
command -v leann || echo "leann gone"
|
||||||
|
command -v leann_mcp || echo "leann_mcp gone"
|
||||||
|
```
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.2.1"
|
version = "0.3.4"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
1
packages/wechat-exporter/__init__.py
Normal file
1
packages/wechat-exporter/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__all__ = []
|
||||||
@@ -136,5 +136,9 @@ def export_sqlite(
|
|||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
app()
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -10,11 +10,10 @@ requires-python = ">=3.9"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
"leann-backend-hnsw",
|
"leann-backend-hnsw",
|
||||||
|
"typer>=0.12.3",
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"torch",
|
"torch",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"flask",
|
|
||||||
"flask_compress",
|
|
||||||
"datasets>=2.15.0",
|
"datasets>=2.15.0",
|
||||||
"evaluate",
|
"evaluate",
|
||||||
"colorama",
|
"colorama",
|
||||||
@@ -32,7 +31,7 @@ dependencies = [
|
|||||||
"pypdfium2>=4.30.0",
|
"pypdfium2>=4.30.0",
|
||||||
# LlamaIndex core and readers - updated versions
|
# LlamaIndex core and readers - updated versions
|
||||||
"llama-index>=0.12.44",
|
"llama-index>=0.12.44",
|
||||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
@@ -40,32 +39,24 @@ dependencies = [
|
|||||||
# Other dependencies
|
# Other dependencies
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
|
"pybind11>=3.0.0",
|
||||||
|
"pathspec>=0.12.1",
|
||||||
|
"nbconvert>=7.16.6",
|
||||||
|
"gitignore-parser>=0.1.12",
|
||||||
|
# AST-aware code chunking dependencies
|
||||||
|
"astchunk>=0.1.0",
|
||||||
|
"tree-sitter>=0.20.0",
|
||||||
|
"tree-sitter-python>=0.20.0",
|
||||||
|
"tree-sitter-java>=0.20.0",
|
||||||
|
"tree-sitter-c-sharp>=0.20.0",
|
||||||
|
"tree-sitter-typescript>=0.20.0",
|
||||||
|
"torchvision>=0.23.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
|
||||||
"pytest>=7.0",
|
|
||||||
"pytest-cov>=4.0",
|
|
||||||
"pytest-xdist>=3.0", # For parallel test execution
|
|
||||||
"black>=23.0",
|
|
||||||
"ruff>=0.1.0",
|
|
||||||
"matplotlib",
|
|
||||||
"huggingface-hub>=0.20.0",
|
|
||||||
"pre-commit>=3.5.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
test = [
|
|
||||||
"pytest>=7.0",
|
|
||||||
"pytest-timeout>=2.0",
|
|
||||||
"llama-index-core>=0.12.0",
|
|
||||||
"llama-index-readers-file>=0.4.0",
|
|
||||||
"python-dotenv>=1.0.0",
|
|
||||||
"sentence-transformers>=2.2.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
diskann = [
|
diskann = [
|
||||||
"leann-backend-diskann",
|
"leann-backend-diskann",
|
||||||
]
|
]
|
||||||
@@ -80,24 +71,51 @@ documents = [
|
|||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = []
|
py-modules = []
|
||||||
|
packages = ["wechat_exporter"]
|
||||||
|
package-dir = { "wechat_exporter" = "packages/wechat-exporter" }
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
wechat-exporter = "wechat_exporter.main:main"
|
||||||
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
# Minimal lint toolchain for CI and local hooks
|
||||||
|
lint = [
|
||||||
|
"pre-commit>=3.5.0",
|
||||||
|
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test toolchain (no heavy project runtime deps)
|
||||||
|
test = [
|
||||||
|
"pytest>=7.0",
|
||||||
|
"pytest-cov>=4.0",
|
||||||
|
"pytest-xdist>=3.0",
|
||||||
|
"pytest-timeout>=2.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
# dependencies by apps/ should list here
|
||||||
|
dev = [
|
||||||
|
"matplotlib",
|
||||||
|
"huggingface-hub>=0.20.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py310"
|
target-version = "py39"
|
||||||
line-length = 100
|
line-length = 100
|
||||||
extend-exclude = [
|
extend-exclude = [
|
||||||
"third_party",
|
"third_party",
|
||||||
"*.egg-info",
|
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann.py",
|
||||||
"__pycache__",
|
"apps/multimodal/vision-based-pdf-multi-vector/multi-vector-leann-similarity-map.py"
|
||||||
".git",
|
|
||||||
".venv",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
"E", # pycodestyle errors
|
"E", # pycodestyle errors
|
||||||
@@ -119,21 +137,12 @@ ignore = [
|
|||||||
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
|
|
||||||
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
indent-style = "space"
|
indent-style = "space"
|
||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
|
|
||||||
[dependency-groups]
|
|
||||||
dev = [
|
|
||||||
"ruff>=0.12.4",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.lychee]
|
[tool.lychee]
|
||||||
accept = ["200", "403", "429", "503"]
|
accept = ["200", "403", "429", "503"]
|
||||||
timeout = 20
|
timeout = 20
|
||||||
@@ -151,7 +160,7 @@ markers = [
|
|||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
]
|
]
|
||||||
timeout = 600
|
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||||
addopts = [
|
addopts = [
|
||||||
"-v",
|
"-v",
|
||||||
"--tb=short",
|
"--tb=short",
|
||||||
|
|||||||
121
scripts/hf_upload.py
Normal file
121
scripts/hf_upload.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Upload local evaluation data to Hugging Face Hub, excluding diskann_rpj_wiki.
|
||||||
|
|
||||||
|
Defaults:
|
||||||
|
- repo_id: LEANN-RAG/leann-rag-evaluation-data (dataset)
|
||||||
|
- folder_path: benchmarks/data
|
||||||
|
- ignore_patterns: diskann_rpj_wiki/** and .cache/**
|
||||||
|
|
||||||
|
Requires authentication via `huggingface-cli login` or HF_TOKEN env var.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
except Exception as e:
|
||||||
|
raise SystemExit(
|
||||||
|
"huggingface_hub is required. Install with: pip install huggingface_hub hf_transfer"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_transfer_accel_if_available() -> None:
|
||||||
|
"""Best-effort enabling of accelerated transfers across hub versions.
|
||||||
|
|
||||||
|
Tries the public util if present; otherwise, falls back to env flag when
|
||||||
|
hf_transfer is installed. Silently no-ops if unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Newer huggingface_hub exposes this under utils
|
||||||
|
from huggingface_hub.utils import hf_hub_enable_hf_transfer # type: ignore
|
||||||
|
|
||||||
|
hf_hub_enable_hf_transfer()
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If hf_transfer is installed, set env flag recognized by the hub
|
||||||
|
import hf_transfer # noqa: F401
|
||||||
|
|
||||||
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||||
|
except Exception:
|
||||||
|
# Acceleration not available; proceed without it
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description="Upload local data to HF, excluding diskann_rpj_wiki")
|
||||||
|
p.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
default="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
help="Target dataset repo id (namespace/name)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--folder-path",
|
||||||
|
default="benchmarks/data",
|
||||||
|
help="Local folder to upload (default: benchmarks/data)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--ignore",
|
||||||
|
default=["diskann_rpj_wiki/**", ".cache/**"],
|
||||||
|
nargs="+",
|
||||||
|
help="Glob patterns to ignore (space-separated)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--allow",
|
||||||
|
default=["**"],
|
||||||
|
nargs="+",
|
||||||
|
help="Glob patterns to allow (space-separated). Defaults to everything.",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--message",
|
||||||
|
default="sync local data (exclude diskann_rpj_wiki)",
|
||||||
|
help="Commit message",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--no-transfer-accel",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable hf_transfer accelerated uploads",
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if not args.no_transfer_accel:
|
||||||
|
_enable_transfer_accel_if_available()
|
||||||
|
|
||||||
|
if not os.path.isdir(args.folder_path):
|
||||||
|
raise SystemExit(f"Folder not found: {args.folder_path}")
|
||||||
|
|
||||||
|
print("Uploading to Hugging Face Hub:")
|
||||||
|
print(f" repo_id: {args.repo_id}")
|
||||||
|
print(" repo_type: dataset")
|
||||||
|
print(f" folder_path: {args.folder_path}")
|
||||||
|
print(f" allow_patterns: {args.allow}")
|
||||||
|
print(f" ignore_patterns:{args.ignore}")
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
# Perform upload. This skips unchanged files by content hash.
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
folder_path=args.folder_path,
|
||||||
|
path_in_repo=".",
|
||||||
|
allow_patterns=args.allow,
|
||||||
|
ignore_patterns=args.ignore,
|
||||||
|
commit_message=args.message,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Upload completed (unchanged files were skipped by the Hub).")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
324
scripts/measure_generation_times.py
Executable file
324
scripts/measure_generation_times.py
Executable file
@@ -0,0 +1,324 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Measure generation latency of a HuggingFace/OpenAI-compatible model over prompt files."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import contextlib
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.chat import get_llm
|
||||||
|
|
||||||
|
PROMPT_PREFIX = "PROMPT #"
|
||||||
|
logging.getLogger("leann.chat").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompts(path: Path) -> list[str]:
|
||||||
|
prompts: list[str] = []
|
||||||
|
buffer: list[str] = []
|
||||||
|
collecting = False
|
||||||
|
|
||||||
|
with path.open("r", encoding="utf-8") as handle:
|
||||||
|
for line in handle:
|
||||||
|
if line.startswith(PROMPT_PREFIX):
|
||||||
|
if buffer:
|
||||||
|
prompts.append("".join(buffer).strip())
|
||||||
|
buffer.clear()
|
||||||
|
collecting = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if collecting:
|
||||||
|
buffer.append(line)
|
||||||
|
|
||||||
|
if buffer:
|
||||||
|
prompts.append("".join(buffer).strip())
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def measure_generation_times(
|
||||||
|
prompts: list[str],
|
||||||
|
llm,
|
||||||
|
generation_kwargs: dict[str, object],
|
||||||
|
allow_truncation: bool,
|
||||||
|
enable_qwen_thinking: bool,
|
||||||
|
verbose: bool,
|
||||||
|
per_call_timeout: int | None,
|
||||||
|
):
|
||||||
|
timings: list[float] = []
|
||||||
|
tokenizer = getattr(llm, "tokenizer", None)
|
||||||
|
max_positions = None
|
||||||
|
if hasattr(llm, "model") and hasattr(llm.model, "config"):
|
||||||
|
max_positions = getattr(llm.model.config, "max_position_embeddings", None)
|
||||||
|
|
||||||
|
requested_new_tokens = None
|
||||||
|
if max_positions is not None:
|
||||||
|
if "max_new_tokens" in generation_kwargs:
|
||||||
|
requested_new_tokens = generation_kwargs["max_new_tokens"]
|
||||||
|
elif "max_tokens" in generation_kwargs:
|
||||||
|
requested_new_tokens = generation_kwargs["max_tokens"]
|
||||||
|
|
||||||
|
context_max_length = max_positions
|
||||||
|
if max_positions is not None and requested_new_tokens is not None:
|
||||||
|
if requested_new_tokens >= max_positions:
|
||||||
|
requested_new_tokens = max_positions - 1
|
||||||
|
context_max_length = max(max_positions - requested_new_tokens, 1)
|
||||||
|
|
||||||
|
suppress_buffer = io.StringIO()
|
||||||
|
# Log base config
|
||||||
|
if verbose:
|
||||||
|
device = getattr(llm, "device", None)
|
||||||
|
try:
|
||||||
|
dtype = getattr(getattr(llm, "model", None), "dtype", None)
|
||||||
|
except Exception:
|
||||||
|
dtype = None
|
||||||
|
print(
|
||||||
|
f"[dbg] device={device} dtype={dtype} max_positions={max_positions} requested_new_tokens={requested_new_tokens} context_max_length={context_max_length}"
|
||||||
|
)
|
||||||
|
total = len(prompts)
|
||||||
|
for idx, prompt in enumerate(prompts, start=1):
|
||||||
|
prompt_for_llm = prompt
|
||||||
|
if (
|
||||||
|
enable_qwen_thinking
|
||||||
|
and "/think" not in prompt_for_llm
|
||||||
|
and "/no_think" not in prompt_for_llm
|
||||||
|
):
|
||||||
|
prompt_for_llm = f"{prompt_for_llm}\n/think"
|
||||||
|
|
||||||
|
if allow_truncation and tokenizer is not None and max_positions is not None:
|
||||||
|
tokenized = tokenizer(
|
||||||
|
prompt_for_llm,
|
||||||
|
truncation=True,
|
||||||
|
max_length=context_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
prompt_for_llm = tokenizer.decode(tokenized["input_ids"][0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
per_call_kwargs = dict(generation_kwargs)
|
||||||
|
if requested_new_tokens is not None:
|
||||||
|
per_call_kwargs["max_new_tokens"] = requested_new_tokens
|
||||||
|
# Enable streaming if requested (HF backend will print tokens)
|
||||||
|
if verbose:
|
||||||
|
# When verbose (or --stream propagated), enable streaming in HF backend
|
||||||
|
per_call_kwargs["stream"] = True
|
||||||
|
|
||||||
|
# Extra debug info about token lengths
|
||||||
|
if verbose and tokenizer is not None:
|
||||||
|
try:
|
||||||
|
toks = tokenizer(prompt_for_llm, return_tensors=None, truncation=False)
|
||||||
|
in_len = (
|
||||||
|
len(toks["input_ids"])
|
||||||
|
if isinstance(toks["input_ids"], list)
|
||||||
|
else len(toks["input_ids"][0])
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
in_len = None
|
||||||
|
print(f"[dbg] prompt {idx}/{total} tokens={in_len}")
|
||||||
|
print(
|
||||||
|
f"[dbg] gen_cfg={{max_new_tokens:{per_call_kwargs.get('max_new_tokens')}, temp:{per_call_kwargs.get('temperature')}, top_p:{per_call_kwargs.get('top_p')}}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
# Optional per-call timeout using signal alarm
|
||||||
|
timeout_handler_installed = False
|
||||||
|
if per_call_timeout is not None:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("generation timed out")
|
||||||
|
|
||||||
|
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(int(per_call_timeout))
|
||||||
|
timeout_handler_installed = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
if verbose:
|
||||||
|
print("[dbg] generation_start")
|
||||||
|
llm.ask(prompt_for_llm, **per_call_kwargs)
|
||||||
|
print("[dbg] generation_done")
|
||||||
|
else:
|
||||||
|
with contextlib.redirect_stdout(suppress_buffer):
|
||||||
|
llm.ask(prompt_for_llm, **per_call_kwargs)
|
||||||
|
except TimeoutError:
|
||||||
|
if verbose:
|
||||||
|
print("[dbg] generation_timeout")
|
||||||
|
finally:
|
||||||
|
if timeout_handler_installed:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
signal.alarm(0)
|
||||||
|
signal.signal(signal.SIGALRM, old_handler)
|
||||||
|
end = time.perf_counter()
|
||||||
|
timings.append(end - start)
|
||||||
|
suppress_buffer.seek(0)
|
||||||
|
suppress_buffer.truncate(0)
|
||||||
|
|
||||||
|
return timings
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Measure generation timing for prompt files")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-prompts",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Optional limit on number of prompts to evaluate per file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow-truncation",
|
||||||
|
action="store_true",
|
||||||
|
help="Allow truncating prompt context to respect model's max context",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="sshleifer/tiny-gpt2",
|
||||||
|
help="LLM model identifier (default: sshleifer/tiny-gpt2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["hf", "openai", "ollama", "gemini", "simulated"],
|
||||||
|
help="LLM backend type (default: hf)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
choices=["cpu", "auto"],
|
||||||
|
help="Device override for HF models (default: cpu)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-new-tokens",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Max new tokens per generation (default: 16)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Sampling temperature (default: 0.2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="Nucleus sampling top-p (default: 0.8)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--qwen-thinking",
|
||||||
|
action="store_true",
|
||||||
|
help="Append /think to prompts for Qwen models",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-max-new-tokens",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not set max_new_tokens in generation kwargs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-call-timeout",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Optional timeout (seconds) per generation call; if hit, moves to next prompt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Stream generated text to stdout during generation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--datasets",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Comma-separated subset of datasets to run. Options: gpqa_bm25,gpqa_diskann,gpqa_hnsw. "
|
||||||
|
"Default: all"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable debug logging and show generation progress",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
dataset_map = {
|
||||||
|
# "gpqa_bm25": Path("prompt_dump_gpqa_bm25.txt"),
|
||||||
|
# "gpqa_diskann": Path("prompt_dump_gpqa_diskann.txt"),
|
||||||
|
# "gpqa_hnsw": Path("prompt_dump_gpqa_hnsw.txt"),
|
||||||
|
# "nq_bm25": Path("prompt_dump_nq_bm25.txt"),
|
||||||
|
# # "nq_diskann": Path("prompt_dump_nq_diskann.txt"),
|
||||||
|
# "nq_hnsw": Path("prompt_dump_nq_hnsw.txt"),
|
||||||
|
"gpqa_bm25": Path("prompt_dump_hotpot_bm25.txt"),
|
||||||
|
"gpqa_diskann": Path("prompt_dump_hotpot_diskann.txt"),
|
||||||
|
# "gpqa_hnsw": Path("prompt_dump_hotpot_hnsw.txt"),
|
||||||
|
# "gpqa_bm25": Path("prompt_dump_trivia_bm25.txt"),
|
||||||
|
# "gpqa_diskann": Path("prompt_dump_trivia_diskann.txt"),
|
||||||
|
}
|
||||||
|
if args.datasets:
|
||||||
|
selected = [k.strip() for k in args.datasets.split(",") if k.strip()]
|
||||||
|
invalid = [k for k in selected if k not in dataset_map]
|
||||||
|
if invalid:
|
||||||
|
raise SystemExit(f"Invalid dataset names: {invalid}. Valid: {list(dataset_map)}")
|
||||||
|
dataset_files = [dataset_map[k] for k in selected]
|
||||||
|
else:
|
||||||
|
dataset_files = list(dataset_map.values())
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_p": args.top_p,
|
||||||
|
}
|
||||||
|
if not args.no_max_new_tokens:
|
||||||
|
generation_kwargs["max_new_tokens"] = args.max_new_tokens
|
||||||
|
|
||||||
|
results: dict[str, dict[str, float | int]] = {}
|
||||||
|
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.model}
|
||||||
|
try:
|
||||||
|
llm = get_llm(llm_config)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Failed to initialize LLM: {exc}")
|
||||||
|
raise SystemExit(1) from exc
|
||||||
|
|
||||||
|
if args.llm_type == "hf" and hasattr(llm, "model") and args.device == "cpu":
|
||||||
|
llm.model = llm.model.to("cpu")
|
||||||
|
if hasattr(llm, "device"):
|
||||||
|
llm.device = "cpu"
|
||||||
|
|
||||||
|
for dataset_path in dataset_files:
|
||||||
|
print(f"Processing {dataset_path.name}...")
|
||||||
|
prompts = load_prompts(dataset_path)
|
||||||
|
if args.max_prompts is not None:
|
||||||
|
prompts = prompts[: args.max_prompts]
|
||||||
|
if args.verbose:
|
||||||
|
print(f"[dbg] loaded_prompts={len(prompts)} (showing up to --max-prompts)")
|
||||||
|
timings = measure_generation_times(
|
||||||
|
prompts,
|
||||||
|
llm,
|
||||||
|
generation_kwargs,
|
||||||
|
args.allow_truncation,
|
||||||
|
args.qwen_thinking,
|
||||||
|
args.verbose or args.stream,
|
||||||
|
args.per_call_timeout,
|
||||||
|
)
|
||||||
|
total_time = sum(timings)
|
||||||
|
count = len(timings)
|
||||||
|
average_time = total_time / count if count else 0.0
|
||||||
|
results[str(dataset_path.name)] = {
|
||||||
|
"total_prompts": count,
|
||||||
|
"total_time_seconds": total_time,
|
||||||
|
"average_time_seconds": average_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(json.dumps(results, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
76
sky/leann-build.yaml
Normal file
76
sky/leann-build.yaml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
name: leann-build
|
||||||
|
|
||||||
|
resources:
|
||||||
|
# Choose a GPU for fast embeddings (examples: L4, A10G, A100). CPU also works but is slower.
|
||||||
|
accelerators: L4:1
|
||||||
|
# Optionally pin a cloud, otherwise SkyPilot will auto-select
|
||||||
|
# cloud: aws
|
||||||
|
disk_size: 100
|
||||||
|
|
||||||
|
envs:
|
||||||
|
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
|
||||||
|
index_name: my-index
|
||||||
|
docs: ./data
|
||||||
|
backend: hnsw # hnsw | diskann
|
||||||
|
complexity: 64
|
||||||
|
graph_degree: 32
|
||||||
|
num_threads: 8
|
||||||
|
# Embedding selection
|
||||||
|
embedding_mode: sentence-transformers # sentence-transformers | openai | mlx | ollama
|
||||||
|
embedding_model: facebook/contriever
|
||||||
|
# Storage/latency knobs
|
||||||
|
recompute: true # true => selective recomputation (recommended)
|
||||||
|
compact: true # for HNSW only
|
||||||
|
# Optional pass-through
|
||||||
|
extra_args: ""
|
||||||
|
# Rebuild control
|
||||||
|
force: true
|
||||||
|
|
||||||
|
# Sync local paths to the remote VM. Adjust as needed.
|
||||||
|
file_mounts:
|
||||||
|
# Example: mount your local data directory used for building
|
||||||
|
~/leann-data: ${docs}
|
||||||
|
|
||||||
|
setup: |
|
||||||
|
set -e
|
||||||
|
# Install uv (package manager)
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
|
|
||||||
|
# Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30)
|
||||||
|
sudo apt-get update -y
|
||||||
|
sudo apt-get install -y libstdc++6 libgomp1
|
||||||
|
# Also upgrade conda's libstdc++ in base env (Skypilot images include conda)
|
||||||
|
if command -v conda >/dev/null 2>&1; then
|
||||||
|
conda install -y -n base -c conda-forge libstdcxx-ng
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install LEANN CLI and backends into the user environment
|
||||||
|
uv pip install --upgrade pip
|
||||||
|
uv pip install leann-core leann-backend-hnsw leann-backend-diskann
|
||||||
|
|
||||||
|
run: |
|
||||||
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
|
# Derive flags from env
|
||||||
|
recompute_flag=""
|
||||||
|
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
|
||||||
|
recompute_flag="--no-recompute"
|
||||||
|
fi
|
||||||
|
force_flag=""
|
||||||
|
if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then
|
||||||
|
force_flag="--force"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
python -m leann.cli build ${index_name} \
|
||||||
|
--docs ~/leann-data \
|
||||||
|
--backend ${backend} \
|
||||||
|
--complexity ${complexity} \
|
||||||
|
--graph-degree ${graph_degree} \
|
||||||
|
--num-threads ${num_threads} \
|
||||||
|
--embedding-mode ${embedding_mode} \
|
||||||
|
--embedding-model ${embedding_model} \
|
||||||
|
${recompute_flag} ${force_flag} ${extra_args}
|
||||||
|
|
||||||
|
# Print where the index is stored for downstream rsync
|
||||||
|
echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"
|
||||||
@@ -6,10 +6,11 @@ This directory contains automated tests for the LEANN project using pytest.
|
|||||||
|
|
||||||
### `test_readme_examples.py`
|
### `test_readme_examples.py`
|
||||||
Tests the examples shown in README.md:
|
Tests the examples shown in README.md:
|
||||||
- The basic example code that users see first
|
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
||||||
- Import statements work correctly
|
- Import statements work correctly
|
||||||
- Different backend options (HNSW, DiskANN)
|
- Different backend options (HNSW, DiskANN)
|
||||||
- Different LLM configuration options
|
- Different LLM configuration options (parametrized for both backends)
|
||||||
|
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
||||||
|
|
||||||
### `test_basic.py`
|
### `test_basic.py`
|
||||||
Basic functionality tests that verify:
|
Basic functionality tests that verify:
|
||||||
@@ -25,12 +26,22 @@ Tests the document RAG example functionality:
|
|||||||
- Tests error handling with invalid parameters
|
- Tests error handling with invalid parameters
|
||||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||||
|
|
||||||
|
### `test_diskann_partition.py`
|
||||||
|
Tests DiskANN graph partitioning functionality:
|
||||||
|
- Tests DiskANN index building without partitioning (baseline)
|
||||||
|
- Tests automatic graph partitioning with `is_recompute=True`
|
||||||
|
- Verifies that partition files are created and large files are cleaned up for storage saving
|
||||||
|
- Tests search functionality with partitioned indices
|
||||||
|
- Validates medoid and max_base_norm file generation and usage
|
||||||
|
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||||
|
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
```bash
|
```bash
|
||||||
# Using extras
|
# Using uv dependency groups (tools only)
|
||||||
uv pip install -e ".[test]"
|
uv sync --only-group test
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run all tests:
|
### Run all tests:
|
||||||
@@ -54,15 +65,23 @@ pytest tests/ -m "not openai"
|
|||||||
|
|
||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
|
# Run DiskANN partition tests (requires local machine, not CI)
|
||||||
|
pytest tests/test_diskann_partition.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run with specific backend:
|
### Run with specific backend:
|
||||||
```bash
|
```bash
|
||||||
# Test only HNSW backend
|
# Test only HNSW backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||||
|
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
||||||
|
|
||||||
# Test only DiskANN backend
|
# Test only DiskANN backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||||
|
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
||||||
|
|
||||||
|
# All DiskANN tests (parametrized + specialized partition tests)
|
||||||
|
pytest tests/ -k diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
397
tests/test_astchunk_integration.py
Normal file
397
tests/test_astchunk_integration.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""
|
||||||
|
Test suite for astchunk integration with LEANN.
|
||||||
|
Tests AST-aware chunking functionality, language detection, and fallback mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add apps directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "apps"))
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from chunking import (
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockDocument:
|
||||||
|
"""Mock LlamaIndex Document for testing."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, file_path: str = "", metadata: Optional[dict] = None):
|
||||||
|
self.content = content
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
if file_path:
|
||||||
|
self.metadata["file_path"] = file_path
|
||||||
|
|
||||||
|
def get_content(self) -> str:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeFileDetection:
|
||||||
|
"""Test code file detection and language mapping."""
|
||||||
|
|
||||||
|
def test_detect_code_files_python(self):
|
||||||
|
"""Test detection of Python files."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("print('hello')", "/path/to/file.py"),
|
||||||
|
MockDocument("This is text", "/path/to/file.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 1
|
||||||
|
assert len(text_docs) == 1
|
||||||
|
assert code_docs[0].metadata["language"] == "python"
|
||||||
|
assert code_docs[0].metadata["is_code"] is True
|
||||||
|
assert text_docs[0].metadata["is_code"] is False
|
||||||
|
|
||||||
|
def test_detect_code_files_multiple_languages(self):
|
||||||
|
"""Test detection of multiple programming languages."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def func():", "/path/to/script.py"),
|
||||||
|
MockDocument("public class Test {}", "/path/to/Test.java"),
|
||||||
|
MockDocument("interface ITest {}", "/path/to/test.ts"),
|
||||||
|
MockDocument("using System;", "/path/to/Program.cs"),
|
||||||
|
MockDocument("Regular text content", "/path/to/document.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 4
|
||||||
|
assert len(text_docs) == 1
|
||||||
|
|
||||||
|
languages = [doc.metadata["language"] for doc in code_docs]
|
||||||
|
assert "python" in languages
|
||||||
|
assert "java" in languages
|
||||||
|
assert "typescript" in languages
|
||||||
|
assert "csharp" in languages
|
||||||
|
|
||||||
|
def test_detect_code_files_no_file_path(self):
|
||||||
|
"""Test handling of documents without file paths."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("some content"),
|
||||||
|
MockDocument("other content", metadata={"some_key": "value"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 0
|
||||||
|
assert len(text_docs) == 2
|
||||||
|
for doc in text_docs:
|
||||||
|
assert doc.metadata["is_code"] is False
|
||||||
|
|
||||||
|
def test_get_language_from_extension(self):
|
||||||
|
"""Test language detection from file extensions."""
|
||||||
|
assert get_language_from_extension("test.py") == "python"
|
||||||
|
assert get_language_from_extension("Test.java") == "java"
|
||||||
|
assert get_language_from_extension("component.tsx") == "typescript"
|
||||||
|
assert get_language_from_extension("Program.cs") == "csharp"
|
||||||
|
assert get_language_from_extension("document.txt") is None
|
||||||
|
assert get_language_from_extension("") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingFunctions:
|
||||||
|
"""Test various chunking functionality."""
|
||||||
|
|
||||||
|
def test_create_traditional_chunks(self):
|
||||||
|
"""Test traditional text chunking."""
|
||||||
|
docs = [
|
||||||
|
MockDocument(
|
||||||
|
"This is a test document. It has multiple sentences. We want to test chunking."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_traditional_chunks_empty_docs(self):
|
||||||
|
"""Test traditional chunking with empty documents."""
|
||||||
|
chunks = create_traditional_chunks([], chunk_size=50, chunk_overlap=10)
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip astchunk tests in CI - dependency may not be available",
|
||||||
|
)
|
||||||
|
def test_create_ast_chunks_with_astchunk_available(self):
|
||||||
|
"""Test AST chunking when astchunk is available."""
|
||||||
|
python_code = '''
|
||||||
|
def hello_world():
|
||||||
|
"""Print hello world message."""
|
||||||
|
print("Hello, World!")
|
||||||
|
|
||||||
|
def add_numbers(a, b):
|
||||||
|
"""Add two numbers and return the result."""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
class Calculator:
|
||||||
|
"""A simple calculator class."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.history = []
|
||||||
|
|
||||||
|
def add(self, a, b):
|
||||||
|
result = a + b
|
||||||
|
self.history.append(f"{a} + {b} = {result}")
|
||||||
|
return result
|
||||||
|
'''
|
||||||
|
|
||||||
|
docs = [MockDocument(python_code, "/test/calculator.py", {"language": "python"})]
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
||||||
|
|
||||||
|
# Should have multiple chunks due to different functions/classes
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
|
# Check that code structure is somewhat preserved
|
||||||
|
combined_content = " ".join(chunks)
|
||||||
|
assert "def hello_world" in combined_content
|
||||||
|
assert "class Calculator" in combined_content
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# astchunk not available, should fall back to traditional chunking
|
||||||
|
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
||||||
|
assert len(chunks) > 0 # Should still get chunks from fallback
|
||||||
|
|
||||||
|
def test_create_ast_chunks_fallback_to_traditional(self):
|
||||||
|
"""Test AST chunking falls back to traditional when astchunk is not available."""
|
||||||
|
docs = [MockDocument("def test(): pass", "/test/script.py", {"language": "python"})]
|
||||||
|
|
||||||
|
# Mock astchunk import to fail
|
||||||
|
with patch("chunking.create_ast_chunks"):
|
||||||
|
# First call (actual test) should import astchunk and potentially fail
|
||||||
|
# Let's call the actual function to test the import error handling
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should return some chunks (either from astchunk or fallback)
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
def test_create_text_chunks_traditional_mode(self):
|
||||||
|
"""Test text chunking in traditional mode."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def test(): pass", "/test/script.py"),
|
||||||
|
MockDocument("This is regular text.", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_text_chunks_ast_mode(self):
|
||||||
|
"""Test text chunking in AST mode."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def test(): pass", "/test/script.py"),
|
||||||
|
MockDocument("This is regular text.", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
docs,
|
||||||
|
use_ast_chunking=True,
|
||||||
|
ast_chunk_size=100,
|
||||||
|
ast_chunk_overlap=20,
|
||||||
|
chunk_size=50,
|
||||||
|
chunk_overlap=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_text_chunks_custom_extensions(self):
|
||||||
|
"""Test text chunking with custom code file extensions."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("function test() {}", "/test/script.js"), # Not in default extensions
|
||||||
|
MockDocument("Regular text", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# First without custom extensions - should treat .js as text
|
||||||
|
chunks_without = create_text_chunks(docs, use_ast_chunking=True, code_file_extensions=None)
|
||||||
|
|
||||||
|
# Then with custom extensions - should treat .js as code
|
||||||
|
chunks_with = create_text_chunks(
|
||||||
|
docs, use_ast_chunking=True, code_file_extensions=[".js", ".jsx"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should return chunks
|
||||||
|
assert len(chunks_without) > 0
|
||||||
|
assert len(chunks_with) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationWithDocumentRAG:
|
||||||
|
"""Integration tests with the document RAG system."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_code_dir(self):
|
||||||
|
"""Create a temporary directory with sample code files."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
|
||||||
|
# Create sample Python file
|
||||||
|
python_file = temp_path / "example.py"
|
||||||
|
python_file.write_text('''
|
||||||
|
def fibonacci(n):
|
||||||
|
"""Calculate fibonacci number."""
|
||||||
|
if n <= 1:
|
||||||
|
return n
|
||||||
|
return fibonacci(n-1) + fibonacci(n-2)
|
||||||
|
|
||||||
|
class MathUtils:
|
||||||
|
@staticmethod
|
||||||
|
def factorial(n):
|
||||||
|
if n <= 1:
|
||||||
|
return 1
|
||||||
|
return n * MathUtils.factorial(n-1)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# Create sample text file
|
||||||
|
text_file = temp_path / "readme.txt"
|
||||||
|
text_file.write_text("This is a sample text file for testing purposes.")
|
||||||
|
|
||||||
|
yield temp_path
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip integration tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_document_rag_with_ast_chunking(self, temp_code_dir):
|
||||||
|
"""Test document RAG with AST chunking enabled."""
|
||||||
|
with tempfile.TemporaryDirectory() as index_dir:
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/document_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--embedding-mode",
|
||||||
|
"sentence-transformers",
|
||||||
|
"--index-dir",
|
||||||
|
index_dir,
|
||||||
|
"--data-dir",
|
||||||
|
str(temp_code_dir),
|
||||||
|
"--enable-code-chunking",
|
||||||
|
"--query",
|
||||||
|
"How does the fibonacci function work?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300, # 5 minutes
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should succeed even if astchunk is not available (fallback)
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Index saved to" in output or "Using existing index" in output
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip integration tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_code_rag_application(self, temp_code_dir):
|
||||||
|
"""Test the specialized code RAG application."""
|
||||||
|
with tempfile.TemporaryDirectory() as index_dir:
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/code_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--index-dir",
|
||||||
|
index_dir,
|
||||||
|
"--repo-dir",
|
||||||
|
str(temp_code_dir),
|
||||||
|
"--query",
|
||||||
|
"What classes are defined in this code?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, env=env)
|
||||||
|
|
||||||
|
# Should succeed
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Using AST-aware chunking" in output or "traditional chunking" in output
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
def test_text_chunking_empty_documents(self):
|
||||||
|
"""Test text chunking with empty document list."""
|
||||||
|
chunks = create_text_chunks([])
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_text_chunking_invalid_parameters(self):
|
||||||
|
"""Test text chunking with invalid parameters."""
|
||||||
|
docs = [MockDocument("test content")]
|
||||||
|
|
||||||
|
# Should handle negative chunk sizes gracefully
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
docs, chunk_size=0, chunk_overlap=0, ast_chunk_size=0, ast_chunk_overlap=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still return some result
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
def test_create_ast_chunks_no_language(self):
|
||||||
|
"""Test AST chunking with documents missing language metadata."""
|
||||||
|
docs = [MockDocument("def test(): pass", "/test/script.py")] # No language set
|
||||||
|
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should fall back to traditional chunking
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
assert len(chunks) >= 0 # May be empty if fallback also fails
|
||||||
|
|
||||||
|
def test_create_ast_chunks_empty_content(self):
|
||||||
|
"""Test AST chunking with empty content."""
|
||||||
|
docs = [MockDocument("", "/test/script.py", {"language": "python"})]
|
||||||
|
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should handle empty content gracefully
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user