Compare commits
42 Commits
v0.2.8
...
feat/add-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
679848a3b7 | ||
|
|
da811061f4 | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
|||||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
1
.github/workflows/build-and-publish.yml
vendored
1
.github/workflows/build-and-publish.yml
vendored
@@ -5,6 +5,7 @@ on:
|
|||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|||||||
174
.github/workflows/build-reusable.yml
vendored
174
.github/workflows/build-reusable.yml
vendored
@@ -54,6 +54,17 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -64,6 +75,16 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.13'
|
||||||
- os: macos-13
|
- os: macos-13
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: macos-13
|
- os: macos-13
|
||||||
@@ -77,7 +98,7 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.ref }}
|
ref: ${{ inputs.ref }}
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
@@ -88,21 +109,56 @@ jobs:
|
|||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v4
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
- name: Install system dependencies (Ubuntu)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
patchelf
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
# Debug: Show system information
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
echo "🔍 System Information:"
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
echo "Architecture: $(uname -m)"
|
||||||
source /opt/intel/oneapi/setvars.sh
|
echo "OS: $(uname -a)"
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
|
||||||
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
- name: Install system dependencies (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
@@ -147,7 +203,14 @@ jobs:
|
|||||||
# Use system clang for better compatibility
|
# Use system clang for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
export MACOSX_DEPLOYMENT_TARGET=11.0
|
# Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
@@ -161,7 +224,14 @@ jobs:
|
|||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
# But Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
else
|
else
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
@@ -197,10 +267,24 @@ jobs:
|
|||||||
- name: Repair wheels (macOS)
|
- name: Repair wheels (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
run: |
|
run: |
|
||||||
|
# Determine deployment target based on runner OS
|
||||||
|
# Must match the Homebrew libraries for each macOS version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
HNSW_TARGET="13.0"
|
||||||
|
DISKANN_TARGET="13.3"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
fi
|
||||||
|
|
||||||
# Repair HNSW wheel
|
# Repair HNSW wheel
|
||||||
cd packages/leann-backend-hnsw
|
cd packages/leann-backend-hnsw
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -209,7 +293,8 @@ jobs:
|
|||||||
# Repair DiskANN wheel
|
# Repair DiskANN wheel
|
||||||
cd packages/leann-backend-diskann
|
cd packages/leann-backend-diskann
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -238,19 +323,16 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests with pytest
|
- name: Run tests with pytest
|
||||||
env:
|
env:
|
||||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
CI: true
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
HF_HUB_DISABLE_SYMLINKS: 1
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
TOKENIZERS_PARALLELISM: false
|
TOKENIZERS_PARALLELISM: false
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||||
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
OMP_NUM_THREADS: 1
|
||||||
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
MKL_NUM_THREADS: 1
|
||||||
run: |
|
run: |
|
||||||
# Activate virtual environment
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
# Run all tests
|
|
||||||
pytest tests/
|
|
||||||
|
|
||||||
- name: Run sanity checks (optional)
|
- name: Run sanity checks (optional)
|
||||||
run: |
|
run: |
|
||||||
@@ -268,3 +350,53 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
path: packages/*/dist/
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
|
arch-smoke:
|
||||||
|
name: Arch Linux smoke test (install & import)
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: archlinux:latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare system
|
||||||
|
run: |
|
||||||
|
pacman -Syu --noconfirm
|
||||||
|
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||||
|
|
||||||
|
- name: Download ALL wheel artifacts from this run
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
# Don't specify name, download all artifacts
|
||||||
|
path: ./wheels
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Create virtual environment and install wheels
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
uv pip install --find-links wheels leann-core
|
||||||
|
uv pip install --find-links wheels leann-backend-hnsw
|
||||||
|
uv pip install --find-links wheels leann-backend-diskann
|
||||||
|
uv pip install --find-links wheels leann
|
||||||
|
|
||||||
|
- name: Import & tiny runtime check
|
||||||
|
env:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
python - <<'PY'
|
||||||
|
import leann
|
||||||
|
import leann_backend_hnsw as h
|
||||||
|
import leann_backend_diskann as d
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
b = LeannBuilder(backend_name="hnsw")
|
||||||
|
b.add_text("hello arch")
|
||||||
|
b.build_index("arch_demo.leann")
|
||||||
|
s = LeannSearcher("arch_demo.leann")
|
||||||
|
print("search:", s.search("hello", top_k=1))
|
||||||
|
PY
|
||||||
|
|||||||
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: lycheeverse/lychee-action@v2
|
- uses: lycheeverse/lychee-action@v2
|
||||||
with:
|
with:
|
||||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -18,9 +18,11 @@ demo/experiment_results/**/*.json
|
|||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
|
!.vscode/*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -92,3 +94,10 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
tests/__pytest_cache__/
|
||||||
tests/__pycache__/
|
tests/__pycache__/
|
||||||
|
paru-bin/
|
||||||
|
|
||||||
|
CLAUDE.md
|
||||||
|
CLAUDE.local.md
|
||||||
|
.claude/*.local.*
|
||||||
|
.claude/local/*
|
||||||
|
benchmarks/data/
|
||||||
|
|||||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -14,3 +14,7 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = git@github.com:yichuan-w/astchunk-leann.git
|
||||||
|
branch = main
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.5.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -10,7 +10,8 @@ repos:
|
|||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.2.1
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"recommendations": [
|
||||||
|
"charliermarsh.ruff",
|
||||||
|
]
|
||||||
|
}
|
||||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
|
"editor.formatOnSave": true,
|
||||||
|
"editor.codeActionsOnSave": {
|
||||||
|
"source.organizeImports": "explicit",
|
||||||
|
"source.fixAll": "explicit"
|
||||||
|
},
|
||||||
|
"editor.insertSpaces": true,
|
||||||
|
"editor.tabSize": 4
|
||||||
|
},
|
||||||
|
"ruff.enable": true,
|
||||||
|
"files.watcherExclude": {
|
||||||
|
"**/.venv/**": true,
|
||||||
|
"**/__pycache__/**": true,
|
||||||
|
"**/*.egg-info/**": true,
|
||||||
|
"**/build/**": true,
|
||||||
|
"**/dist/**": true
|
||||||
|
}
|
||||||
|
}
|
||||||
238
README.md
238
README.md
@@ -5,9 +5,11 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
||||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
||||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||||
|
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||||
|
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -31,7 +33,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
@@ -70,6 +72,8 @@ uv venv
|
|||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install leann
|
uv pip install leann
|
||||||
```
|
```
|
||||||
|
<!--
|
||||||
|
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>
|
<summary>
|
||||||
@@ -85,15 +89,60 @@ git submodule update --init --recursive
|
|||||||
```
|
```
|
||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
|
|
||||||
|
Note: DiskANN requires MacOS 13.3 or later.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
brew install libomp boost protobuf zeromq pkgconf
|
||||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux:**
|
**Linux (Ubuntu/Debian):**
|
||||||
|
|
||||||
|
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
||||||
|
|
||||||
|
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
sudo apt-get update && sudo apt-get install -y \
|
||||||
uv sync
|
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
libmkl-full-dev
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Arch Linux):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
||||||
|
boost boost-libs protobuf abseil-cpp libaio zeromq
|
||||||
|
|
||||||
|
# For MKL in DiskANN
|
||||||
|
sudo pacman -S --needed base-devel git
|
||||||
|
git clone https://aur.archlinux.org/paru-bin.git
|
||||||
|
cd paru-bin && makepkg -si
|
||||||
|
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
||||||
|
|
||||||
|
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo dnf groupinstall -y "Development Tools"
|
||||||
|
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
||||||
|
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
||||||
|
|
||||||
|
# For MKL in DiskANN
|
||||||
|
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
||||||
|
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -129,6 +178,8 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Generation Model Setup
|
### Generation Model Setup
|
||||||
|
|
||||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||||
@@ -171,7 +222,8 @@ ollama pull llama3.2:1b
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### ⭐ Flexible Configuration
|
|
||||||
|
## ⭐ Flexible Configuration
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
@@ -184,34 +236,34 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Core Parameters (General preprocessing for all examples)
|
# Core Parameters (General preprocessing for all examples)
|
||||||
--index-dir DIR # Directory to store the index (default: current directory)
|
--index-dir DIR # Directory to store the index (default: current directory)
|
||||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||||
--force-rebuild # Force rebuild index even if it exists
|
--force-rebuild # Force rebuild index even if it exists
|
||||||
|
|
||||||
# Embedding Parameters
|
# Embedding Parameters
|
||||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# LLM Parameters (Text generation models)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||||
|
|
||||||
# Search Parameters
|
# Search Parameters
|
||||||
--top-k N # Number of results to retrieve (default: 20)
|
--top-k N # Number of results to retrieve (default: 20)
|
||||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||||
|
|
||||||
# Chunking Parameters
|
# Chunking Parameters
|
||||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||||
|
|
||||||
# Index Building Parameters
|
# Index Building Parameters
|
||||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||||
--graph-degree N # Graph degree for index construction (default: 32)
|
--graph-degree N # Graph degree for index construction (default: 32)
|
||||||
--build-complexity N # Build complexity for index construction (default: 64)
|
--build-complexity N # Build complexity for index construction (default: 64)
|
||||||
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||||
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -247,6 +299,12 @@ python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
|||||||
|
|
||||||
# Filter only markdown and Python files with smaller chunks
|
# Filter only markdown and Python files with smaller chunks
|
||||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||||
|
|
||||||
|
# Enable AST-aware chunking for code files
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
||||||
|
|
||||||
|
# Or use the specialized code RAG for better code understanding
|
||||||
|
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -421,24 +479,34 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||||
|
|
||||||
|
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||||
|
|
||||||
|
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||||
|
|
||||||
**Key features:**
|
**Key features:**
|
||||||
- 🔍 **Semantic code search** across your entire project
|
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
||||||
|
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
||||||
- 📚 **Context-aware assistance** for debugging and development
|
- 📚 **Context-aware assistance** for debugging and development
|
||||||
- 🚀 **Zero-config setup** with automatic language detection
|
- 🚀 **Zero-config setup** with automatic language detection
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install LEANN globally for MCP integration
|
# Install LEANN globally for MCP integration
|
||||||
uv tool install leann-core
|
uv tool install leann-core --with leann
|
||||||
|
claude mcp add --scope user leann-server -- leann_mcp
|
||||||
# Setup is automatic - just start using Claude Code!
|
# Setup is automatic - just start using Claude Code!
|
||||||
```
|
```
|
||||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||||
|
|
||||||
## 🖥️ Command Line Interface
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
@@ -455,7 +523,8 @@ leann --help
|
|||||||
**To make it globally available:**
|
**To make it globally available:**
|
||||||
```bash
|
```bash
|
||||||
# Install the LEANN CLI globally using uv tool
|
# Install the LEANN CLI globally using uv tool
|
||||||
uv tool install leann
|
uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
|
||||||
# Now you can use leann from anywhere without activating venv
|
# Now you can use leann from anywhere without activating venv
|
||||||
leann --help
|
leann --help
|
||||||
@@ -479,30 +548,36 @@ leann ask my-docs --interactive
|
|||||||
|
|
||||||
# List all your indexes
|
# List all your indexes
|
||||||
leann list
|
leann list
|
||||||
|
|
||||||
|
# Remove an index
|
||||||
|
leann remove my-docs
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key CLI features:**
|
**Key CLI features:**
|
||||||
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
||||||
- Smart text chunking with overlap
|
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
||||||
|
- Smart text chunking with overlap for all other content
|
||||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
- Organized index storage in `~/.leann/indexes/`
|
- Organized index storage in `.leann/indexes/` (project-local)
|
||||||
- Support for advanced search parameters
|
- Support for advanced search parameters
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
|
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
|
||||||
|
|
||||||
**Build Command:**
|
**Build Command:**
|
||||||
```bash
|
```bash
|
||||||
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
--graph-degree N Graph degree (default: 32)
|
--graph-degree N Graph degree (default: 32)
|
||||||
--complexity N Build complexity (default: 64)
|
--complexity N Build complexity (default: 64)
|
||||||
--force Force rebuild existing index
|
--force Force rebuild existing index
|
||||||
--compact Use compact storage (default: true)
|
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
||||||
--recompute Enable recomputation (default: true)
|
--recompute / --no-recompute Enable recomputation (default: true)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Search Command:**
|
**Search Command:**
|
||||||
@@ -510,9 +585,9 @@ Options:
|
|||||||
leann search INDEX_NAME QUERY [OPTIONS]
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--top-k N Number of results (default: 5)
|
--top-k N Number of results (default: 5)
|
||||||
--complexity N Search complexity (default: 64)
|
--complexity N Search complexity (default: 64)
|
||||||
--recompute-embeddings Use recomputation for highest accuracy
|
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
||||||
--pruning-strategy {global,local,proportional}
|
--pruning-strategy {global,local,proportional}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -527,8 +602,73 @@ Options:
|
|||||||
--top-k N Retrieval count (default: 20)
|
--top-k N Retrieval count (default: 20)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**List Command:**
|
||||||
|
```bash
|
||||||
|
leann list
|
||||||
|
|
||||||
|
# Lists all indexes across all projects with status indicators:
|
||||||
|
# ✅ - Index is complete and ready to use
|
||||||
|
# ❌ - Index is incomplete or corrupted
|
||||||
|
# 📁 - CLI-created index (in .leann/indexes/)
|
||||||
|
# 📄 - App-created index (*.leann.meta.json files)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Remove Command:**
|
||||||
|
```bash
|
||||||
|
leann remove INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--force, -f Force removal without confirmation
|
||||||
|
|
||||||
|
# Smart removal: automatically finds and safely removes indexes
|
||||||
|
# - Shows all matching indexes across projects
|
||||||
|
# - Requires confirmation for cross-project removal
|
||||||
|
# - Interactive selection when multiple matches found
|
||||||
|
# - Supports both CLI and app-created indexes
|
||||||
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 🚀 Advanced Features
|
||||||
|
|
||||||
|
### 🎯 Metadata Filtering
|
||||||
|
|
||||||
|
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Add metadata during indexing
|
||||||
|
builder.add_text(
|
||||||
|
"def authenticate_user(token): ...",
|
||||||
|
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||||
|
|
||||||
|
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||||
|
|
||||||
|
### 🔍 Grep Search
|
||||||
|
|
||||||
|
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||||
|
|
||||||
|
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -543,12 +683,16 @@ Options:
|
|||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
**Backends:** HNSW (default) for most use cases, with optional DiskANN support for billion-scale datasets.
|
**Backends:**
|
||||||
|
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||||
|
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
|
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||||
|
|
||||||
|
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||||
|
|
||||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)**
|
|
||||||
### 📊 Storage Comparison
|
### 📊 Storage Comparison
|
||||||
|
|
||||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
@@ -564,6 +708,7 @@ Options:
|
|||||||
```bash
|
```bash
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||||
|
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||||
```
|
```
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
@@ -603,6 +748,9 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
|
|
||||||
|
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -69,14 +69,14 @@ class BaseRAGExample(ABC):
|
|||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
type=str,
|
type=str,
|
||||||
default=embedding_model_default,
|
default=embedding_model_default,
|
||||||
help=f"Embedding model to use (default: {embedding_model_default})",
|
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||||
)
|
)
|
||||||
embedding_group.add_argument(
|
embedding_group.add_argument(
|
||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
)
|
)
|
||||||
|
|
||||||
# LLM parameters
|
# LLM parameters
|
||||||
@@ -86,13 +86,13 @@ class BaseRAGExample(ABC):
|
|||||||
type=str,
|
type=str,
|
||||||
default="openai",
|
default="openai",
|
||||||
choices=["openai", "ollama", "hf", "simulated"],
|
choices=["openai", "ollama", "hf", "simulated"],
|
||||||
help="LLM backend to use (default: openai)",
|
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-model",
|
"--llm-model",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-host",
|
"--llm-host",
|
||||||
@@ -108,6 +108,38 @@ class BaseRAGExample(ABC):
|
|||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# AST Chunking parameters
|
||||||
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum characters per AST chunk (default: 512)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Overlap between AST chunks (default: 64)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--code-file-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
# Search parameters
|
# Search parameters
|
||||||
search_group = parser.add_argument_group("Search Parameters")
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
search_group.add_argument(
|
search_group.add_argument(
|
||||||
@@ -178,6 +210,9 @@ class BaseRAGExample(ABC):
|
|||||||
config["host"] = args.llm_host
|
config["host"] = args.llm_host
|
||||||
elif args.llm == "hf":
|
elif args.llm == "hf":
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
elif args.llm == "simulated":
|
||||||
|
# Simulated LLM doesn't need additional configuration
|
||||||
|
pass
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -211,6 +246,11 @@ class BaseRAGExample(ABC):
|
|||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index saved to: {index_path}")
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
# Register project directory so leann list can discover this index
|
||||||
|
# The index is saved as args.index_dir/index_name.leann
|
||||||
|
# We want to register the current working directory where the app is run
|
||||||
|
register_project_directory(Path.cwd())
|
||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
async def run_interactive_chat(self, args, index_path: str):
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
@@ -259,7 +299,6 @@ class BaseRAGExample(ABC):
|
|||||||
chat = LeannChat(
|
chat = LeannChat(
|
||||||
index_path,
|
index_path,
|
||||||
llm_config=self.get_llm_config(args),
|
llm_config=self.get_llm_config(args),
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
complexity=args.search_complexity,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -301,21 +340,3 @@ class BaseRAGExample(ABC):
|
|||||||
await self.run_single_query(args, index_path, args.query)
|
await self.run_single_query(args, index_path, args.query)
|
||||||
else:
|
else:
|
||||||
await self.run_interactive_chat(args, index_path)
|
await self.run_interactive_chat(args, index_path)
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
|
||||||
"""Helper function to create text chunks from documents."""
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator=" ",
|
|
||||||
paragraph_separator="\n\n",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
if nodes:
|
|
||||||
all_texts.extend(node.get_content() for node in nodes)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .history_data.history import ChromeHistoryReader
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|||||||
44
apps/chunking/__init__.py
Normal file
44
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Unified chunking utilities facade.
|
||||||
|
|
||||||
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"create_ast_chunks",
|
||||||
|
"create_text_chunks",
|
||||||
|
"create_traditional_chunks",
|
||||||
|
"detect_code_files",
|
||||||
|
"get_language_from_extension",
|
||||||
|
]
|
||||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||||
|
Specialized for code repositories with automatic language detection and
|
||||||
|
optimized chunking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRAG(BaseRAGExample):
|
||||||
|
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Code",
|
||||||
|
description="Process and query code repositories with AST-aware chunking",
|
||||||
|
default_index_name="code_index",
|
||||||
|
)
|
||||||
|
# Override defaults for code-specific usage
|
||||||
|
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||||
|
self.max_items_default = -1 # Process all code files by default
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add code-specific arguments."""
|
||||||
|
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||||
|
|
||||||
|
code_group.add_argument(
|
||||||
|
"--repo-dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Code repository directory to index (default: current directory)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=list(CODE_EXTENSIONS.keys()),
|
||||||
|
help="File extensions to include (default: supported code extensions)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--exclude-dirs",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
],
|
||||||
|
help="Directories to exclude from indexing",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=int,
|
||||||
|
default=1000000, # 1MB
|
||||||
|
help="Maximum file size in bytes to process (default: 1MB)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-comments",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in chunking (useful for documentation)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--preserve-imports",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
|
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||||
|
|
||||||
|
# Check if repository directory exists
|
||||||
|
repo_path = Path(args.repo_dir)
|
||||||
|
if not repo_path.exists():
|
||||||
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
|
# Load code files with filtering
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"required_exts": args.include_extensions,
|
||||||
|
"exclude_hidden": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create exclusion filter
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
"""Filter out unwanted files and directories."""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
if path.stat().st_size > args.max_file_size:
|
||||||
|
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if in excluded directory
|
||||||
|
for exclude_dir in args.exclude_dirs:
|
||||||
|
if exclude_dir in path.parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load documents with file filtering
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.repo_dir,
|
||||||
|
file_extractor=None, # Use default extractors
|
||||||
|
**reader_kwargs,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
# Apply custom filtering
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_filter(file_path):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading code files: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(
|
||||||
|
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Loaded {len(documents)} code files")
|
||||||
|
|
||||||
|
# Show breakdown by language/extension
|
||||||
|
ext_counts = {}
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||||
|
|
||||||
|
print("📊 Files by extension:")
|
||||||
|
for ext, count in sorted(ext_counts.items()):
|
||||||
|
print(f" {ext}: {count} files")
|
||||||
|
|
||||||
|
# Use AST-aware chunking by default for code
|
||||||
|
print(
|
||||||
|
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=256, # Fallback for non-code files
|
||||||
|
chunk_overlap=64,
|
||||||
|
use_ast_chunking=True, # Always use AST for code RAG
|
||||||
|
ast_chunk_size=args.ast_chunk_size,
|
||||||
|
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||||
|
code_file_extensions=args.include_extensions,
|
||||||
|
ast_fallback_traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for code RAG
|
||||||
|
print("\n💻 Code RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'How does the embedding computation work?'")
|
||||||
|
print("- 'What are the main classes in this codebase?'")
|
||||||
|
print("- 'Show me the search implementation'")
|
||||||
|
print("- 'How is error handling implemented?'")
|
||||||
|
print("- 'What design patterns are used?'")
|
||||||
|
print("- 'Explain the chunking logic'")
|
||||||
|
print("\n🚀 Features:")
|
||||||
|
print("- ✅ AST-aware chunking preserves code structure")
|
||||||
|
print("- ✅ Automatic language detection")
|
||||||
|
print("- ✅ Smart filtering of large files and common excludes")
|
||||||
|
print("- ✅ Optimized for code understanding")
|
||||||
|
print("\nUsage examples:")
|
||||||
|
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||||
|
print(
|
||||||
|
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = CodeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
@@ -44,6 +45,11 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
doc_group.add_argument(
|
doc_group.add_argument(
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
)
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--enable-code-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[str]:
|
||||||
"""Load documents and convert to text chunks."""
|
"""Load documents and convert to text chunks."""
|
||||||
@@ -76,9 +82,22 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
|
|
||||||
print(f"Loaded {len(documents)} documents")
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
# Convert to text chunks
|
# Determine chunking strategy
|
||||||
|
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("Using AST-aware chunking for code files")
|
||||||
|
|
||||||
|
# Convert to text chunks with optional AST support
|
||||||
all_texts = create_text_chunks(
|
all_texts = create_text_chunks(
|
||||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
use_ast_chunking=use_ast,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||||
|
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
# Apply max_items limit if specified
|
||||||
@@ -102,6 +121,10 @@ if __name__ == "__main__":
|
|||||||
print(
|
print(
|
||||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
)
|
)
|
||||||
|
print("\n🚀 NEW: Code-aware chunking available!")
|
||||||
|
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||||
|
print("- Supports Python, Java, C#, TypeScript files")
|
||||||
|
print("- Better semantic understanding of code structure")
|
||||||
print("\nOr run without --query for interactive mode\n")
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
rag = DocumentRAG()
|
rag = DocumentRAG()
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .email_data.LEANN_email_reader import EmlxReader
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
|
|||||||
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,9 +1,24 @@
|
|||||||
# 🧪 Leann Sanity Checks
|
# 🧪 LEANN Benchmarks & Testing
|
||||||
|
|
||||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
|
### `diskann_vs_hnsw_speed_comparison.py`
|
||||||
|
Performance comparison between DiskANN and HNSW backends:
|
||||||
|
- ✅ **Search latency** comparison with both backends using recompute
|
||||||
|
- ✅ **Index size** and **build time** measurements
|
||||||
|
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||||
|
- ✅ **Configurable dataset sizes** for different scales
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick comparison with 500 docs, 10 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||||
|
|
||||||
|
# Large-scale comparison with 2000 docs, 20 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||||
|
```
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
|
|||||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_exists(index_path: str) -> bool:
|
||||||
|
p = Path(index_path)
|
||||||
|
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||||
|
# if _meta_exists(index_path):
|
||||||
|
# return
|
||||||
|
kwargs = {}
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
kwargs["is_compact"] = is_recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||||
|
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
num_threads=4,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for i in range(num_docs):
|
||||||
|
builder.add_text(
|
||||||
|
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||||
|
)
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_group(
|
||||||
|
index_path: str,
|
||||||
|
recompute: bool,
|
||||||
|
query: str,
|
||||||
|
repeats: int,
|
||||||
|
complexity: int = 32,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> float:
|
||||||
|
# Independent searcher per group; fixed port when recompute
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
|
||||||
|
# Warm-up once
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _once() -> float:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
return time.time() - t0
|
||||||
|
|
||||||
|
if repeats <= 1:
|
||||||
|
t = _once()
|
||||||
|
else:
|
||||||
|
vals = [_once() for _ in range(repeats)]
|
||||||
|
vals.sort()
|
||||||
|
t = vals[len(vals) // 2]
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-docs", type=int, default=5000)
|
||||||
|
parser.add_argument("--repeats", type=int, default=3)
|
||||||
|
parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# ---------- Build HNSW variants ----------
|
||||||
|
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||||
|
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||||
|
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||||
|
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Build DiskANN variants ----------
|
||||||
|
diskann_r = str(base / "diskann_r.leann")
|
||||||
|
diskann_nr = str(base / "diskann_nr.leann")
|
||||||
|
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||||
|
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Helpers ----------
|
||||||
|
def _size_for(prefix: str) -> int:
|
||||||
|
p = Path(prefix)
|
||||||
|
base_dir = p.parent
|
||||||
|
stem = p.stem
|
||||||
|
total = 0
|
||||||
|
for f in base_dir.iterdir():
|
||||||
|
if f.is_file() and f.name.startswith(stem):
|
||||||
|
total += f.stat().st_size
|
||||||
|
return total
|
||||||
|
|
||||||
|
# ---------- HNSW benchmark ----------
|
||||||
|
t_hnsw_r = _bench_group(
|
||||||
|
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_hnsw_nr = _bench_group(
|
||||||
|
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
size_hnsw_r = _size_for(hnsw_r)
|
||||||
|
size_hnsw_nr = _size_for(hnsw_nr)
|
||||||
|
|
||||||
|
print("Benchmark results (HNSW):")
|
||||||
|
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(
|
||||||
|
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||||
|
)
|
||||||
|
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||||
|
|
||||||
|
# ---------- DiskANN benchmark ----------
|
||||||
|
t_diskann_r = _bench_group(
|
||||||
|
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_diskann_nr = _bench_group(
|
||||||
|
diskann_nr,
|
||||||
|
False,
|
||||||
|
"DiskANN NR test doc 123",
|
||||||
|
repeats=args.repeats,
|
||||||
|
complexity=args.complexity,
|
||||||
|
)
|
||||||
|
size_diskann_r = _size_for(diskann_r)
|
||||||
|
size_diskann_nr = _size_for(diskann_nr)
|
||||||
|
|
||||||
|
print("\nBenchmark results (DiskANN):")
|
||||||
|
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||||
|
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
82
benchmarks/data/.gitattributes
vendored
82
benchmarks/data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - uncompressed
|
|
||||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - compressed
|
|
||||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - uncompressed
|
|
||||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.png filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - compressed
|
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Video files - compressed
|
|
||||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
44
benchmarks/data/README.md
Executable file
44
benchmarks/data/README.md
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
---
|
||||||
|
|
||||||
|
# LEANN-RAG Evaluation Data
|
||||||
|
|
||||||
|
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||||
|
|
||||||
|
## Dataset Components
|
||||||
|
|
||||||
|
This dataset is structured into three main parts:
|
||||||
|
|
||||||
|
1. **Pre-built LEANN Indices**:
|
||||||
|
* `dpr/`: A pre-built index for the DPR dataset.
|
||||||
|
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||||
|
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||||
|
|
||||||
|
2. **Ground Truth Data**:
|
||||||
|
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||||
|
|
||||||
|
3. **Queries**:
|
||||||
|
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir="data"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||||
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DiskANN vs HNSW Search Performance Comparison
|
||||||
|
|
||||||
|
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||||
|
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||||
|
- HNSW: With recompute enabled (is_recompute=True)
|
||||||
|
- Tests performance across different dataset sizes
|
||||||
|
- Measures search latency, recall, and index size
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import multiprocessing as mp
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||||
|
try:
|
||||||
|
mp.set_start_method("fork", force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_texts(n_docs: int) -> list[str]:
|
||||||
|
"""Create synthetic test documents for benchmarking."""
|
||||||
|
np.random.seed(42)
|
||||||
|
topics = [
|
||||||
|
"machine learning and artificial intelligence",
|
||||||
|
"natural language processing and text analysis",
|
||||||
|
"computer vision and image recognition",
|
||||||
|
"data science and statistical analysis",
|
||||||
|
"deep learning and neural networks",
|
||||||
|
"information retrieval and search engines",
|
||||||
|
"database systems and data management",
|
||||||
|
"software engineering and programming",
|
||||||
|
"cybersecurity and network protection",
|
||||||
|
"cloud computing and distributed systems",
|
||||||
|
]
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
for i in range(n_docs):
|
||||||
|
topic = topics[i % len(topics)]
|
||||||
|
variation = np.random.randint(1, 100)
|
||||||
|
text = (
|
||||||
|
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||||
|
f"Additional information about {topic} with details and examples. "
|
||||||
|
f"Technical discussion of {topic} including implementation aspects."
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_backend(
|
||||||
|
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark a specific backend with the given configuration."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure index size
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||||
|
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||||
|
size_mb = total_size / (1024 * 1024)
|
||||||
|
|
||||||
|
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||||
|
|
||||||
|
# Search benchmark
|
||||||
|
print("🔍 Running search benchmark...")
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||||
|
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||||
|
|
||||||
|
# Check for valid scores (detect -inf issues)
|
||||||
|
all_scores = [
|
||||||
|
result.score
|
||||||
|
for results in all_results
|
||||||
|
for result in results
|
||||||
|
if result.score is not None
|
||||||
|
]
|
||||||
|
valid_scores = [
|
||||||
|
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||||
|
]
|
||||||
|
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||||
|
|
||||||
|
# Clean up (ensure embedding server shutdown and object GC)
|
||||||
|
try:
|
||||||
|
if hasattr(searcher, "cleanup"):
|
||||||
|
searcher.cleanup()
|
||||||
|
del searcher
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"build_time": build_time,
|
||||||
|
"avg_search_time_ms": avg_search_time,
|
||||||
|
"index_size_mb": size_mb,
|
||||||
|
"score_validity_rate": score_validity_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||||
|
"""Run performance comparison between DiskANN and HNSW."""
|
||||||
|
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||||
|
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
texts = create_test_texts(n_docs)
|
||||||
|
test_queries = [
|
||||||
|
"machine learning algorithms",
|
||||||
|
"natural language processing",
|
||||||
|
"computer vision techniques",
|
||||||
|
"data analysis methods",
|
||||||
|
"neural network architectures",
|
||||||
|
"database query optimization",
|
||||||
|
"software development practices",
|
||||||
|
"security vulnerabilities",
|
||||||
|
"cloud infrastructure",
|
||||||
|
"distributed computing",
|
||||||
|
][:n_queries]
|
||||||
|
|
||||||
|
# HNSW benchmark
|
||||||
|
hnsw_results = benchmark_backend(
|
||||||
|
backend_name="hnsw",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable recompute for fair comparison
|
||||||
|
"M": 16,
|
||||||
|
"efConstruction": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiskANN benchmark
|
||||||
|
diskann_results = benchmark_backend(
|
||||||
|
backend_name="diskann",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable graph partitioning
|
||||||
|
"num_neighbors": 32,
|
||||||
|
"search_list_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
print("\n📈 Performance Comparison Results")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||||
|
print(f"{'-' * 60}")
|
||||||
|
|
||||||
|
# Build time comparison
|
||||||
|
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||||
|
print(
|
||||||
|
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search time comparison
|
||||||
|
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||||
|
print(
|
||||||
|
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index size comparison
|
||||||
|
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||||
|
print(
|
||||||
|
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score validity
|
||||||
|
print(
|
||||||
|
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print("\n🎯 Summary:")
|
||||||
|
if search_speedup > 1:
|
||||||
|
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||||
|
else:
|
||||||
|
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||||
|
|
||||||
|
if size_ratio > 1:
|
||||||
|
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||||
|
else:
|
||||||
|
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle help request
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||||
|
print()
|
||||||
|
print("Arguments:")
|
||||||
|
print(" n_docs Number of documents to index (default: 500)")
|
||||||
|
print(" n_queries Number of test queries to run (default: 10)")
|
||||||
|
print()
|
||||||
|
print("Examples:")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||||
|
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||||
|
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||||
|
print()
|
||||||
|
|
||||||
|
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Benchmark interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Benchmark failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||||
|
try:
|
||||||
|
gc.collect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
# Flush stdio to ensure message is visible before hard-exit
|
||||||
|
try:
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
_sys.stdout.flush()
|
||||||
|
_sys.stderr.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||||
|
import os as _os
|
||||||
|
|
||||||
|
_os._exit(0)
|
||||||
@@ -12,7 +12,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
@@ -197,6 +197,25 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
@@ -318,9 +337,24 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
new_results = searcher.search(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ except ImportError:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: list[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
@@ -179,10 +179,16 @@ class Benchmark:
|
|||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|||||||
143
docs/ast_chunking_guide.md
Normal file
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# AST-Aware Code chunking guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Enable AST chunking for mixed content (code + docs)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||||
|
|
||||||
|
# Specialized code repository indexing
|
||||||
|
python -m apps.code_rag --repo-dir ./my_codebase
|
||||||
|
|
||||||
|
# Global CLI with AST support
|
||||||
|
leann build my-code-index --docs ./src --use-ast-chunking
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN with AST chunking support
|
||||||
|
uv pip install -e "."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### When to Use AST Chunking
|
||||||
|
|
||||||
|
✅ **Recommended for:**
|
||||||
|
- Code repositories with multiple languages
|
||||||
|
- Mixed documentation and code content
|
||||||
|
- Complex codebases with deep function/class hierarchies
|
||||||
|
- When working with Claude Code for code assistance
|
||||||
|
|
||||||
|
❌ **Not recommended for:**
|
||||||
|
- Pure text documents
|
||||||
|
- Very large files (>1MB)
|
||||||
|
- Languages not supported by tree-sitter
|
||||||
|
|
||||||
|
### Optimal Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended settings for most codebases
|
||||||
|
python -m apps.code_rag \
|
||||||
|
--repo-dir ./src \
|
||||||
|
--ast-chunk-size 768 \
|
||||||
|
--ast-chunk-overlap 96 \
|
||||||
|
--exclude-dirs .git __pycache__ node_modules build dist
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Languages
|
||||||
|
|
||||||
|
| Extension | Language | Status |
|
||||||
|
|-----------|----------|--------|
|
||||||
|
| `.py` | Python | ✅ Full support |
|
||||||
|
| `.java` | Java | ✅ Full support |
|
||||||
|
| `.cs` | C# | ✅ Full support |
|
||||||
|
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||||
|
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||||
|
|
||||||
|
## Integration Examples
|
||||||
|
|
||||||
|
### Document RAG with Code Support
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Enable code chunking in document RAG
|
||||||
|
python -m apps.document_rag \
|
||||||
|
--enable-code-chunking \
|
||||||
|
--data-dir ./project \
|
||||||
|
--query "How does authentication work in the codebase?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Code Integration
|
||||||
|
|
||||||
|
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||||
|
- Code completion and suggestions
|
||||||
|
- Bug analysis and debugging
|
||||||
|
- Architecture understanding
|
||||||
|
- Refactoring assistance
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Fallback to Traditional Chunking**
|
||||||
|
- Normal behavior for unsupported languages
|
||||||
|
- Check logs for specific language support
|
||||||
|
|
||||||
|
2. **Performance with Large Files**
|
||||||
|
- Adjust `--max-file-size` parameter
|
||||||
|
- Use `--exclude-dirs` to skip unnecessary directories
|
||||||
|
|
||||||
|
3. **Quality Issues**
|
||||||
|
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||||
|
- Adjust overlap for better context preservation
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LEANN_LOG_LEVEL=DEBUG
|
||||||
|
python -m apps.code_rag --repo-dir ./my_code
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Traditional Chunking
|
||||||
|
|
||||||
|
Existing workflows continue to work without changes. To enable AST chunking:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Before
|
||||||
|
python -m apps.document_rag --chunk-size 256
|
||||||
|
|
||||||
|
# After (maintains traditional chunking for non-code files)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||||
|
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||||
|
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||||
@@ -52,7 +52,7 @@ Based on our experience developing LEANN, embedding models fall into three categ
|
|||||||
### Quick Start: Cloud and Local Embedding Options
|
### Quick Start: Cloud and Local Embedding Options
|
||||||
|
|
||||||
**OpenAI Embeddings (Fastest Setup)**
|
**OpenAI Embeddings (Fastest Setup)**
|
||||||
For immediate testing without local model downloads:
|
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
||||||
```bash
|
```bash
|
||||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
@@ -97,16 +97,24 @@ ollama pull nomic-embed-text
|
|||||||
```
|
```
|
||||||
|
|
||||||
### DiskANN
|
### DiskANN
|
||||||
**Best for**: Large datasets (> 10M vectors, 10GB+ index size) - **⚠️ Beta version, still in active development**
|
**Best for**: Large datasets, especially when you want `recompute=True`.
|
||||||
- Uses Product Quantization (PQ) for coarse filtering during graph traversal
|
|
||||||
- Novel approach: stores only PQ codes, performs rerank with exact computation in final step
|
**Key advantages:**
|
||||||
- Implements a corner case of double-queue: prunes all neighbors and recomputes at the end
|
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
||||||
|
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
||||||
|
- **Better scaling**: Designed for 100k+ documents
|
||||||
|
|
||||||
|
**Recompute behavior:**
|
||||||
|
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
||||||
|
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# For billion-scale deployments
|
# Recommended for most use cases
|
||||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||||
|
|
||||||
## LLM Selection: Engine and Model Comparison
|
## LLM Selection: Engine and Model Comparison
|
||||||
|
|
||||||
### LLM Engines
|
### LLM Engines
|
||||||
@@ -259,27 +267,118 @@ Every configuration choice involves trade-offs:
|
|||||||
|
|
||||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||||
|
|
||||||
## Deep Dive: Critical Configuration Decisions
|
## Low-resource setups
|
||||||
|
|
||||||
### When to Disable Recomputation
|
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
||||||
|
|
||||||
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
|
### 1) Use OpenAI embeddings (no local compute)
|
||||||
|
|
||||||
|
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--no-recompute # Disable selective recomputation
|
export OPENAI_API_KEY=sk-...
|
||||||
|
|
||||||
|
# Build with OpenAI embeddings
|
||||||
|
leann build my-index \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-3-small
|
||||||
|
|
||||||
|
# Search with OpenAI embeddings (recompute at query time)
|
||||||
|
leann search my-index "your query" \
|
||||||
|
--recompute
|
||||||
```
|
```
|
||||||
|
|
||||||
**Trade-offs**:
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
|
|
||||||
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# One-time: install and configure SkyPilot
|
||||||
|
pip install skypilot
|
||||||
|
|
||||||
|
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml
|
||||||
|
|
||||||
|
# Override parameters via -e key=value (optional)
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml \
|
||||||
|
-e index_name=my-index \
|
||||||
|
-e backend=hnsw \
|
||||||
|
-e embedding_mode=sentence-transformers \
|
||||||
|
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
||||||
|
|
||||||
|
# Copy the built index back to your local .leann (use rsync)
|
||||||
|
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3) Disable recomputation to trade storage for speed
|
||||||
|
|
||||||
|
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build without recomputation (HNSW requires non-compact in this mode)
|
||||||
|
leann build my-index --no-recompute --no-compact
|
||||||
|
|
||||||
|
# Search without recomputation
|
||||||
|
leann search my-index "your query" --no-recompute
|
||||||
|
```
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Extreme low latency requirements (high QPS, interactive assistants)
|
||||||
|
- Read-heavy workloads where storage is cheaper than latency
|
||||||
|
- No always-available GPU
|
||||||
|
|
||||||
|
Constraints:
|
||||||
|
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
||||||
|
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
||||||
|
|
||||||
|
Storage impact:
|
||||||
|
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
||||||
|
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
||||||
|
|
||||||
|
Converting an existing index (rebuild required):
|
||||||
|
```bash
|
||||||
|
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
||||||
|
leann build my-index --force --no-recompute --no-compact
|
||||||
|
```
|
||||||
|
|
||||||
|
Python API usage:
|
||||||
|
```python
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("/path/to/my-index.leann")
|
||||||
|
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Trade-offs:
|
||||||
|
- Lower latency and fewer network hops at query time
|
||||||
|
- Significantly higher storage (10–100× vs selective recomputation)
|
||||||
|
- Slightly larger memory footprint during build and search
|
||||||
|
|
||||||
|
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
||||||
|
|
||||||
|
- HNSW
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.818s, size=1.1MB
|
||||||
|
recompute=False: search_time=0.012s, size=16.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
- DiskANN
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.041s, size=5.9MB
|
||||||
|
recompute=False: search_time=0.013s, size=24.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
Conclusion:
|
||||||
|
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
||||||
|
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
||||||
|
|
||||||
|
|
||||||
**Disable when**:
|
|
||||||
- You have abundant storage and memory
|
|
||||||
- Need extremely low latency (< 100ms)
|
|
||||||
- Running a read-heavy workload where storage cost is acceptable
|
|
||||||
|
|
||||||
## Further Reading
|
## Further Reading
|
||||||
|
|
||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
|
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
## 🔥 Core Features
|
## 🔥 Core Features
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||||
|
|||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# LEANN Metadata Filtering Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Adding Metadata to Your Documents
|
||||||
|
|
||||||
|
When building your index, add metadata to each text chunk:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
|
||||||
|
# Add text with metadata
|
||||||
|
builder.add_text(
|
||||||
|
text="Chapter 1: Alice falls down the rabbit hole",
|
||||||
|
metadata={
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"themes": ["adventure", "curiosity"],
|
||||||
|
"word_count": 150
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.build_index("alice_in_wonderland_index")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Searching with Metadata Filters
|
||||||
|
|
||||||
|
Use the `metadata_filters` parameter in search calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="What happens to Alice?",
|
||||||
|
top_k=10,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||||
|
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Filter Syntax
|
||||||
|
|
||||||
|
### Basic Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"field_name": {"operator": value},
|
||||||
|
"another_field": {"operator": value}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
#### Comparison Operators
|
||||||
|
- `"=="`: Equal to
|
||||||
|
- `"!="`: Not equal to
|
||||||
|
- `"<"`: Less than
|
||||||
|
- `"<="`: Less than or equal
|
||||||
|
- `">"`: Greater than
|
||||||
|
- `">="`: Greater than or equal
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||||
|
{"page": {">": 100}} # Pages after 100
|
||||||
|
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||||
|
{"word_count": {"<": 500}} # Short passages
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Membership Operators
|
||||||
|
- `"in"`: Value is in list
|
||||||
|
- `"not_in"`: Value is not in list
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||||
|
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||||
|
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||||
|
```
|
||||||
|
|
||||||
|
#### String Operators
|
||||||
|
- `"contains"`: String contains substring
|
||||||
|
- `"starts_with"`: String starts with prefix
|
||||||
|
- `"ends_with"`: String ends with suffix
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||||
|
{"filename": {"ends_with": ".py"}} # Python files
|
||||||
|
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Boolean Operators
|
||||||
|
- `"is_true"`: Field is truthy
|
||||||
|
- `"is_false"`: Field is falsy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"is_published": {"is_true": True}} # Published content
|
||||||
|
{"is_draft": {"is_false": False}} # Not drafts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Operators on Same Field
|
||||||
|
|
||||||
|
You can apply multiple operators to the same field (AND logic):
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"word_count": {
|
||||||
|
">=": 100, # At least 100 words
|
||||||
|
"<=": 500 # At most 500 words
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compound Filters
|
||||||
|
|
||||||
|
Multiple fields are combined with AND logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"chapter": {"<=": 10}, # Up to chapter 10
|
||||||
|
"character": {"==": "Alice"}, # About Alice
|
||||||
|
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Case Examples
|
||||||
|
|
||||||
|
### 1. Spoiler-Free Book Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reader has only read up to chapter 5
|
||||||
|
def search_spoiler_free(query, max_chapter):
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": max_chapter},
|
||||||
|
"spoiler_level": {"in": ["none", "low"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Document Management by Date
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find recent documents
|
||||||
|
recent_docs = searcher.search(
|
||||||
|
query="project updates",
|
||||||
|
metadata_filters={
|
||||||
|
"date": {">=": "2024-01-01"},
|
||||||
|
"document_type": {"==": "report"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Code Search by File Type
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search only Python files
|
||||||
|
python_code = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Content Filtering by Audience
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Age-appropriate content
|
||||||
|
family_content = searcher.search(
|
||||||
|
query="adventure stories",
|
||||||
|
metadata_filters={
|
||||||
|
"age_rating": {"in": ["G", "PG"]},
|
||||||
|
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Book Series Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search across first 3 books only
|
||||||
|
early_series = searcher.search(
|
||||||
|
query="character development",
|
||||||
|
metadata_filters={
|
||||||
|
"series": {"==": "Harry Potter"},
|
||||||
|
"book_number": {"<=": 3}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Don't forget to set up the environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Run the spoiler-free book RAG example
|
||||||
|
uv run examples/spoiler_free_book_rag.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates:
|
||||||
|
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||||
|
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||||
|
- Different scenarios for readers at various points in the book
|
||||||
|
|
||||||
|
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Custom Chunking with metadata
|
||||||
|
|
||||||
|
```python
|
||||||
|
def chunk_book_with_metadata(book_text, book_info):
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||||
|
# Extract entities, themes, etc.
|
||||||
|
characters = extract_characters(chapter_text)
|
||||||
|
themes = classify_themes(chapter_text)
|
||||||
|
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||||
|
|
||||||
|
# Create chunks with rich metadata
|
||||||
|
for paragraph in split_paragraphs(chapter_text):
|
||||||
|
chunks.append({
|
||||||
|
"text": paragraph,
|
||||||
|
"metadata": {
|
||||||
|
"book_title": book_info["title"],
|
||||||
|
"chapter": chapter_num,
|
||||||
|
"characters": characters,
|
||||||
|
"themes": themes,
|
||||||
|
"spoiler_level": spoiler_level,
|
||||||
|
"word_count": len(paragraph.split()),
|
||||||
|
"reading_level": calculate_reading_level(paragraph)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Efficient Filtering Strategies
|
||||||
|
|
||||||
|
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||||
|
|
||||||
|
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||||
|
|
||||||
|
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||||
|
|
||||||
|
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||||
|
|
||||||
|
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||||
|
|
||||||
|
### Adding Metadata to Existing Indices
|
||||||
|
|
||||||
|
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Read existing passages and add metadata
|
||||||
|
def add_metadata_to_existing_chunks(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
# Extract or assign metadata based on content
|
||||||
|
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Rebuild index with metadata
|
||||||
|
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
for chunk in enhanced_chunks:
|
||||||
|
builder.add_text(chunk["text"], chunk["metadata"])
|
||||||
|
builder.build_index("enhanced_index")
|
||||||
|
```
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||||
|
|
||||||
|
This example demonstrates how to use LEANN's metadata filtering to create
|
||||||
|
a spoiler-free book RAG system where users can search for information
|
||||||
|
up to a specific chapter they've read.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python spoiler_free_book_rag.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Add LEANN to path (adjust path as needed)
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Create sample book chunks with metadata for demonstration.
|
||||||
|
|
||||||
|
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||||
|
and extract chapter boundaries, character mentions, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_title: Title of the book
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunk dictionaries with text and metadata
|
||||||
|
"""
|
||||||
|
# Sample book chunks with metadata
|
||||||
|
# In practice, you'd use proper text processing libraries
|
||||||
|
|
||||||
|
sample_chunks = [
|
||||||
|
{
|
||||||
|
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 1,
|
||||||
|
"characters": ["Alice", "Sister"],
|
||||||
|
"themes": ["boredom", "curiosity"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 2,
|
||||||
|
"characters": ["Alice", "White Rabbit"],
|
||||||
|
"themes": ["decision", "surprise", "magic"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 2,
|
||||||
|
"page": 15,
|
||||||
|
"characters": ["Alice"],
|
||||||
|
"themes": ["falling", "wonder", "transformation"],
|
||||||
|
"location": "rabbit hole",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 6,
|
||||||
|
"page": 85,
|
||||||
|
"characters": ["Alice", "Cheshire Cat"],
|
||||||
|
"themes": ["madness", "philosophy", "identity"],
|
||||||
|
"location": "Duchess's house",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 8,
|
||||||
|
"page": 120,
|
||||||
|
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||||
|
"themes": ["justice", "absurdity", "authority"],
|
||||||
|
"location": "Queen's court",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 12,
|
||||||
|
"page": 180,
|
||||||
|
"characters": ["Alice", "Sister", "Rabbit"],
|
||||||
|
"themes": ["revelation", "reality", "growth"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return sample_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a LEANN index with book chunks that include spoiler metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_chunks: List of book chunks with metadata
|
||||||
|
index_name: Name for the index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the built index
|
||||||
|
"""
|
||||||
|
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||||
|
|
||||||
|
# Initialize LEANN builder
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add each chunk with its metadata
|
||||||
|
for chunk in book_chunks:
|
||||||
|
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
index_path = f"{index_name}_book_index"
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
print(f"✅ Index built successfully: {index_path}")
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def spoiler_free_search(
|
||||||
|
index_path: str,
|
||||||
|
query: str,
|
||||||
|
max_chapter: int,
|
||||||
|
character_filter: Optional[list[str]] = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Perform a spoiler-free search on the book index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: Search query
|
||||||
|
max_chapter: Maximum chapter number to include
|
||||||
|
character_filter: Optional list of characters to focus on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results safe for the reader
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||||
|
|
||||||
|
if character_filter:
|
||||||
|
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||||
|
|
||||||
|
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def demo_spoiler_free_rag():
|
||||||
|
"""
|
||||||
|
Demonstrate the spoiler-free book RAG system.
|
||||||
|
"""
|
||||||
|
print("🎭 Spoiler-Free Book RAG Demo")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Step 1: Prepare book data
|
||||||
|
book_title = "Alice's Adventures in Wonderland"
|
||||||
|
book_chunks = chunk_book_with_metadata(book_title)
|
||||||
|
|
||||||
|
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||||
|
|
||||||
|
# Step 2: Build the index (in practice, this would be done once)
|
||||||
|
try:
|
||||||
|
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||||
|
print(
|
||||||
|
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Demonstrate various spoiler-free searches
|
||||||
|
search_scenarios = [
|
||||||
|
{
|
||||||
|
"description": "Reader who has only read Chapter 1",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read up to Chapter 5",
|
||||||
|
"query": "Tell me about Alice's adventures",
|
||||||
|
"max_chapter": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read most of the book",
|
||||||
|
"query": "What does the Cheshire Cat represent?",
|
||||||
|
"max_chapter": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read the whole book",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 12,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in search_scenarios:
|
||||||
|
print(f"\n📚 Scenario: {scenario['description']}")
|
||||||
|
print(f" Query: {scenario['query']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = spoiler_free_search(
|
||||||
|
index_path=index_path,
|
||||||
|
query=scenario["query"],
|
||||||
|
max_chapter=scenario["max_chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📄 Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||||
|
chapter = result.metadata.get("chapter", "?")
|
||||||
|
location = result.metadata.get("location", "?")
|
||||||
|
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Search failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||||
|
print("=====================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
demo_spoiler_free_rag()
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error running demo: {e}")
|
||||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# llms.txt — LEANN MCP and Agent Integration
|
||||||
|
product: LEANN
|
||||||
|
homepage: https://github.com/yichuan-w/LEANN
|
||||||
|
contact: https://github.com/yichuan-w/LEANN/issues
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
install: uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
# MCP Server Entry Point
|
||||||
|
mcp.server: leann_mcp
|
||||||
|
mcp.protocol_version: 2024-11-05
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
mcp.tools: leann_list, leann_search
|
||||||
|
|
||||||
|
mcp.tool.leann_list.description: List available LEANN indexes
|
||||||
|
mcp.tool.leann_list.input: {}
|
||||||
|
|
||||||
|
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||||
|
mcp.tool.leann_search.input.index_name: string, required
|
||||||
|
mcp.tool.leann_search.input.query: string, required
|
||||||
|
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||||
|
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||||
|
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
example.verify: claude mcp list | cat
|
||||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at a4537018a3
@@ -1 +1,7 @@
|
|||||||
from . import diskann_backend as diskann_backend
|
from . import diskann_backend as diskann_backend
|
||||||
|
from . import graph_partition
|
||||||
|
|
||||||
|
# Export main classes and functions
|
||||||
|
from .graph_partition import GraphPartitioner, partition_graph
|
||||||
|
|
||||||
|
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||||
|
|||||||
@@ -22,6 +22,11 @@ logger = logging.getLogger(__name__)
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def suppress_cpp_output_if_needed():
|
def suppress_cpp_output_if_needed():
|
||||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
||||||
|
if os.getenv("CI") == "true":
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
@@ -137,6 +142,71 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
|
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||||
|
"""
|
||||||
|
Safely cleanup files after partition.
|
||||||
|
In partition mode, C++ doesn't read _disk.index content,
|
||||||
|
so we can delete it if all derived files exist.
|
||||||
|
"""
|
||||||
|
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||||
|
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||||
|
|
||||||
|
# Required files that C++ partition mode needs
|
||||||
|
# Note: C++ generates these with _disk.index suffix
|
||||||
|
disk_suffix = "_disk.index"
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||||
|
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||||
|
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||||
|
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if all required files exist
|
||||||
|
missing_files = []
|
||||||
|
for filename in required_files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if not file_path.exists():
|
||||||
|
missing_files.append(filename)
|
||||||
|
|
||||||
|
if missing_files:
|
||||||
|
logger.warning(
|
||||||
|
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||||
|
)
|
||||||
|
logger.info("Keeping all original files for safety")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate space savings
|
||||||
|
space_saved = 0
|
||||||
|
files_to_delete = []
|
||||||
|
|
||||||
|
if disk_index_file.exists():
|
||||||
|
space_saved += disk_index_file.stat().st_size
|
||||||
|
files_to_delete.append(disk_index_file)
|
||||||
|
|
||||||
|
if beam_search_file.exists():
|
||||||
|
space_saved += beam_search_file.stat().st_size
|
||||||
|
files_to_delete.append(beam_search_file)
|
||||||
|
|
||||||
|
# Safe to delete!
|
||||||
|
for file_to_delete in files_to_delete:
|
||||||
|
try:
|
||||||
|
os.remove(file_to_delete)
|
||||||
|
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||||
|
|
||||||
|
if space_saved > 0:
|
||||||
|
space_saved_mb = space_saved / (1024 * 1024)
|
||||||
|
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||||
|
|
||||||
|
# Show what files are kept
|
||||||
|
logger.info("📁 Kept essential files for partition mode:")
|
||||||
|
for filename in required_files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -151,6 +221,17 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
|
|
||||||
|
# Extract is_recompute from nested backend_kwargs if needed
|
||||||
|
is_recompute = build_kwargs.get("is_recompute", False)
|
||||||
|
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||||
|
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||||
|
|
||||||
|
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||||
|
if "backend_kwargs" in build_kwargs:
|
||||||
|
nested_params = build_kwargs.pop("backend_kwargs")
|
||||||
|
build_kwargs.update(nested_params)
|
||||||
|
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
@@ -185,6 +266,30 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
build_kwargs.get("pq_disk_bytes", 0),
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Auto-partition if is_recompute is enabled
|
||||||
|
if build_kwargs.get("is_recompute", False):
|
||||||
|
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||||
|
from .graph_partition import partition_graph
|
||||||
|
|
||||||
|
# Partition the index using absolute paths
|
||||||
|
# Convert to absolute paths to avoid issues with working directory changes
|
||||||
|
absolute_index_dir = Path(index_dir).resolve()
|
||||||
|
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||||
|
disk_graph_path, partition_bin_path = partition_graph(
|
||||||
|
index_prefix_path=absolute_index_prefix_path,
|
||||||
|
output_dir=str(absolute_index_dir),
|
||||||
|
partition_prefix=index_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||||
|
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||||
|
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||||
|
|
||||||
|
logger.info("✅ Graph partitioning completed successfully!")
|
||||||
|
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||||
|
logger.info(f" - Partition file: {partition_bin_path}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
@@ -213,7 +318,26 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
# Store the initialization parameters for later use
|
# Store the initialization parameters for later use
|
||||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||||
|
# C++ internally constructs: index_prefix + "_disk.index"
|
||||||
|
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||||
|
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||||
|
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||||
|
|
||||||
|
# Auto-detect partition files and set partition_prefix
|
||||||
|
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||||
|
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||||
|
|
||||||
|
partition_prefix = ""
|
||||||
|
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||||
|
# C++ expects full path prefix, not just filename
|
||||||
|
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||||
|
logger.info(
|
||||||
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No partition files detected, using standard index files")
|
||||||
|
|
||||||
self._init_params = {
|
self._init_params = {
|
||||||
"metric_enum": metric_enum,
|
"metric_enum": metric_enum,
|
||||||
"full_index_prefix": full_index_prefix,
|
"full_index_prefix": full_index_prefix,
|
||||||
@@ -221,8 +345,14 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
"cache_mechanism": 1,
|
"cache_mechanism": 1,
|
||||||
"pq_prefix": "",
|
"pq_prefix": "",
|
||||||
"partition_prefix": "",
|
"partition_prefix": partition_prefix,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Log partition configuration for debugging
|
||||||
|
if partition_prefix:
|
||||||
|
logger.info(
|
||||||
|
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||||
|
)
|
||||||
self._diskannpy = diskannpy
|
self._diskannpy = diskannpy
|
||||||
self._current_zmq_port = None
|
self._current_zmq_port = None
|
||||||
self._index = None
|
self._index = None
|
||||||
@@ -311,9 +441,14 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
else: # "global"
|
else: # "global"
|
||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
# Perform search with suppressed C++ output based on log level
|
# Strategy:
|
||||||
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
# - Traversal always uses PQ distances
|
||||||
recompute_neighors = False
|
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
|
||||||
|
# (fetch embeddings for the final candidate set only)
|
||||||
|
# - Do not recompute neighbor distances along the path
|
||||||
|
use_deferred_fetch = True if recompute_embeddings else False
|
||||||
|
recompute_neighors = False # Expected typo. For backward compatibility.
|
||||||
|
|
||||||
with suppress_cpp_output_if_needed():
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
|
|||||||
@@ -81,10 +81,9 @@ def create_diskann_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||||
logger.info(
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||||
)
|
|
||||||
|
|
||||||
# Import protobuf after ensuring the path is correct
|
# Import protobuf after ensuring the path is correct
|
||||||
try:
|
try:
|
||||||
@@ -102,8 +101,9 @@ def create_diskann_embedding_server(
|
|||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -220,30 +220,217 @@ def create_diskann_embedding_server(
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
This creates its own REP socket, binds to zmq_port, and periodically
|
||||||
|
checks shutdown_event using recv timeouts to exit cleanly.
|
||||||
|
"""
|
||||||
|
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
rep_socket = context.socket(zmq.REP)
|
||||||
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
# Set receive timeout so we can check shutdown_event periodically
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
e2e_start = time.time()
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = rep_socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if not message:
|
||||||
|
logger.warning("Received empty message, sending empty response")
|
||||||
|
rep_socket.send(b"")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try protobuf first (same logic as original)
|
||||||
|
texts = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
# Look up texts by node IDs
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
|
||||||
|
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
||||||
|
except Exception:
|
||||||
|
# Fallback to msgpack for text requests
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(
|
||||||
|
f"ZMQ received msgpack text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception:
|
||||||
|
logger.error("Both protobuf and msgpack parsing failed!")
|
||||||
|
# Send error response
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process the request
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error("NaN or Inf detected in embeddings!")
|
||||||
|
# Send error response
|
||||||
|
if is_text_request:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb([])
|
||||||
|
else:
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For direct text requests, return msgpack
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For protobuf requests, return protobuf
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
try:
|
||||||
|
# Send error response for REP socket
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
||||||
|
|
||||||
|
# Add shutdown coordination
|
||||||
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
|
def shutdown_zmq_server():
|
||||||
|
"""Gracefully shutdown ZMQ server."""
|
||||||
|
logger.info("Initiating graceful shutdown...")
|
||||||
|
shutdown_event.set()
|
||||||
|
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.info("Waiting for ZMQ thread to finish...")
|
||||||
|
zmq_thread.join(timeout=5)
|
||||||
|
if zmq_thread.is_alive():
|
||||||
|
logger.warning("ZMQ thread did not finish in time")
|
||||||
|
|
||||||
|
# Clean up ZMQ resources
|
||||||
|
try:
|
||||||
|
# Note: socket and context are cleaned up by thread exit
|
||||||
|
logger.info("ZMQ resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||||
|
|
||||||
|
# Clean up other resources
|
||||||
|
try:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Additional resources cleaned up")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning additional resources: {e}")
|
||||||
|
|
||||||
|
logger.info("Graceful shutdown completed")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers within this function scope
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
shutdown_zmq_server()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Start ZMQ thread (NOT daemon!)
|
||||||
|
zmq_thread = threading.Thread(
|
||||||
|
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||||
|
daemon=False, # Not daemon - we want to wait for it
|
||||||
|
)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(0.1) # Check shutdown more frequently
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("DiskANN Server shutting down...")
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
shutdown_zmq_server()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we reach here, shutdown was triggered by signal
|
||||||
|
logger.info("Main loop exited, process should be shutting down")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# Signal handlers are now registered within create_diskann_embedding_server
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
|||||||
@@ -0,0 +1,299 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Graph Partition Module for LEANN DiskANN Backend
|
||||||
|
|
||||||
|
This module provides Python bindings for the graph partition functionality
|
||||||
|
of DiskANN, allowing users to partition disk-based indices for better
|
||||||
|
performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class GraphPartitioner:
|
||||||
|
"""
|
||||||
|
A Python interface for DiskANN's graph partition functionality.
|
||||||
|
|
||||||
|
This class provides methods to partition disk-based indices for improved
|
||||||
|
search performance and memory efficiency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, build_type: str = "release"):
|
||||||
|
"""
|
||||||
|
Initialize the GraphPartitioner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
build_type: Build type for the executables ("debug" or "release")
|
||||||
|
"""
|
||||||
|
self.build_type = build_type
|
||||||
|
self._ensure_executables()
|
||||||
|
|
||||||
|
def _get_executable_path(self, name: str) -> str:
|
||||||
|
"""Get the path to a graph partition executable."""
|
||||||
|
# Get the directory where this Python module is located
|
||||||
|
module_dir = Path(__file__).parent
|
||||||
|
# Navigate to the graph_partition directory
|
||||||
|
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||||
|
|
||||||
|
if not executable_path.exists():
|
||||||
|
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||||
|
|
||||||
|
return str(executable_path)
|
||||||
|
|
||||||
|
def _ensure_executables(self):
|
||||||
|
"""Ensure that the required executables are built."""
|
||||||
|
try:
|
||||||
|
self._get_executable_path("partitioner")
|
||||||
|
self._get_executable_path("index_relayout")
|
||||||
|
except FileNotFoundError:
|
||||||
|
# Try to build the executables automatically
|
||||||
|
print("Executables not found, attempting to build them...")
|
||||||
|
self._build_executables()
|
||||||
|
|
||||||
|
def _build_executables(self):
|
||||||
|
"""Build the required executables."""
|
||||||
|
graph_partition_dir = (
|
||||||
|
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.chdir(graph_partition_dir)
|
||||||
|
|
||||||
|
# Clean any existing build
|
||||||
|
if (graph_partition_dir / "build").exists():
|
||||||
|
shutil.rmtree(graph_partition_dir / "build")
|
||||||
|
|
||||||
|
# Run the build script
|
||||||
|
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||||
|
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||||
|
|
||||||
|
# Check if executables were created
|
||||||
|
partitioner_path = self._get_executable_path("partitioner")
|
||||||
|
relayout_path = self._get_executable_path("index_relayout")
|
||||||
|
|
||||||
|
print(f"✅ Built partitioner: {partitioner_path}")
|
||||||
|
print(f"✅ Built index_relayout: {relayout_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to build executables: {e}")
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
def partition_graph(
|
||||||
|
self,
|
||||||
|
index_prefix_path: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
partition_prefix: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Partition a disk-based index for improved performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||||
|
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||||
|
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||||
|
**kwargs: Additional parameters for graph partitioning:
|
||||||
|
- gp_times: Number of LDG partition iterations (default: 10)
|
||||||
|
- lock_nums: Number of lock nodes (default: 10)
|
||||||
|
- cut: Cut adjacency list degree (default: 100)
|
||||||
|
- scale_factor: Scale factor (default: 1)
|
||||||
|
- data_type: Data type (default: "float")
|
||||||
|
- thread_nums: Number of threads (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the partitioning process fails
|
||||||
|
"""
|
||||||
|
# Set default parameters
|
||||||
|
params = {
|
||||||
|
"gp_times": 10,
|
||||||
|
"lock_nums": 10,
|
||||||
|
"cut": 100,
|
||||||
|
"scale_factor": 1,
|
||||||
|
"data_type": "float",
|
||||||
|
"thread_nums": 10,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine output directory
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = str(Path(index_prefix_path).parent)
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine partition prefix
|
||||||
|
if partition_prefix is None:
|
||||||
|
partition_prefix = Path(index_prefix_path).name
|
||||||
|
|
||||||
|
# Get executable paths
|
||||||
|
partitioner_path = self._get_executable_path("partitioner")
|
||||||
|
relayout_path = self._get_executable_path("index_relayout")
|
||||||
|
|
||||||
|
# Create temporary directory for processing
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Change to the graph_partition directory for temporary files
|
||||||
|
graph_partition_dir = (
|
||||||
|
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||||
|
)
|
||||||
|
original_dir = os.getcwd()
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.chdir(graph_partition_dir)
|
||||||
|
|
||||||
|
# Create temporary data directory
|
||||||
|
temp_data_dir = Path(temp_dir) / "data"
|
||||||
|
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Set up paths for temporary files
|
||||||
|
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||||
|
graph_gp_path = (
|
||||||
|
graph_path
|
||||||
|
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||||
|
)
|
||||||
|
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Find input index file
|
||||||
|
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||||
|
if not os.path.exists(old_index_file):
|
||||||
|
old_index_file = f"{index_prefix_path}_disk.index"
|
||||||
|
|
||||||
|
if not os.path.exists(old_index_file):
|
||||||
|
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||||
|
|
||||||
|
# Run partitioner
|
||||||
|
gp_file_path = graph_gp_path / "_part.bin"
|
||||||
|
partitioner_cmd = [
|
||||||
|
partitioner_path,
|
||||||
|
"--index_file",
|
||||||
|
old_index_file,
|
||||||
|
"--data_type",
|
||||||
|
params["data_type"],
|
||||||
|
"--gp_file",
|
||||||
|
str(gp_file_path),
|
||||||
|
"-T",
|
||||||
|
str(params["thread_nums"]),
|
||||||
|
"--ldg_times",
|
||||||
|
str(params["gp_times"]),
|
||||||
|
"--scale",
|
||||||
|
str(params["scale_factor"]),
|
||||||
|
"--mode",
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||||
|
result = subprocess.run(
|
||||||
|
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Partitioner failed with return code {result.returncode}.\n"
|
||||||
|
f"stdout: {result.stdout}\n"
|
||||||
|
f"stderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run relayout
|
||||||
|
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||||
|
relayout_cmd = [
|
||||||
|
relayout_path,
|
||||||
|
old_index_file,
|
||||||
|
str(gp_file_path),
|
||||||
|
params["data_type"],
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||||
|
result = subprocess.run(
|
||||||
|
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Relayout failed with return code {result.returncode}.\n"
|
||||||
|
f"stdout: {result.stdout}\n"
|
||||||
|
f"stderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy results to output directory
|
||||||
|
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||||
|
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||||
|
|
||||||
|
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||||
|
shutil.copy2(gp_file_path, partition_bin_path)
|
||||||
|
|
||||||
|
print(f"Results copied to: {output_dir}")
|
||||||
|
return str(disk_graph_path), str(partition_bin_path)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about a partition file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
partition_bin_path: Path to the partition binary file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing partition information
|
||||||
|
"""
|
||||||
|
if not os.path.exists(partition_bin_path):
|
||||||
|
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||||
|
|
||||||
|
# For now, return basic file information
|
||||||
|
# In the future, this could parse the binary file for detailed info
|
||||||
|
stat = os.stat(partition_bin_path)
|
||||||
|
return {
|
||||||
|
"file_size": stat.st_size,
|
||||||
|
"file_path": partition_bin_path,
|
||||||
|
"modified_time": stat.st_mtime,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def partition_graph(
|
||||||
|
index_prefix_path: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
partition_prefix: Optional[str] = None,
|
||||||
|
build_type: str = "release",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Convenience function to partition a graph index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_prefix_path: Path to the index prefix
|
||||||
|
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||||
|
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||||
|
build_type: Build type for executables ("debug" or "release")
|
||||||
|
**kwargs: Additional parameters for graph partitioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||||
|
"""
|
||||||
|
partitioner = GraphPartitioner(build_type=build_type)
|
||||||
|
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example: partition an index
|
||||||
|
try:
|
||||||
|
disk_graph_path, partition_bin_path = partition_graph(
|
||||||
|
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||||
|
)
|
||||||
|
print("Partitioning completed successfully!")
|
||||||
|
print(f"Disk graph index: {disk_graph_path}")
|
||||||
|
print(f"Partition binary: {partition_bin_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Partitioning failed: {e}")
|
||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.2.8"
|
version = "0.3.4"
|
||||||
dependencies = ["leann-core==0.2.8", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 04048bb302...19f9603c72
@@ -49,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
# Disable additional SIMD versions to speed up compilation
|
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
||||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# ARM64-specific configuration
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
message(STATUS "Configuring Faiss for ARM64 architecture")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
||||||
|
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
||||||
|
else()
|
||||||
|
# Use generic optimization for other ARM64 platforms (like macOS)
|
||||||
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
||||||
|
message(STATUS "Using ARM64-compatible Faiss submodule")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Additional optimization options from INSTALL.md
|
# Additional optimization options from INSTALL.md
|
||||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc # Import garbage collector interface
|
import gc # Import garbage collector interface
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
@@ -7,6 +8,12 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# Set up logging to avoid print buffer issues
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
@@ -243,6 +250,8 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
output_filename: Output CSR index file
|
output_filename: Output CSR index file
|
||||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||||
"""
|
"""
|
||||||
|
# Keep prints simple; rely on CI runner to flush output as needed
|
||||||
|
|
||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
@@ -54,12 +55,13 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||||
self.dimensions = self.build_params.get("dimensions")
|
self.dimensions = self.build_params.get("dimensions")
|
||||||
if not self.is_recompute:
|
if not self.is_recompute and self.is_compact:
|
||||||
if self.is_compact:
|
# Auto-correct: non-recompute requires non-compact storage for HNSW
|
||||||
# TODO: support this case @andy
|
logger.warning(
|
||||||
raise ValueError(
|
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
|
||||||
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
)
|
||||||
)
|
self.is_compact = False
|
||||||
|
self.build_params["is_compact"] = False
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
@@ -184,9 +186,11 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
"""
|
"""
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
if not recompute_embeddings:
|
if not recompute_embeddings and self.is_pruned:
|
||||||
if self.is_pruned:
|
raise RuntimeError(
|
||||||
raise RuntimeError("Recompute is required for pruned index.")
|
"Recompute is required for pruned/compact HNSW index. "
|
||||||
|
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
|
||||||
|
)
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
@@ -233,6 +237,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|
||||||
|
search_time = time.time()
|
||||||
self._index.search(
|
self._index.search(
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
faiss.swig_ptr(query),
|
faiss.swig_ptr(query),
|
||||||
@@ -241,7 +246,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.swig_ptr(labels),
|
faiss.swig_ptr(labels),
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
search_time = time.time() - search_time
|
||||||
|
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,7 +34,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Union[str, None] = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -82,199 +82,315 @@ def create_hnsw_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Convert relative paths to absolute paths based on metadata file location
|
# Let PassageManager handle path resolution uniformly. It supports fallback order:
|
||||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
|
||||||
passage_sources = []
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
for source in meta["passage_sources"]:
|
# Dimension from metadata for shaping responses
|
||||||
source_copy = source.copy()
|
try:
|
||||||
# Convert relative paths to absolute paths
|
embedding_dim: int = int(meta.get("dimensions", 0))
|
||||||
if not Path(source_copy["path"]).is_absolute():
|
except Exception:
|
||||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
embedding_dim = 0
|
||||||
if not Path(source_copy["index_path"]).is_absolute():
|
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
|
||||||
passage_sources.append(source_copy)
|
|
||||||
|
|
||||||
passages = PassageManager(passage_sources)
|
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||||
logger.info(
|
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
)
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
Creates its own REP socket bound to zmq_port and polls with timeouts
|
||||||
|
to allow graceful shutdown.
|
||||||
|
"""
|
||||||
|
logger.info("ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
def zmq_server_thread():
|
|
||||||
"""ZMQ server thread"""
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
rep_socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
# Keep sends from blocking during shutdown; fail fast and drop on close
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
# Track last request type/length for shape-correct fallbacks
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
||||||
|
last_request_length = 0
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
try:
|
while not shutdown_event.is_set():
|
||||||
message_bytes = socket.recv()
|
try:
|
||||||
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
e2e_start = time.time()
|
||||||
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
|
request_bytes = rep_socket.recv()
|
||||||
|
|
||||||
e2e_start = time.time()
|
# Rest of the processing logic (same as original)
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
request = msgpack.unpackb(request_bytes)
|
||||||
|
|
||||||
# Handle direct text embedding request
|
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
response_bytes = msgpack.packb([model_name])
|
||||||
# Check if this is a direct text request (list of strings)
|
rep_socket.send(response_bytes)
|
||||||
if all(isinstance(item, str) for item in request_payload):
|
continue
|
||||||
logger.info(
|
|
||||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use unified embedding computation (now with model caching)
|
# Handle direct text embedding request
|
||||||
embeddings = compute_embeddings(
|
if (
|
||||||
request_payload, model_name, mode=embedding_mode
|
isinstance(request, list)
|
||||||
)
|
and request
|
||||||
|
and all(isinstance(item, str) for item in request)
|
||||||
response = embeddings.tolist()
|
):
|
||||||
socket.send(msgpack.packb(response))
|
last_request_type = "text"
|
||||||
|
last_request_length = len(request)
|
||||||
|
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Handle distance calculation request: [[ids], [query_vector]]
|
||||||
if (
|
if (
|
||||||
isinstance(request_payload, list)
|
isinstance(request, list)
|
||||||
and len(request_payload) == 2
|
and len(request) == 2
|
||||||
and isinstance(request_payload[0], list)
|
and isinstance(request[0], list)
|
||||||
and isinstance(request_payload[1], list)
|
and isinstance(request[1], list)
|
||||||
):
|
):
|
||||||
node_ids = request_payload[0]
|
node_ids = request[0]
|
||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
# Handle nested [[ids]] shape defensively
|
||||||
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
|
node_ids = node_ids[0]
|
||||||
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
|
last_request_type = "distance"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
# Get embeddings for node IDs
|
# Gather texts for found ids
|
||||||
texts = []
|
texts: list[str] = []
|
||||||
for nid in node_ids:
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {nid}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
|
||||||
|
# Prepare full-length response with large sentinel values
|
||||||
|
large_distance = 1e9
|
||||||
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts, model_name, mode=embedding_mode
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
if distance_metric == "l2":
|
||||||
|
partial = np.sum(
|
||||||
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
|
)
|
||||||
|
else: # mips or cosine
|
||||||
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
|
response_distances[pos] = float(dval)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Distance computation error, using sentinels: {e}")
|
||||||
|
|
||||||
|
# Send response in expected shape [[distances]]
|
||||||
|
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fallback: treat as embedding-by-id request
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
):
|
||||||
|
node_ids = request[0]
|
||||||
|
elif isinstance(request, list):
|
||||||
|
node_ids = request
|
||||||
|
else:
|
||||||
|
node_ids = []
|
||||||
|
last_request_type = "embedding"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
|
# Preallocate zero-filled flat data for robustness
|
||||||
|
if embedding_dim <= 0:
|
||||||
|
dims = [0, 0]
|
||||||
|
flat_data: list[float] = []
|
||||||
|
else:
|
||||||
|
dims = [len(node_ids), embedding_dim]
|
||||||
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
|
# Collect texts for found ids
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data.get("text", "")
|
||||||
texts.append(txt)
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {nid}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage ID {nid} not found")
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
|
||||||
|
|
||||||
# Process embeddings
|
if texts:
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
try:
|
||||||
logger.info(
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
logger.info(
|
||||||
)
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate distances
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
if distance_metric == "l2":
|
logger.error(
|
||||||
distances = np.sum(
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
)
|
||||||
)
|
dims = [0, embedding_dim]
|
||||||
else: # mips or cosine
|
flat_data = []
|
||||||
distances = -np.dot(embeddings, query_vector)
|
else:
|
||||||
|
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
flat = emb_f32.flatten().tolist()
|
||||||
|
for j, pos in enumerate(found_indices):
|
||||||
|
start = pos * embedding_dim
|
||||||
|
end = start + embedding_dim
|
||||||
|
if end <= len(flat_data):
|
||||||
|
flat_data[start:end] = flat[
|
||||||
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||||
|
|
||||||
response_payload = distances.flatten().tolist()
|
response_payload = [dims, flat_data]
|
||||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
logger.debug(f"Sending distance response with {len(distances)} distances")
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
rep_socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
if not shutdown_event.is_set():
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
# Shape-correct fallback
|
||||||
|
try:
|
||||||
|
if last_request_type == "distance":
|
||||||
|
large_distance = 1e9
|
||||||
|
fallback_len = max(0, int(last_request_length))
|
||||||
|
safe = [[large_distance] * fallback_len]
|
||||||
|
elif last_request_type == "embedding":
|
||||||
|
bsz = max(0, int(last_request_length))
|
||||||
|
dim = max(0, int(embedding_dim))
|
||||||
|
safe = (
|
||||||
|
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
||||||
|
)
|
||||||
|
elif last_request_type == "text":
|
||||||
|
safe = [] # direct text embeddings expectation is a flat list
|
||||||
|
else:
|
||||||
|
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Standard embedding request (passage ID lookup)
|
logger.info("ZMQ server thread exiting gracefully")
|
||||||
if (
|
|
||||||
not isinstance(request_payload, list)
|
|
||||||
or len(request_payload) != 1
|
|
||||||
or not isinstance(request_payload[0], list)
|
|
||||||
):
|
|
||||||
logger.error(
|
|
||||||
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
|
||||||
)
|
|
||||||
socket.send(msgpack.packb([[], []]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_ids = request_payload[0]
|
# Add shutdown coordination
|
||||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
# Look up texts by node IDs
|
def shutdown_zmq_server():
|
||||||
texts = []
|
"""Gracefully shutdown ZMQ server."""
|
||||||
for nid in node_ids:
|
logger.info("Initiating graceful shutdown...")
|
||||||
try:
|
shutdown_event.set()
|
||||||
passage_data = passages.get_passage(str(nid))
|
|
||||||
txt = passage_data["text"]
|
|
||||||
if not txt:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Process embeddings
|
if zmq_thread.is_alive():
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
logger.info("Waiting for ZMQ thread to finish...")
|
||||||
logger.info(
|
zmq_thread.join(timeout=5)
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
if zmq_thread.is_alive():
|
||||||
)
|
logger.warning("ZMQ thread did not finish in time")
|
||||||
|
|
||||||
# Serialization and response
|
# Clean up ZMQ resources
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
try:
|
||||||
logger.error(
|
# Note: socket and context are cleaned up by thread exit
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
logger.info("ZMQ resources cleaned up")
|
||||||
)
|
except Exception as e:
|
||||||
raise AssertionError()
|
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
||||||
|
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
# Clean up other resources
|
||||||
response_payload = [
|
try:
|
||||||
list(hidden_contiguous_f32.shape),
|
import gc
|
||||||
hidden_contiguous_f32.flatten().tolist(),
|
|
||||||
]
|
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
gc.collect()
|
||||||
e2e_end = time.time()
|
logger.info("Additional resources cleaned up")
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
except Exception as e:
|
||||||
|
logger.warning(f"Error cleaning additional resources: {e}")
|
||||||
|
|
||||||
except zmq.Again:
|
logger.info("Graceful shutdown completed")
|
||||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
sys.exit(0)
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
# Register signal handlers within this function scope
|
||||||
socket.send(msgpack.packb([[], []]))
|
import signal
|
||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
shutdown_zmq_server()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Pass shutdown_event to ZMQ thread
|
||||||
|
zmq_thread = threading.Thread(
|
||||||
|
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
||||||
|
daemon=False, # Not daemon - we want to wait for it
|
||||||
|
)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while not shutdown_event.is_set():
|
||||||
time.sleep(1)
|
time.sleep(0.1) # Check shutdown more frequently
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
|
shutdown_zmq_server()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If we reach here, shutdown was triggered by signal
|
||||||
|
logger.info("Main loop exited, process should be shutting down")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# Signal handlers are now registered within create_hnsw_embedding_server
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.2.8"
|
version = "0.3.4"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.2.8",
|
"leann-core==0.3.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 4a2c0d67d3...ed96ff7dba
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.2.8"
|
version = "0.3.4"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -18,6 +20,7 @@ from leann.interface import LeannBackendSearcherInterface
|
|||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
|
from .metadata_filter import MetadataFilterEngine
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -46,6 +49,7 @@ def compute_embeddings(
|
|||||||
- "sentence-transformers": Use sentence-transformers library (default)
|
- "sentence-transformers": Use sentence-transformers library (default)
|
||||||
- "mlx": Use MLX backend for Apple Silicon
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
- "openai": Use OpenAI embedding API
|
- "openai": Use OpenAI embedding API
|
||||||
|
- "gemini": Use Google Gemini embedding API
|
||||||
use_server: Whether to use embedding server (True for search, False for build)
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -115,42 +119,156 @@ class SearchResult:
|
|||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
def __init__(
|
||||||
self.offset_maps = {}
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||||
self.passage_files = {}
|
):
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.offset_maps: dict[str, dict[str, int]] = {}
|
||||||
|
self.passage_files: dict[str, str] = {}
|
||||||
|
# Avoid materializing a single gigantic global map to reduce memory
|
||||||
|
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
||||||
|
# per-shard maps and do a lightweight per-shard lookup on demand.
|
||||||
|
self._total_count: int = 0
|
||||||
|
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
|
||||||
|
|
||||||
|
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
||||||
|
index_name_base = None
|
||||||
|
if metadata_file_path:
|
||||||
|
meta_name = Path(metadata_file_path).name
|
||||||
|
if meta_name.endswith(".meta.json"):
|
||||||
|
index_name_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source.get("path", "")
|
||||||
index_file = source["index_path"] # .idx file
|
index_file = source.get("index_path", "") # .idx file
|
||||||
|
|
||||||
# Fix path resolution for Colab and other environments
|
# Fix path resolution - relative paths should be relative to metadata file directory
|
||||||
if not Path(index_file).is_absolute():
|
def _resolve_candidates(
|
||||||
# If relative path, try to resolve it properly
|
primary: str,
|
||||||
index_file = str(Path(index_file).resolve())
|
relative_key: str,
|
||||||
|
default_name: Optional[str],
|
||||||
|
source_dict: dict[str, Any],
|
||||||
|
) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Build an ordered list of candidate paths. For relative paths specified in
|
||||||
|
metadata, prefer resolution relative to the metadata file directory first,
|
||||||
|
then fall back to CWD-based resolution, and finally to conventional
|
||||||
|
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
|
||||||
|
"""
|
||||||
|
candidates: list[Path] = []
|
||||||
|
# 1) Primary path
|
||||||
|
if primary:
|
||||||
|
p = Path(primary)
|
||||||
|
if p.is_absolute():
|
||||||
|
candidates.append(p)
|
||||||
|
else:
|
||||||
|
# Prefer metadata-relative resolution for relative paths
|
||||||
|
if metadata_file_path:
|
||||||
|
candidates.append(Path(metadata_file_path).parent / p)
|
||||||
|
# Also consider CWD-relative as a fallback for legacy layouts
|
||||||
|
candidates.append(Path.cwd() / p)
|
||||||
|
# 2) metadata-relative explicit relative key (if present)
|
||||||
|
if metadata_file_path and source_dict.get(relative_key):
|
||||||
|
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
||||||
|
# 3) metadata-relative standard sibling filename
|
||||||
|
if metadata_file_path and default_name:
|
||||||
|
candidates.append(Path(metadata_file_path).parent / default_name)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
|
||||||
|
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
|
||||||
|
idx_candidates = _resolve_candidates(
|
||||||
|
index_file, "index_path_relative", idx_default, source
|
||||||
|
)
|
||||||
|
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
|
||||||
|
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
|
||||||
|
|
||||||
|
def _pick_existing(cands: list[Path]) -> str:
|
||||||
|
for c in cands:
|
||||||
|
if c.exists():
|
||||||
|
return str(c.resolve())
|
||||||
|
# Fallback to last candidate (best guess) even if not exists; will error below
|
||||||
|
return str(cands[-1].resolve()) if cands else ""
|
||||||
|
|
||||||
|
index_file = _pick_existing(idx_candidates)
|
||||||
|
passage_file = _pick_existing(pas_candidates)
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
offset_map = pickle.load(f)
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
self.offset_maps[passage_file] = offset_map
|
self.offset_maps[passage_file] = offset_map
|
||||||
self.passage_files[passage_file] = passage_file
|
self.passage_files[passage_file] = passage_file
|
||||||
|
self._total_count += len(offset_map)
|
||||||
# Build global map for O(1) lookup
|
|
||||||
for passage_id, offset in offset_map.items():
|
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
# Fast path: check each shard map (there are typically few shards).
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
# This avoids building a massive combined dict while keeping lookups
|
||||||
# Lazy file opening - only open when needed
|
# bounded by the number of shards.
|
||||||
with open(passage_file, encoding="utf-8") as f:
|
for passage_file, offset_map in self.offset_maps.items():
|
||||||
f.seek(offset)
|
try:
|
||||||
return json.loads(f.readline())
|
offset = offset_map[passage_id]
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
f.seek(offset)
|
||||||
|
return json.loads(f.readline())
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
|
def filter_search_results(
|
||||||
|
self,
|
||||||
|
search_results: list[SearchResult],
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of SearchResult objects
|
||||||
|
metadata_filters: Filter specifications to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of SearchResult objects
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying metadata filters to {len(search_results)} results")
|
||||||
|
|
||||||
|
# Convert SearchResult objects to dictionaries for the filter engine
|
||||||
|
result_dicts = []
|
||||||
|
for result in search_results:
|
||||||
|
result_dicts.append(
|
||||||
|
{
|
||||||
|
"id": result.id,
|
||||||
|
"score": result.score,
|
||||||
|
"text": result.text,
|
||||||
|
"metadata": result.metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters using the filter engine
|
||||||
|
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
|
||||||
|
|
||||||
|
# Convert back to SearchResult objects
|
||||||
|
filtered_results = []
|
||||||
|
for result_dict in filtered_dicts:
|
||||||
|
filtered_results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=result_dict["id"],
|
||||||
|
score=result_dict["score"],
|
||||||
|
text=result_dict["text"],
|
||||||
|
metadata=result_dict["metadata"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._total_count
|
||||||
|
|
||||||
|
|
||||||
class LeannBuilder:
|
class LeannBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -162,6 +280,18 @@ class LeannBuilder:
|
|||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
|
# Normalize incompatible combinations early (for consistent metadata)
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
is_recompute = backend_kwargs.get("is_recompute", True)
|
||||||
|
is_compact = backend_kwargs.get("is_compact", True)
|
||||||
|
if is_recompute is False and is_compact is True:
|
||||||
|
warnings.warn(
|
||||||
|
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
backend_kwargs["is_compact"] = False
|
||||||
|
|
||||||
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
@@ -252,6 +382,23 @@ class LeannBuilder:
|
|||||||
def build_index(self, index_path: str):
|
def build_index(self, index_path: str):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
raise ValueError("No chunks added.")
|
raise ValueError("No chunks added.")
|
||||||
|
|
||||||
|
# Filter out invalid/empty text chunks early to keep passage and embedding counts aligned
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
skipped = 0
|
||||||
|
for chunk in self.chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if isinstance(text, str) and text.strip():
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
skipped += 1
|
||||||
|
if skipped > 0:
|
||||||
|
print(
|
||||||
|
f"Warning: Skipping {skipped} empty/invalid text chunk(s). Processing {len(valid_chunks)} valid chunks"
|
||||||
|
)
|
||||||
|
self.chunks = valid_chunks
|
||||||
|
if not self.chunks:
|
||||||
|
raise ValueError("All provided chunks are empty or invalid. Nothing to index.")
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = len(
|
self.dimensions = len(
|
||||||
compute_embeddings(
|
compute_embeddings(
|
||||||
@@ -314,8 +461,12 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": str(passages_file),
|
# Preserve existing relative file names (backward-compatible)
|
||||||
"index_path": str(offset_file),
|
"path": passages_file.name,
|
||||||
|
"index_path": offset_file.name,
|
||||||
|
# Add optional redundant relative keys for remote build portability (non-breaking)
|
||||||
|
"path_relative": passages_file.name,
|
||||||
|
"index_path_relative": offset_file.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -430,8 +581,12 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
"path": str(passages_file),
|
# Preserve existing relative file names (backward-compatible)
|
||||||
"index_path": str(offset_file),
|
"path": passages_file.name,
|
||||||
|
"index_path": offset_file.name,
|
||||||
|
# Add optional redundant relative keys for remote build portability (non-breaking)
|
||||||
|
"path_relative": passages_file.name,
|
||||||
|
"index_path_relative": offset_file.name,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"built_from_precomputed_embeddings": True,
|
"built_from_precomputed_embeddings": True,
|
||||||
@@ -473,7 +628,12 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
# Delegate portability handling to PassageManager
|
||||||
|
self.passage_manager = PassageManager(
|
||||||
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
|
)
|
||||||
|
# Preserve backend name for conditional parameter forwarding
|
||||||
|
self.backend_name = backend_name
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
@@ -493,15 +653,49 @@ class LeannSearcher:
|
|||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors with optional metadata filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text query to search for
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||||
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
Supported operators:
|
||||||
|
- Comparison: "==", "!=", "<", "<=", ">", ">="
|
||||||
|
- Membership: "in", "not_in"
|
||||||
|
- String: "contains", "starts_with", "ends_with"
|
||||||
|
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
|
||||||
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects with text, metadata, and similarity scores
|
||||||
|
"""
|
||||||
|
# Handle grep search
|
||||||
|
if use_grep:
|
||||||
|
return self._grep_search(query, top_k)
|
||||||
|
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
|
logger.info(f" Metadata filters: {metadata_filters}")
|
||||||
logger.info(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Smart top_k detection and adjustment
|
# Smart top_k detection and adjustment
|
||||||
total_docs = len(self.passage_manager.global_offset_map)
|
# Use PassageManager length (sum of shard sizes) to avoid
|
||||||
|
# depending on a massive combined map
|
||||||
|
total_docs = len(self.passage_manager)
|
||||||
original_top_k = top_k
|
original_top_k = top_k
|
||||||
if top_k > total_docs:
|
if top_k > total_docs:
|
||||||
top_k = total_docs
|
top_k = total_docs
|
||||||
@@ -530,29 +724,39 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
backend_search_kwargs: dict[str, Any] = {
|
||||||
|
"complexity": complexity,
|
||||||
|
"beam_width": beam_width,
|
||||||
|
"prune_ratio": prune_ratio,
|
||||||
|
"recompute_embeddings": recompute_embeddings,
|
||||||
|
"pruning_strategy": pruning_strategy,
|
||||||
|
"zmq_port": zmq_port,
|
||||||
|
}
|
||||||
|
# Only HNSW supports batching; forward conditionally
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
backend_search_kwargs["batch_size"] = batch_size
|
||||||
|
|
||||||
|
# Merge any extra kwargs last
|
||||||
|
backend_search_kwargs.update(kwargs)
|
||||||
|
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
top_k,
|
top_k,
|
||||||
complexity=complexity,
|
**backend_search_kwargs,
|
||||||
beam_width=beam_width,
|
|
||||||
prune_ratio=prune_ratio,
|
|
||||||
recompute_embeddings=recompute_embeddings,
|
|
||||||
pruning_strategy=pruning_strategy,
|
|
||||||
zmq_port=zmq_port,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
|
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0])
|
||||||
):
|
):
|
||||||
@@ -580,13 +784,138 @@ class LeannSearcher:
|
|||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
|
RESET = "\033[0m"
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply metadata filters if specified
|
||||||
|
if metadata_filters:
|
||||||
|
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
|
||||||
|
enriched_results = self.passage_manager.filter_search_results(
|
||||||
|
enriched_results, metadata_filters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define color codes outside the loop for final message
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
RESET = "\033[0m"
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
def _find_jsonl_file(self) -> Optional[str]:
|
||||||
|
"""Find the .jsonl file containing raw passages for grep search"""
|
||||||
|
index_path = Path(self.meta_path_str).parent
|
||||||
|
potential_files = [
|
||||||
|
index_path / "documents.leann.passages.jsonl",
|
||||||
|
index_path.parent / "documents.leann.passages.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_path in potential_files:
|
||||||
|
if file_path.exists():
|
||||||
|
return str(file_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Perform grep-based search on raw passages"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl passages file found for grep search")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = ["grep", "-i", "-n", query, jsonl_file]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 1:
|
||||||
|
return []
|
||||||
|
elif result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Grep failed: {result.stderr}")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for line in result.stdout.strip().split("\n"):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(parts[1])
|
||||||
|
text = data.get("text", "")
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", parts[0]),
|
||||||
|
text=text,
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(score),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"grep command not found. Please install grep or use semantic search."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Fallback regex search"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl file found")
|
||||||
|
|
||||||
|
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
with open(jsonl_file, encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if pattern.search(line):
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", str(line_num)),
|
||||||
|
text=data.get("text", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(len(pattern.findall(data.get("text", "")))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
This method should be called after you're done using the searcher,
|
||||||
|
especially in test environments or batch processing scenarios.
|
||||||
|
"""
|
||||||
|
backend = getattr(self.backend_impl, "embedding_server_manager", None)
|
||||||
|
if backend is not None:
|
||||||
|
backend.stop_server()
|
||||||
|
|
||||||
|
# Enable automatic cleanup patterns
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
# Avoid noisy errors during interpreter shutdown
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -594,9 +923,15 @@ class LeannChat:
|
|||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[dict[str, Any]] = None,
|
llm_config: Optional[dict[str, Any]] = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
|
searcher: Optional[LeannSearcher] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
if searcher is None:
|
||||||
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
|
self._owns_searcher = True
|
||||||
|
else:
|
||||||
|
self.searcher = searcher
|
||||||
|
self._owns_searcher = False
|
||||||
self.llm = get_llm(llm_config)
|
self.llm = get_llm(llm_config)
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
@@ -610,6 +945,9 @@ class LeannChat:
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
@@ -624,10 +962,12 @@ class LeannChat:
|
|||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
|
metadata_filters=metadata_filters,
|
||||||
|
batch_size=batch_size,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
search_time = time.time() - search_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -656,3 +996,30 @@ class LeannChat:
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
print("\nGoodbye!")
|
print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
|
This method should be called after you're done using the chat interface,
|
||||||
|
especially in test environments or batch processing scenarios.
|
||||||
|
"""
|
||||||
|
# Only stop the embedding server if this LeannChat instance created the searcher.
|
||||||
|
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
|
||||||
|
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
# Enable automatic cleanup patterns
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self.cleanup()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -422,7 +422,6 @@ class LLMInterface(ABC):
|
|||||||
top_k=10,
|
top_k=10,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
beam_width=8,
|
beam_width=8,
|
||||||
USE_DEFERRED_FETCH=True,
|
|
||||||
skip_search_reorder=True,
|
skip_search_reorder=True,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
dedup_node_dis=True,
|
dedup_node_dis=True,
|
||||||
@@ -434,7 +433,6 @@ class LLMInterface(ABC):
|
|||||||
Supported kwargs:
|
Supported kwargs:
|
||||||
- complexity (int): Search complexity parameter (default: 32)
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
- beam_width (int): Beam width for search (default: 4)
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
|
||||||
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||||
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||||
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||||
@@ -682,6 +680,60 @@ class HFChat(LLMInterface):
|
|||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiChat(LLMInterface):
|
||||||
|
"""LLM interface for Google Gemini models."""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Initializing Gemini Chat with model='{model}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import google.genai as genai
|
||||||
|
|
||||||
|
self.client = genai.Client(api_key=self.api_key)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
logger.info(f"Sending request to Gemini with model {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from google.genai.types import GenerateContentConfig
|
||||||
|
|
||||||
|
generation_config = GenerateContentConfig(
|
||||||
|
temperature=kwargs.get("temperature", 0.7),
|
||||||
|
max_output_tokens=kwargs.get("max_tokens", 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle top_p parameter
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
generation_config.top_p = kwargs["top_p"]
|
||||||
|
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=self.model,
|
||||||
|
contents=prompt,
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
# Handle potential None response text
|
||||||
|
response_text = response.text
|
||||||
|
if response_text is None:
|
||||||
|
logger.warning("Gemini returned None response text")
|
||||||
|
return ""
|
||||||
|
return response_text.strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error communicating with Gemini: {e}")
|
||||||
|
return f"Error: Could not get a response from Gemini. Details: {e}"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
@@ -795,6 +847,8 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
|
elif llm_type == "gemini":
|
||||||
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Enhanced chunking utilities with AST-aware code chunking support.
|
||||||
|
Packaged within leann-core so installed wheels can import it reliably.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Code file extensions supported by astchunk
|
||||||
|
CODE_EXTENSIONS = {
|
||||||
|
".py": "python",
|
||||||
|
".java": "java",
|
||||||
|
".cs": "csharp",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
".js": "typescript",
|
||||||
|
".jsx": "typescript",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||||
|
"""Separate documents into code files and regular text files."""
|
||||||
|
if code_extensions is None:
|
||||||
|
code_extensions = CODE_EXTENSIONS
|
||||||
|
|
||||||
|
code_docs = []
|
||||||
|
text_docs = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
||||||
|
if file_path:
|
||||||
|
file_ext = Path(file_path).suffix.lower()
|
||||||
|
if file_ext in code_extensions:
|
||||||
|
doc.metadata["language"] = code_extensions[file_ext]
|
||||||
|
doc.metadata["is_code"] = True
|
||||||
|
code_docs.append(doc)
|
||||||
|
else:
|
||||||
|
doc.metadata["is_code"] = False
|
||||||
|
text_docs.append(doc)
|
||||||
|
else:
|
||||||
|
doc.metadata["is_code"] = False
|
||||||
|
text_docs.append(doc)
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
||||||
|
return code_docs, text_docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||||
|
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
return CODE_EXTENSIONS.get(ext)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ast_chunks(
|
||||||
|
documents,
|
||||||
|
max_chunk_size: int = 512,
|
||||||
|
chunk_overlap: int = 64,
|
||||||
|
metadata_template: str = "default",
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
|
|
||||||
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"astchunk not available: {e}")
|
||||||
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
for doc in documents:
|
||||||
|
language = doc.metadata.get("language")
|
||||||
|
if not language:
|
||||||
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
configs = {
|
||||||
|
"max_chunk_size": max_chunk_size,
|
||||||
|
"language": language,
|
||||||
|
"metadata_template": metadata_template,
|
||||||
|
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
repo_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
"creation_date": doc.metadata.get("creation_date", ""),
|
||||||
|
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
||||||
|
}
|
||||||
|
configs["repo_level_metadata"] = repo_metadata
|
||||||
|
|
||||||
|
chunk_builder = ASTChunkBuilder(**configs)
|
||||||
|
code_content = doc.get_content()
|
||||||
|
if not code_content or not code_content.strip():
|
||||||
|
logger.warning("Empty code content, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
|
for chunk in chunks:
|
||||||
|
if hasattr(chunk, "text"):
|
||||||
|
chunk_text = chunk.text
|
||||||
|
elif isinstance(chunk, dict) and "text" in chunk:
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
chunk_text = chunk
|
||||||
|
else:
|
||||||
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
|
if chunk_text and chunk_text.strip():
|
||||||
|
all_chunks.append(chunk_text.strip())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
|
logger.info("Falling back to traditional chunking")
|
||||||
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
|
||||||
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def create_traditional_chunks(
|
||||||
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||||
|
if chunk_size <= 0:
|
||||||
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
|
chunk_size = 256
|
||||||
|
if chunk_overlap < 0:
|
||||||
|
chunk_overlap = 0
|
||||||
|
if chunk_overlap >= chunk_size:
|
||||||
|
chunk_overlap = chunk_size // 2
|
||||||
|
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
|
content = doc.get_content()
|
||||||
|
if content and content.strip():
|
||||||
|
all_texts.append(content.strip())
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
chunk_overlap: int = 128,
|
||||||
|
use_ast_chunking: bool = False,
|
||||||
|
ast_chunk_size: int = 512,
|
||||||
|
ast_chunk_overlap: int = 64,
|
||||||
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
|
ast_fallback_traditional: bool = True,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create text chunks from documents with optional AST support for code files."""
|
||||||
|
if not documents:
|
||||||
|
logger.warning("No documents provided for chunking")
|
||||||
|
return []
|
||||||
|
|
||||||
|
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||||
|
if code_file_extensions:
|
||||||
|
ext_mapping = {
|
||||||
|
".py": "python",
|
||||||
|
".java": "java",
|
||||||
|
".cs": "c_sharp",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
}
|
||||||
|
for ext in code_file_extensions:
|
||||||
|
if ext.lower() not in local_code_extensions:
|
||||||
|
if ext.lower() in ext_mapping:
|
||||||
|
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
if use_ast_chunking:
|
||||||
|
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||||
|
if code_docs:
|
||||||
|
try:
|
||||||
|
all_chunks.extend(
|
||||||
|
create_ast_chunks(
|
||||||
|
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AST chunking failed: {e}")
|
||||||
|
if ast_fallback_traditional:
|
||||||
|
all_chunks.extend(
|
||||||
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
if text_docs:
|
||||||
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||||
|
else:
|
||||||
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
|
return all_chunks
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -28,6 +29,8 @@ def compute_embeddings(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -50,6 +53,8 @@ def compute_embeddings(
|
|||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
adaptive_optimization=adaptive_optimization,
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
manual_tokenize=manual_tokenize,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
@@ -57,6 +62,8 @@ def compute_embeddings(
|
|||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
elif mode == "ollama":
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||||
|
elif mode == "gemini":
|
||||||
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
@@ -69,6 +76,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
@@ -212,20 +221,130 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
|
||||||
# Use torch.inference_mode for optimal performance
|
start_time = time.time()
|
||||||
with torch.inference_mode():
|
if not manual_tokenize:
|
||||||
embeddings = model.encode(
|
# Use SentenceTransformer's optimized encode path (default)
|
||||||
texts,
|
with torch.inference_mode():
|
||||||
batch_size=batch_size,
|
embeddings = model.encode(
|
||||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
texts,
|
||||||
convert_to_numpy=True,
|
batch_size=batch_size,
|
||||||
normalize_embeddings=False,
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
device=device,
|
convert_to_numpy=True,
|
||||||
)
|
normalize_embeddings=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Synchronize if CUDA to measure accurate wall time
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
|
try:
|
||||||
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
|
else:
|
||||||
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
|
hf_model.to(device)
|
||||||
|
hf_model.eval()
|
||||||
|
# Optional compile on supported devices
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
|
all_embeddings: list[np.ndarray] = []
|
||||||
|
# Progress bar when building or for large inputs
|
||||||
|
show_progress = is_build or len(texts) > 32
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm # type: ignore
|
||||||
|
|
||||||
|
batch_iter = tqdm(
|
||||||
|
range(0, len(texts), batch_size),
|
||||||
|
desc="Embedding (manual)",
|
||||||
|
unit="batch",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
except Exception:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
start_time_manual = time.time()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for start_index in batch_iter:
|
||||||
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
|
inputs = hf_tokenizer(
|
||||||
|
batch_texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
|
masked = last_hidden_state * mask
|
||||||
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
|
||||||
|
end_time = time.time()
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
logger.info(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
@@ -244,6 +363,16 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
|
# Validate input list
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
# Extra validation: abort early if any item is empty/whitespace
|
||||||
|
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
|
||||||
|
if invalid_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
|
)
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
@@ -263,8 +392,16 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
print(f"len of texts: {len(texts)}")
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 1000 # Conservative batch size
|
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
# get the avg len of texts
|
||||||
|
avg_len = sum(len(text) for text in texts) / len(texts)
|
||||||
|
print(f"avg len of texts: {avg_len}")
|
||||||
|
# if avg len is less than 1000, use the max batch size
|
||||||
|
if avg_len > 300:
|
||||||
|
max_batch_size = 500
|
||||||
|
|
||||||
|
# if avg len is less than 1000, use the max batch size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -650,3 +787,83 @@ def compute_embeddings_ollama(
|
|||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_gemini(
|
||||||
|
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute embeddings using Google Gemini API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to compute embeddings for
|
||||||
|
model_name: Gemini model name (default: "text-embedding-004")
|
||||||
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings array, shape: (len(texts), embedding_dim)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import google.genai as genai
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Google GenAI package not installed: {e}")
|
||||||
|
|
||||||
|
api_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
||||||
|
|
||||||
|
# Cache Gemini client
|
||||||
|
cache_key = "gemini_client"
|
||||||
|
if cache_key in _model_cache:
|
||||||
|
client = _model_cache[cache_key]
|
||||||
|
else:
|
||||||
|
client = genai.Client(api_key=api_key)
|
||||||
|
_model_cache[cache_key] = client
|
||||||
|
logger.info("Gemini client cached")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gemini supports batch embedding
|
||||||
|
max_batch_size = 100 # Conservative batch size for Gemini
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
||||||
|
batch_range = range(0, len(texts), max_batch_size)
|
||||||
|
batch_iterator = tqdm(
|
||||||
|
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback when tqdm is not available
|
||||||
|
batch_iterator = range(0, len(texts), max_batch_size)
|
||||||
|
|
||||||
|
for i in batch_iterator:
|
||||||
|
batch_texts = texts[i : i + max_batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the embed_content method from the new Google GenAI SDK
|
||||||
|
response = client.models.embed_content(
|
||||||
|
model=model_name,
|
||||||
|
contents=batch_texts,
|
||||||
|
config=genai.types.EmbedContentConfig(
|
||||||
|
task_type="RETRIEVAL_DOCUMENT" # For document embedding
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
for embedding_data in response.embeddings:
|
||||||
|
all_embeddings.append(embedding_data.values)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch {i} failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import psutil
|
# Lightweight, self-contained server manager with no cross-process inspection
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -43,130 +43,7 @@ def _check_port(port: int) -> bool:
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
def _check_process_matches_config(
|
# Note: All cross-process scanning helpers removed for simplicity
|
||||||
port: int, expected_model: str, expected_passages_file: str
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the process using the port matches our expected model and passages file.
|
|
||||||
Returns True if matches, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
for proc in psutil.process_iter(["pid", "cmdline"]):
|
|
||||||
if not _is_process_listening_on_port(proc, port):
|
|
||||||
continue
|
|
||||||
|
|
||||||
cmdline = proc.info["cmdline"]
|
|
||||||
if not cmdline:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return _check_cmdline_matches_config(
|
|
||||||
cmdline, port, expected_model, expected_passages_file
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"No process found listening on port {port}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not check process on port {port}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_process_listening_on_port(proc, port: int) -> bool:
|
|
||||||
"""Check if a process is listening on the given port."""
|
|
||||||
try:
|
|
||||||
connections = proc.net_connections()
|
|
||||||
for conn in connections:
|
|
||||||
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _check_cmdline_matches_config(
|
|
||||||
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
|
||||||
) -> bool:
|
|
||||||
"""Check if command line matches our expected configuration."""
|
|
||||||
cmdline_str = " ".join(cmdline)
|
|
||||||
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
|
||||||
|
|
||||||
# Check if it's our embedding server
|
|
||||||
is_embedding_server = any(
|
|
||||||
server_type in cmdline_str
|
|
||||||
for server_type in [
|
|
||||||
"embedding_server",
|
|
||||||
"leann_backend_diskann.embedding_server",
|
|
||||||
"leann_backend_hnsw.hnsw_embedding_server",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_embedding_server:
|
|
||||||
logger.debug(f"Process on port {port} is not our embedding server")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check model name
|
|
||||||
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
|
||||||
|
|
||||||
# Check passages file if provided
|
|
||||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
|
||||||
|
|
||||||
result = model_matches and passages_matches
|
|
||||||
logger.debug(
|
|
||||||
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
|
||||||
"""Check if the command line contains the expected model."""
|
|
||||||
if "--model-name" not in cmdline:
|
|
||||||
return False
|
|
||||||
|
|
||||||
model_idx = cmdline.index("--model-name")
|
|
||||||
if model_idx + 1 >= len(cmdline):
|
|
||||||
return False
|
|
||||||
|
|
||||||
actual_model = cmdline[model_idx + 1]
|
|
||||||
return actual_model == expected_model
|
|
||||||
|
|
||||||
|
|
||||||
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
|
||||||
"""Check if the command line contains the expected passages file."""
|
|
||||||
if "--passages-file" not in cmdline:
|
|
||||||
return False # Expected but not found
|
|
||||||
|
|
||||||
passages_idx = cmdline.index("--passages-file")
|
|
||||||
if passages_idx + 1 >= len(cmdline):
|
|
||||||
return False
|
|
||||||
|
|
||||||
actual_passages = cmdline[passages_idx + 1]
|
|
||||||
expected_path = Path(expected_passages_file).resolve()
|
|
||||||
actual_path = Path(actual_passages).resolve()
|
|
||||||
return actual_path == expected_path
|
|
||||||
|
|
||||||
|
|
||||||
def _find_compatible_port_or_next_available(
|
|
||||||
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
|
||||||
) -> tuple[int, bool]:
|
|
||||||
"""
|
|
||||||
Find a port that either has a compatible server or is available.
|
|
||||||
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
|
||||||
"""
|
|
||||||
for port in range(start_port, start_port + max_attempts):
|
|
||||||
if not _check_port(port):
|
|
||||||
# Port is available
|
|
||||||
return port, False
|
|
||||||
|
|
||||||
# Port is in use, check if it's compatible
|
|
||||||
if _check_process_matches_config(port, model_name, passages_file):
|
|
||||||
logger.info(f"Found compatible server on port {port}")
|
|
||||||
return port, True
|
|
||||||
else:
|
|
||||||
logger.info(f"Port {port} has incompatible server, trying next port...")
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
@@ -185,7 +62,16 @@ class EmbeddingServerManager:
|
|||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: Optional[int] = None
|
||||||
|
# Track last-started config for in-process reuse only
|
||||||
|
self._server_config: Optional[dict] = None
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
|
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
|
||||||
|
try:
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
self._finalizer = weakref.finalize(self, self._finalize_process)
|
||||||
|
except Exception:
|
||||||
|
self._finalizer = None
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
self,
|
self,
|
||||||
@@ -195,26 +81,24 @@ class EmbeddingServerManager:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
passages_file = kwargs.get("passages_file")
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
|
|
||||||
# Check if we have a compatible server already running
|
# If this manager already has a live server, just reuse it
|
||||||
if self._has_compatible_running_server(model_name, passages_file):
|
if self.server_process and self.server_process.poll() is None and self.server_port:
|
||||||
logger.info("Found compatible running server!")
|
logger.info("Reusing in-process server")
|
||||||
return True, port
|
return True, self.server_port
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
# Find a compatible port or next available
|
# Always pick a fresh available port
|
||||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
try:
|
||||||
port, model_name, passages_file
|
actual_port = _get_available_port(port)
|
||||||
)
|
except RuntimeError:
|
||||||
|
logger.error("No available ports found")
|
||||||
if is_compatible:
|
return False, port
|
||||||
logger.info(f"Found compatible server on port {actual_port}")
|
|
||||||
return True, actual_port
|
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
@@ -247,17 +131,7 @@ class EmbeddingServerManager:
|
|||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
|
|
||||||
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||||
"""Check if we have a compatible running server."""
|
|
||||||
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
|
||||||
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info("Existing server process is incompatible. Should start a new server.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
@@ -304,22 +178,62 @@ class EmbeddingServerManager:
|
|||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
# Let server output go directly to console
|
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
|
||||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
# Embedding servers use many print statements that can fill stdout buffers
|
||||||
|
is_ci = os.environ.get("CI") == "true"
|
||||||
|
if is_ci:
|
||||||
|
stdout_target = subprocess.DEVNULL
|
||||||
|
stderr_target = None # Keep stderr for error debugging in CI
|
||||||
|
logger.info(
|
||||||
|
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stdout_target = None # Direct to console for visible logs
|
||||||
|
stderr_target = None # Direct to console for visible logs
|
||||||
|
|
||||||
|
# Start embedding server subprocess
|
||||||
|
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=None, # Direct to console
|
stdout=stdout_target,
|
||||||
stderr=None, # Direct to console
|
stderr=stderr_target,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
|
# Record config for in-process reuse
|
||||||
|
try:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
|
if "--model-name" in command
|
||||||
|
else "",
|
||||||
|
"passages_file": command[command.index("--passages-file") + 1]
|
||||||
|
if "--passages-file" in command
|
||||||
|
else "",
|
||||||
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
|
if "--embedding-mode" in command
|
||||||
|
else "sentence-transformers",
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": "",
|
||||||
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
# Use a lambda to avoid issues with bound methods
|
# Always attempt best-effort finalize at interpreter exit
|
||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
# Touch finalizer so it knows there is a live process
|
||||||
|
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
|
||||||
|
try:
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
self._finalizer = weakref.finalize(self, self._finalize_process)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
@@ -344,24 +258,35 @@ class EmbeddingServerManager:
|
|||||||
if not self.server_process:
|
if not self.server_process:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.server_process.poll() is not None:
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
# Process already terminated
|
# Process already terminated
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
self.server_process.terminate()
|
|
||||||
|
# Use simple termination first; if the server installed signal handlers,
|
||||||
|
# it will exit cleanly. Otherwise escalate to kill after a short wait.
|
||||||
|
try:
|
||||||
|
self.server_process.terminate()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=3)
|
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
|
||||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
|
||||||
)
|
)
|
||||||
self.server_process.kill()
|
try:
|
||||||
|
self.server_process.kill()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=2)
|
self.server_process.wait(timeout=2)
|
||||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
||||||
@@ -369,15 +294,33 @@ class EmbeddingServerManager:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
||||||
)
|
)
|
||||||
# Don't hang indefinitely
|
|
||||||
|
|
||||||
# Clean up process resources to prevent resource tracker warnings
|
# Clean up process resources with timeout to avoid CI hang
|
||||||
try:
|
try:
|
||||||
self.server_process.wait() # Ensure process is fully cleaned up
|
# Use shorter timeout in CI environments
|
||||||
|
is_ci = os.environ.get("CI") == "true"
|
||||||
|
timeout = 3 if is_ci else 10
|
||||||
|
self.server_process.wait(timeout=timeout)
|
||||||
|
logger.info(f"Server process {self.server_process.pid} cleanup completed")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during process cleanup: {e}")
|
||||||
|
finally:
|
||||||
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
|
|
||||||
|
def _finalize_process(self) -> None:
|
||||||
|
"""Best-effort cleanup used by weakref.finalize/atexit."""
|
||||||
|
try:
|
||||||
|
self.stop_server()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.server_process = None
|
def _adopt_existing_server(self, *args, **kwargs) -> None:
|
||||||
|
# Removed: cross-process adoption no longer supported
|
||||||
|
return
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
@@ -393,10 +336,16 @@ class EmbeddingServerManager:
|
|||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback
|
# Register atexit callback (unified)
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
|
# Record config for in-process reuse is best-effort in Colab mode
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": "",
|
||||||
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(
|
||||||
self, passages_source_file: str, port: Union[int, None], **kwargs
|
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Ensure server is running"""
|
"""Ensure server is running"""
|
||||||
pass
|
pass
|
||||||
@@ -50,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Union[int, None] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -76,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Union[int, None] = None,
|
zmq_port: Optional[int] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
@@ -64,19 +64,6 @@ def handle_request(request):
|
|||||||
"required": ["index_name", "query"],
|
"required": ["index_name", "query"],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "leann_status",
|
|
||||||
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "leann_list",
|
"name": "leann_list",
|
||||||
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
||||||
@@ -107,7 +94,7 @@ def handle_request(request):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build simplified command
|
# Build simplified command with non-interactive flag for MCP compatibility
|
||||||
cmd = [
|
cmd = [
|
||||||
"leann",
|
"leann",
|
||||||
"search",
|
"search",
|
||||||
@@ -115,19 +102,10 @@ def handle_request(request):
|
|||||||
args["query"],
|
args["query"],
|
||||||
f"--top-k={args.get('top_k', 5)}",
|
f"--top-k={args.get('top_k', 5)}",
|
||||||
f"--complexity={args.get('complexity', 32)}",
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
|
"--non-interactive",
|
||||||
]
|
]
|
||||||
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
elif tool_name == "leann_status":
|
|
||||||
if args.get("index_name"):
|
|
||||||
# Check specific index status - for now, we'll use leann list and filter
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
# We could enhance this to show more detailed status per index
|
|
||||||
else:
|
|
||||||
# Show all indexes status
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
elif tool_name == "leann_list":
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||||
|
|
||||||
|
|||||||
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""
|
||||||
|
Metadata filtering engine for LEANN search results.
|
||||||
|
|
||||||
|
This module provides generic metadata filtering capabilities that can be applied
|
||||||
|
to search results from any LEANN backend. The filtering supports various
|
||||||
|
operators for different data types including numbers, strings, booleans, and lists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for filter specifications
|
||||||
|
FilterValue = Union[str, int, float, bool, list]
|
||||||
|
FilterSpec = dict[str, FilterValue]
|
||||||
|
MetadataFilters = dict[str, FilterSpec]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilterEngine:
|
||||||
|
"""
|
||||||
|
Engine for evaluating metadata filters against search results.
|
||||||
|
|
||||||
|
Supports various operators for filtering based on metadata fields:
|
||||||
|
- Comparison: ==, !=, <, <=, >, >=
|
||||||
|
- Membership: in, not_in
|
||||||
|
- String operations: contains, starts_with, ends_with
|
||||||
|
- Boolean operations: is_true, is_false
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the filter engine with supported operators."""
|
||||||
|
self.operators = {
|
||||||
|
"==": self._equals,
|
||||||
|
"!=": self._not_equals,
|
||||||
|
"<": self._less_than,
|
||||||
|
"<=": self._less_than_or_equal,
|
||||||
|
">": self._greater_than,
|
||||||
|
">=": self._greater_than_or_equal,
|
||||||
|
"in": self._in,
|
||||||
|
"not_in": self._not_in,
|
||||||
|
"contains": self._contains,
|
||||||
|
"starts_with": self._starts_with,
|
||||||
|
"ends_with": self._ends_with,
|
||||||
|
"is_true": self._is_true,
|
||||||
|
"is_false": self._is_false,
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_filters(
|
||||||
|
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to a list of search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of result dictionaries, each containing 'metadata' field
|
||||||
|
metadata_filters: Dictionary of filter specifications
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of search results
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying filters: {metadata_filters}")
|
||||||
|
logger.debug(f"Input results count: {len(search_results)}")
|
||||||
|
|
||||||
|
filtered_results = []
|
||||||
|
for result in search_results:
|
||||||
|
if self._evaluate_filters(result, metadata_filters):
|
||||||
|
filtered_results.append(result)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results count: {len(filtered_results)}")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate all filters against a single search result.
|
||||||
|
|
||||||
|
All filters must pass (AND logic) for the result to be included.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary (including metadata, text, etc.)
|
||||||
|
filters: Filter specifications to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all filters pass, False otherwise
|
||||||
|
"""
|
||||||
|
for field_name, filter_spec in filters.items():
|
||||||
|
if not self._evaluate_field_filter(result, field_name, filter_spec):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _evaluate_field_filter(
|
||||||
|
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate a single field filter against a search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary
|
||||||
|
field_name: Name of the field to filter on
|
||||||
|
filter_spec: Filter specification for this field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the filter passes, False otherwise
|
||||||
|
"""
|
||||||
|
# First check top-level fields, then check metadata
|
||||||
|
field_value = result.get(field_name)
|
||||||
|
if field_value is None:
|
||||||
|
# Try to get from metadata if not found at top level
|
||||||
|
metadata = result.get("metadata", {})
|
||||||
|
field_value = metadata.get(field_name)
|
||||||
|
|
||||||
|
# Handle missing fields - they fail all filters except existence checks
|
||||||
|
if field_value is None:
|
||||||
|
logger.debug(f"Field '{field_name}' not found in result or metadata")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Evaluate each operator in the filter spec
|
||||||
|
for operator, expected_value in filter_spec.items():
|
||||||
|
if operator not in self.operators:
|
||||||
|
logger.warning(f"Unsupported operator: {operator}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.operators[operator](field_value, expected_value):
|
||||||
|
logger.debug(
|
||||||
|
f"Filter failed: {field_name} {operator} {expected_value} "
|
||||||
|
f"(actual: {field_value})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Comparison operators
|
||||||
|
def _equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value equals expected value."""
|
||||||
|
return field_value == expected_value
|
||||||
|
|
||||||
|
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value does not equal expected value."""
|
||||||
|
return field_value != expected_value
|
||||||
|
|
||||||
|
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
|
||||||
|
|
||||||
|
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
|
||||||
|
|
||||||
|
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
|
||||||
|
|
||||||
|
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
|
||||||
|
|
||||||
|
# Membership operators
|
||||||
|
def _in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'in' operator requires a list, tuple, or set")
|
||||||
|
return field_value in expected_value
|
||||||
|
|
||||||
|
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is not in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'not_in' operator requires a list, tuple, or set")
|
||||||
|
return field_value not in expected_value
|
||||||
|
|
||||||
|
# String operators
|
||||||
|
def _contains(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value contains the expected substring."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return expected_str in field_str
|
||||||
|
|
||||||
|
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value starts with the expected prefix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.startswith(expected_str)
|
||||||
|
|
||||||
|
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value ends with the expected suffix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.endswith(expected_str)
|
||||||
|
|
||||||
|
# Boolean operators
|
||||||
|
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is truthy."""
|
||||||
|
return bool(field_value)
|
||||||
|
|
||||||
|
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is falsy."""
|
||||||
|
return not bool(field_value)
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
|
||||||
|
"""
|
||||||
|
Helper for numeric comparisons with type coercion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_value: Value from metadata
|
||||||
|
expected_value: Value to compare against
|
||||||
|
compare_func: Comparison function to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of comparison
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to convert both values to numbers for comparison
|
||||||
|
if isinstance(field_value, str) and isinstance(expected_value, str):
|
||||||
|
# String comparison if both are strings
|
||||||
|
return compare_func(field_value, expected_value)
|
||||||
|
|
||||||
|
# Numeric comparison - attempt to convert to float
|
||||||
|
field_num = (
|
||||||
|
float(field_value) if not isinstance(field_value, (int, float)) else field_value
|
||||||
|
)
|
||||||
|
expected_num = (
|
||||||
|
float(expected_value)
|
||||||
|
if not isinstance(expected_value, (int, float))
|
||||||
|
else expected_value
|
||||||
|
)
|
||||||
|
|
||||||
|
return compare_func(field_num, expected_num)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Fall back to string comparison if numeric conversion fails
|
||||||
|
return compare_func(str(field_value), str(expected_value))
|
||||||
@@ -2,11 +2,17 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from typing import TYPE_CHECKING
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
|
# Set up logger for this module
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +20,7 @@ def register_backend(name: str):
|
|||||||
"""A decorator to register a new backend class."""
|
"""A decorator to register a new backend class."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
print(f"INFO: Registering backend '{name}'")
|
logger.debug(f"Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@@ -39,3 +45,54 @@ def autodiscover_backends():
|
|||||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
pass
|
pass
|
||||||
# print("INFO: Backend auto-discovery finished.")
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|
||||||
|
|
||||||
|
def register_project_directory(project_dir: Optional[Union[str, Path]] = None):
|
||||||
|
"""
|
||||||
|
Register a project directory in the global LEANN registry.
|
||||||
|
|
||||||
|
This allows `leann list` to discover indexes created by apps or other tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_dir: Directory to register. If None, uses current working directory.
|
||||||
|
"""
|
||||||
|
if project_dir is None:
|
||||||
|
project_dir = Path.cwd()
|
||||||
|
else:
|
||||||
|
project_dir = Path(project_dir)
|
||||||
|
|
||||||
|
# Only register directories that have some kind of LEANN content
|
||||||
|
# Either .leann/indexes/ (CLI format) or *.leann.meta.json files (apps format)
|
||||||
|
has_cli_indexes = (project_dir / ".leann" / "indexes").exists()
|
||||||
|
has_app_indexes = any(project_dir.rglob("*.leann.meta.json"))
|
||||||
|
|
||||||
|
if not (has_cli_indexes or has_app_indexes):
|
||||||
|
# Don't register if there are no LEANN indexes
|
||||||
|
return
|
||||||
|
|
||||||
|
global_registry = Path.home() / ".leann" / "projects.json"
|
||||||
|
global_registry.parent.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
project_str = str(project_dir.resolve())
|
||||||
|
|
||||||
|
# Load existing registry
|
||||||
|
projects = []
|
||||||
|
if global_registry.exists():
|
||||||
|
try:
|
||||||
|
with open(global_registry) as f:
|
||||||
|
projects = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not load existing project registry")
|
||||||
|
projects = []
|
||||||
|
|
||||||
|
# Add project if not already present
|
||||||
|
if project_str not in projects:
|
||||||
|
projects.append(project_str)
|
||||||
|
|
||||||
|
# Save updated registry
|
||||||
|
try:
|
||||||
|
with open(global_registry, "w") as f:
|
||||||
|
json.dump(projects, f, indent=2)
|
||||||
|
logger.debug(f"Registered project directory: {project_str}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not save project registry: {e}")
|
||||||
|
|||||||
@@ -2,29 +2,33 @@
|
|||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||||
|
|
||||||
|
For agent-facing discovery details, see `llms.txt` in the repository root.
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
**Step 1:** First, complete the basic LEANN installation following the [📦 Installation guide](../../README.md#installation) in the root README:
|
Install LEANN globally for MCP integration (with default backend):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv venv
|
uv tool install leann-core --with leann
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install leann
|
|
||||||
```
|
```
|
||||||
|
This installs the `leann` CLI into an isolated tool environment and includes both backends so `leann build` works out-of-the-box.
|
||||||
**Step 2:** Install LEANN globally for MCP integration:
|
|
||||||
```bash
|
|
||||||
uv tool install leann-core
|
|
||||||
```
|
|
||||||
|
|
||||||
This makes the `leann` command available system-wide, which `leann_mcp` requires.
|
|
||||||
|
|
||||||
## 🚀 Quick Setup
|
## 🚀 Quick Setup
|
||||||
|
|
||||||
Add the LEANN MCP server to Claude Code:
|
Add the LEANN MCP server to Claude Code. Choose the scope based on how widely you want it available. Below is the command to install it globally; if you prefer a local install, skip this step:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
claude mcp add leann-server -- leann_mcp
|
# Global (recommended): available in all projects for your user
|
||||||
|
claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
- `leann-server`: the display name of the MCP server in Claude Code (you can change it).
|
||||||
|
- `leann_mcp`: the Python entry point installed with LEANN that starts the MCP server.
|
||||||
|
|
||||||
|
Verify it is registered globally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
claude mcp list | cat
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🛠️ Available Tools
|
## 🛠️ Available Tools
|
||||||
@@ -33,27 +37,36 @@ Once connected, you'll have access to these powerful semantic search tools in Cl
|
|||||||
|
|
||||||
- **`leann_list`** - List all available indexes across your projects
|
- **`leann_list`** - List all available indexes across your projects
|
||||||
- **`leann_search`** - Perform semantic searches across code and documents
|
- **`leann_search`** - Perform semantic searches across code and documents
|
||||||
- **`leann_ask`** - Ask natural language questions and get AI-powered answers from your codebase
|
|
||||||
|
|
||||||
## 🎯 Quick Start Example
|
## 🎯 Quick Start Example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Add locally if you did not add it globally (current folder only; default if --scope is omitted)
|
||||||
|
claude mcp add leann-server -- leann_mcp
|
||||||
|
|
||||||
# Build an index for your project (change to your actual path)
|
# Build an index for your project (change to your actual path)
|
||||||
leann build my-project --docs ./
|
# See the advanced examples below for more ways to configure indexing
|
||||||
|
# Set the index name (replace 'my-project' with your own)
|
||||||
|
leann build my-project --docs $(git ls-files)
|
||||||
|
|
||||||
# Start Claude Code
|
# Start Claude Code
|
||||||
claude
|
claude
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🚀 Advanced Usage Examples
|
## 🚀 Advanced Usage Examples to build the index
|
||||||
|
|
||||||
### Index Entire Git Repository
|
### Index Entire Git Repository
|
||||||
```bash
|
```bash
|
||||||
# Index all tracked files in your git repository, note right now we will skip submodules, but we can add it back easily if you want
|
# Index all tracked files in your Git repository.
|
||||||
|
# Note: submodules are currently skipped; we can add them back if needed.
|
||||||
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
# Index only specific file types from git
|
# Index only tracked Python files from Git.
|
||||||
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
||||||
|
|
||||||
|
# If you encounter empty requests caused by empty files (e.g., __init__.py), exclude zero-byte files. Thanks @ww2283 for pointing [that](https://github.com/yichuan-w/LEANN/issues/48) out
|
||||||
|
leann build leann-prospec-lig --docs $(find ./src -name "*.py" -not -empty) --embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
```
|
```
|
||||||
|
|
||||||
### Multiple Directories and Files
|
### Multiple Directories and Files
|
||||||
@@ -81,7 +94,7 @@ leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.js
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
**Try this in Claude Code:**
|
## **Try this in Claude Code:**
|
||||||
```
|
```
|
||||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
Help me understand this codebase. List available indexes and search for authentication patterns.
|
||||||
```
|
```
|
||||||
@@ -90,6 +103,7 @@ Help me understand this codebase. List available indexes and search for authenti
|
|||||||
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
If you see a prompt asking whether to proceed with LEANN, you can now use it in your chat!
|
||||||
|
|
||||||
## 🧠 How It Works
|
## 🧠 How It Works
|
||||||
|
|
||||||
@@ -125,3 +139,11 @@ To remove LEANN
|
|||||||
```
|
```
|
||||||
uv pip uninstall leann leann-backend-hnsw leann-core
|
uv pip uninstall leann leann-backend-hnsw leann-core
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To globally remove LEANN (for version update)
|
||||||
|
```
|
||||||
|
uv tool list | cat
|
||||||
|
uv tool uninstall leann-core
|
||||||
|
command -v leann || echo "leann gone"
|
||||||
|
command -v leann_mcp || echo "leann_mcp gone"
|
||||||
|
```
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.2.8"
|
version = "0.3.4"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
1
packages/wechat-exporter/__init__.py
Normal file
1
packages/wechat-exporter/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__all__ = []
|
||||||
@@ -136,5 +136,9 @@ def export_sqlite(
|
|||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
app()
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -10,11 +10,10 @@ requires-python = ">=3.9"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
"leann-backend-hnsw",
|
"leann-backend-hnsw",
|
||||||
|
"typer>=0.12.3",
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"torch",
|
"torch",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"flask",
|
|
||||||
"flask_compress",
|
|
||||||
"datasets>=2.15.0",
|
"datasets>=2.15.0",
|
||||||
"evaluate",
|
"evaluate",
|
||||||
"colorama",
|
"colorama",
|
||||||
@@ -43,9 +42,17 @@ dependencies = [
|
|||||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
|
"pybind11>=3.0.0",
|
||||||
"pathspec>=0.12.1",
|
"pathspec>=0.12.1",
|
||||||
"nbconvert>=7.16.6",
|
"nbconvert>=7.16.6",
|
||||||
"gitignore-parser>=0.1.12",
|
"gitignore-parser>=0.1.12",
|
||||||
|
# AST-aware code chunking dependencies
|
||||||
|
"astchunk>=0.1.0",
|
||||||
|
"tree-sitter>=0.20.0",
|
||||||
|
"tree-sitter-python>=0.20.0",
|
||||||
|
"tree-sitter-java>=0.20.0",
|
||||||
|
"tree-sitter-c-sharp>=0.20.0",
|
||||||
|
"tree-sitter-typescript>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -54,7 +61,7 @@ dev = [
|
|||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
"pytest-xdist>=3.0", # For parallel test execution
|
"pytest-xdist>=3.0", # For parallel test execution
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff>=0.1.0",
|
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
"pre-commit>=3.5.0",
|
"pre-commit>=3.5.0",
|
||||||
@@ -64,9 +71,7 @@ test = [
|
|||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
"pytest-timeout>=2.0",
|
"pytest-timeout>=2.0",
|
||||||
"llama-index-core>=0.12.0",
|
"llama-index-core>=0.12.0",
|
||||||
"llama-index-readers-file>=0.4.0",
|
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
"sentence-transformers>=2.2.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
diskann = [
|
diskann = [
|
||||||
@@ -83,23 +88,24 @@ documents = [
|
|||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = []
|
py-modules = []
|
||||||
|
packages = ["wechat_exporter"]
|
||||||
|
package-dir = { "wechat_exporter" = "packages/wechat-exporter" }
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
wechat-exporter = "wechat_exporter.main:main"
|
||||||
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
line-length = 100
|
line-length = 100
|
||||||
extend-exclude = [
|
extend-exclude = ["third_party"]
|
||||||
"third_party",
|
|
||||||
"*.egg-info",
|
|
||||||
"__pycache__",
|
|
||||||
".git",
|
|
||||||
".venv",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -122,21 +128,12 @@ ignore = [
|
|||||||
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
|
|
||||||
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
indent-style = "space"
|
indent-style = "space"
|
||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
|
|
||||||
[dependency-groups]
|
|
||||||
dev = [
|
|
||||||
"ruff>=0.12.4",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.lychee]
|
[tool.lychee]
|
||||||
accept = ["200", "403", "429", "503"]
|
accept = ["200", "403", "429", "503"]
|
||||||
timeout = 20
|
timeout = 20
|
||||||
@@ -154,7 +151,7 @@ markers = [
|
|||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
]
|
]
|
||||||
timeout = 600
|
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||||
addopts = [
|
addopts = [
|
||||||
"-v",
|
"-v",
|
||||||
"--tb=short",
|
"--tb=short",
|
||||||
|
|||||||
76
sky/leann-build.yaml
Normal file
76
sky/leann-build.yaml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
name: leann-build
|
||||||
|
|
||||||
|
resources:
|
||||||
|
# Choose a GPU for fast embeddings (examples: L4, A10G, A100). CPU also works but is slower.
|
||||||
|
accelerators: L4:1
|
||||||
|
# Optionally pin a cloud, otherwise SkyPilot will auto-select
|
||||||
|
# cloud: aws
|
||||||
|
disk_size: 100
|
||||||
|
|
||||||
|
envs:
|
||||||
|
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
|
||||||
|
index_name: my-index
|
||||||
|
docs: ./data
|
||||||
|
backend: hnsw # hnsw | diskann
|
||||||
|
complexity: 64
|
||||||
|
graph_degree: 32
|
||||||
|
num_threads: 8
|
||||||
|
# Embedding selection
|
||||||
|
embedding_mode: sentence-transformers # sentence-transformers | openai | mlx | ollama
|
||||||
|
embedding_model: facebook/contriever
|
||||||
|
# Storage/latency knobs
|
||||||
|
recompute: true # true => selective recomputation (recommended)
|
||||||
|
compact: true # for HNSW only
|
||||||
|
# Optional pass-through
|
||||||
|
extra_args: ""
|
||||||
|
# Rebuild control
|
||||||
|
force: true
|
||||||
|
|
||||||
|
# Sync local paths to the remote VM. Adjust as needed.
|
||||||
|
file_mounts:
|
||||||
|
# Example: mount your local data directory used for building
|
||||||
|
~/leann-data: ${docs}
|
||||||
|
|
||||||
|
setup: |
|
||||||
|
set -e
|
||||||
|
# Install uv (package manager)
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
|
|
||||||
|
# Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30)
|
||||||
|
sudo apt-get update -y
|
||||||
|
sudo apt-get install -y libstdc++6 libgomp1
|
||||||
|
# Also upgrade conda's libstdc++ in base env (Skypilot images include conda)
|
||||||
|
if command -v conda >/dev/null 2>&1; then
|
||||||
|
conda install -y -n base -c conda-forge libstdcxx-ng
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install LEANN CLI and backends into the user environment
|
||||||
|
uv pip install --upgrade pip
|
||||||
|
uv pip install leann-core leann-backend-hnsw leann-backend-diskann
|
||||||
|
|
||||||
|
run: |
|
||||||
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
|
# Derive flags from env
|
||||||
|
recompute_flag=""
|
||||||
|
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
|
||||||
|
recompute_flag="--no-recompute"
|
||||||
|
fi
|
||||||
|
force_flag=""
|
||||||
|
if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then
|
||||||
|
force_flag="--force"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build command
|
||||||
|
python -m leann.cli build ${index_name} \
|
||||||
|
--docs ~/leann-data \
|
||||||
|
--backend ${backend} \
|
||||||
|
--complexity ${complexity} \
|
||||||
|
--graph-degree ${graph_degree} \
|
||||||
|
--num-threads ${num_threads} \
|
||||||
|
--embedding-mode ${embedding_mode} \
|
||||||
|
--embedding-model ${embedding_model} \
|
||||||
|
${recompute_flag} ${force_flag} ${extra_args}
|
||||||
|
|
||||||
|
# Print where the index is stored for downstream rsync
|
||||||
|
echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"
|
||||||
@@ -6,10 +6,11 @@ This directory contains automated tests for the LEANN project using pytest.
|
|||||||
|
|
||||||
### `test_readme_examples.py`
|
### `test_readme_examples.py`
|
||||||
Tests the examples shown in README.md:
|
Tests the examples shown in README.md:
|
||||||
- The basic example code that users see first
|
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
||||||
- Import statements work correctly
|
- Import statements work correctly
|
||||||
- Different backend options (HNSW, DiskANN)
|
- Different backend options (HNSW, DiskANN)
|
||||||
- Different LLM configuration options
|
- Different LLM configuration options (parametrized for both backends)
|
||||||
|
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
||||||
|
|
||||||
### `test_basic.py`
|
### `test_basic.py`
|
||||||
Basic functionality tests that verify:
|
Basic functionality tests that verify:
|
||||||
@@ -25,6 +26,16 @@ Tests the document RAG example functionality:
|
|||||||
- Tests error handling with invalid parameters
|
- Tests error handling with invalid parameters
|
||||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||||
|
|
||||||
|
### `test_diskann_partition.py`
|
||||||
|
Tests DiskANN graph partitioning functionality:
|
||||||
|
- Tests DiskANN index building without partitioning (baseline)
|
||||||
|
- Tests automatic graph partitioning with `is_recompute=True`
|
||||||
|
- Verifies that partition files are created and large files are cleaned up for storage saving
|
||||||
|
- Tests search functionality with partitioned indices
|
||||||
|
- Validates medoid and max_base_norm file generation and usage
|
||||||
|
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||||
|
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -54,15 +65,23 @@ pytest tests/ -m "not openai"
|
|||||||
|
|
||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
|
# Run DiskANN partition tests (requires local machine, not CI)
|
||||||
|
pytest tests/test_diskann_partition.py
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run with specific backend:
|
### Run with specific backend:
|
||||||
```bash
|
```bash
|
||||||
# Test only HNSW backend
|
# Test only HNSW backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||||
|
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
||||||
|
|
||||||
# Test only DiskANN backend
|
# Test only DiskANN backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||||
|
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
||||||
|
|
||||||
|
# All DiskANN tests (parametrized + specialized partition tests)
|
||||||
|
pytest tests/ -k diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
397
tests/test_astchunk_integration.py
Normal file
397
tests/test_astchunk_integration.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""
|
||||||
|
Test suite for astchunk integration with LEANN.
|
||||||
|
Tests AST-aware chunking functionality, language detection, and fallback mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add apps directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "apps"))
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from chunking import (
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockDocument:
|
||||||
|
"""Mock LlamaIndex Document for testing."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, file_path: str = "", metadata: Optional[dict] = None):
|
||||||
|
self.content = content
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
if file_path:
|
||||||
|
self.metadata["file_path"] = file_path
|
||||||
|
|
||||||
|
def get_content(self) -> str:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeFileDetection:
|
||||||
|
"""Test code file detection and language mapping."""
|
||||||
|
|
||||||
|
def test_detect_code_files_python(self):
|
||||||
|
"""Test detection of Python files."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("print('hello')", "/path/to/file.py"),
|
||||||
|
MockDocument("This is text", "/path/to/file.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 1
|
||||||
|
assert len(text_docs) == 1
|
||||||
|
assert code_docs[0].metadata["language"] == "python"
|
||||||
|
assert code_docs[0].metadata["is_code"] is True
|
||||||
|
assert text_docs[0].metadata["is_code"] is False
|
||||||
|
|
||||||
|
def test_detect_code_files_multiple_languages(self):
|
||||||
|
"""Test detection of multiple programming languages."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def func():", "/path/to/script.py"),
|
||||||
|
MockDocument("public class Test {}", "/path/to/Test.java"),
|
||||||
|
MockDocument("interface ITest {}", "/path/to/test.ts"),
|
||||||
|
MockDocument("using System;", "/path/to/Program.cs"),
|
||||||
|
MockDocument("Regular text content", "/path/to/document.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 4
|
||||||
|
assert len(text_docs) == 1
|
||||||
|
|
||||||
|
languages = [doc.metadata["language"] for doc in code_docs]
|
||||||
|
assert "python" in languages
|
||||||
|
assert "java" in languages
|
||||||
|
assert "typescript" in languages
|
||||||
|
assert "csharp" in languages
|
||||||
|
|
||||||
|
def test_detect_code_files_no_file_path(self):
|
||||||
|
"""Test handling of documents without file paths."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("some content"),
|
||||||
|
MockDocument("other content", metadata={"some_key": "value"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 0
|
||||||
|
assert len(text_docs) == 2
|
||||||
|
for doc in text_docs:
|
||||||
|
assert doc.metadata["is_code"] is False
|
||||||
|
|
||||||
|
def test_get_language_from_extension(self):
|
||||||
|
"""Test language detection from file extensions."""
|
||||||
|
assert get_language_from_extension("test.py") == "python"
|
||||||
|
assert get_language_from_extension("Test.java") == "java"
|
||||||
|
assert get_language_from_extension("component.tsx") == "typescript"
|
||||||
|
assert get_language_from_extension("Program.cs") == "csharp"
|
||||||
|
assert get_language_from_extension("document.txt") is None
|
||||||
|
assert get_language_from_extension("") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingFunctions:
|
||||||
|
"""Test various chunking functionality."""
|
||||||
|
|
||||||
|
def test_create_traditional_chunks(self):
|
||||||
|
"""Test traditional text chunking."""
|
||||||
|
docs = [
|
||||||
|
MockDocument(
|
||||||
|
"This is a test document. It has multiple sentences. We want to test chunking."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_traditional_chunks_empty_docs(self):
|
||||||
|
"""Test traditional chunking with empty documents."""
|
||||||
|
chunks = create_traditional_chunks([], chunk_size=50, chunk_overlap=10)
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip astchunk tests in CI - dependency may not be available",
|
||||||
|
)
|
||||||
|
def test_create_ast_chunks_with_astchunk_available(self):
|
||||||
|
"""Test AST chunking when astchunk is available."""
|
||||||
|
python_code = '''
|
||||||
|
def hello_world():
|
||||||
|
"""Print hello world message."""
|
||||||
|
print("Hello, World!")
|
||||||
|
|
||||||
|
def add_numbers(a, b):
|
||||||
|
"""Add two numbers and return the result."""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
class Calculator:
|
||||||
|
"""A simple calculator class."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.history = []
|
||||||
|
|
||||||
|
def add(self, a, b):
|
||||||
|
result = a + b
|
||||||
|
self.history.append(f"{a} + {b} = {result}")
|
||||||
|
return result
|
||||||
|
'''
|
||||||
|
|
||||||
|
docs = [MockDocument(python_code, "/test/calculator.py", {"language": "python"})]
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
||||||
|
|
||||||
|
# Should have multiple chunks due to different functions/classes
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
|
# Check that code structure is somewhat preserved
|
||||||
|
combined_content = " ".join(chunks)
|
||||||
|
assert "def hello_world" in combined_content
|
||||||
|
assert "class Calculator" in combined_content
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# astchunk not available, should fall back to traditional chunking
|
||||||
|
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
||||||
|
assert len(chunks) > 0 # Should still get chunks from fallback
|
||||||
|
|
||||||
|
def test_create_ast_chunks_fallback_to_traditional(self):
|
||||||
|
"""Test AST chunking falls back to traditional when astchunk is not available."""
|
||||||
|
docs = [MockDocument("def test(): pass", "/test/script.py", {"language": "python"})]
|
||||||
|
|
||||||
|
# Mock astchunk import to fail
|
||||||
|
with patch("chunking.create_ast_chunks"):
|
||||||
|
# First call (actual test) should import astchunk and potentially fail
|
||||||
|
# Let's call the actual function to test the import error handling
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should return some chunks (either from astchunk or fallback)
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
def test_create_text_chunks_traditional_mode(self):
|
||||||
|
"""Test text chunking in traditional mode."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def test(): pass", "/test/script.py"),
|
||||||
|
MockDocument("This is regular text.", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_text_chunks_ast_mode(self):
|
||||||
|
"""Test text chunking in AST mode."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def test(): pass", "/test/script.py"),
|
||||||
|
MockDocument("This is regular text.", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
docs,
|
||||||
|
use_ast_chunking=True,
|
||||||
|
ast_chunk_size=100,
|
||||||
|
ast_chunk_overlap=20,
|
||||||
|
chunk_size=50,
|
||||||
|
chunk_overlap=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_text_chunks_custom_extensions(self):
|
||||||
|
"""Test text chunking with custom code file extensions."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("function test() {}", "/test/script.js"), # Not in default extensions
|
||||||
|
MockDocument("Regular text", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# First without custom extensions - should treat .js as text
|
||||||
|
chunks_without = create_text_chunks(docs, use_ast_chunking=True, code_file_extensions=None)
|
||||||
|
|
||||||
|
# Then with custom extensions - should treat .js as code
|
||||||
|
chunks_with = create_text_chunks(
|
||||||
|
docs, use_ast_chunking=True, code_file_extensions=[".js", ".jsx"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should return chunks
|
||||||
|
assert len(chunks_without) > 0
|
||||||
|
assert len(chunks_with) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationWithDocumentRAG:
|
||||||
|
"""Integration tests with the document RAG system."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_code_dir(self):
|
||||||
|
"""Create a temporary directory with sample code files."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
|
||||||
|
# Create sample Python file
|
||||||
|
python_file = temp_path / "example.py"
|
||||||
|
python_file.write_text('''
|
||||||
|
def fibonacci(n):
|
||||||
|
"""Calculate fibonacci number."""
|
||||||
|
if n <= 1:
|
||||||
|
return n
|
||||||
|
return fibonacci(n-1) + fibonacci(n-2)
|
||||||
|
|
||||||
|
class MathUtils:
|
||||||
|
@staticmethod
|
||||||
|
def factorial(n):
|
||||||
|
if n <= 1:
|
||||||
|
return 1
|
||||||
|
return n * MathUtils.factorial(n-1)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# Create sample text file
|
||||||
|
text_file = temp_path / "readme.txt"
|
||||||
|
text_file.write_text("This is a sample text file for testing purposes.")
|
||||||
|
|
||||||
|
yield temp_path
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip integration tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_document_rag_with_ast_chunking(self, temp_code_dir):
|
||||||
|
"""Test document RAG with AST chunking enabled."""
|
||||||
|
with tempfile.TemporaryDirectory() as index_dir:
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/document_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--embedding-mode",
|
||||||
|
"sentence-transformers",
|
||||||
|
"--index-dir",
|
||||||
|
index_dir,
|
||||||
|
"--data-dir",
|
||||||
|
str(temp_code_dir),
|
||||||
|
"--enable-code-chunking",
|
||||||
|
"--query",
|
||||||
|
"How does the fibonacci function work?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300, # 5 minutes
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should succeed even if astchunk is not available (fallback)
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Index saved to" in output or "Using existing index" in output
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip integration tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_code_rag_application(self, temp_code_dir):
|
||||||
|
"""Test the specialized code RAG application."""
|
||||||
|
with tempfile.TemporaryDirectory() as index_dir:
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/code_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--index-dir",
|
||||||
|
index_dir,
|
||||||
|
"--repo-dir",
|
||||||
|
str(temp_code_dir),
|
||||||
|
"--query",
|
||||||
|
"What classes are defined in this code?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, env=env)
|
||||||
|
|
||||||
|
# Should succeed
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Using AST-aware chunking" in output or "traditional chunking" in output
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
def test_text_chunking_empty_documents(self):
|
||||||
|
"""Test text chunking with empty document list."""
|
||||||
|
chunks = create_text_chunks([])
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_text_chunking_invalid_parameters(self):
|
||||||
|
"""Test text chunking with invalid parameters."""
|
||||||
|
docs = [MockDocument("test content")]
|
||||||
|
|
||||||
|
# Should handle negative chunk sizes gracefully
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
docs, chunk_size=0, chunk_overlap=0, ast_chunk_size=0, ast_chunk_overlap=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still return some result
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
def test_create_ast_chunks_no_language(self):
|
||||||
|
"""Test AST chunking with documents missing language metadata."""
|
||||||
|
docs = [MockDocument("def test(): pass", "/test/script.py")] # No language set
|
||||||
|
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should fall back to traditional chunking
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
assert len(chunks) >= 0 # May be empty if fallback also fails
|
||||||
|
|
||||||
|
def test_create_ast_chunks_empty_content(self):
|
||||||
|
"""Test AST chunking with empty content."""
|
||||||
|
docs = [MockDocument("", "/test/script.py", {"language": "python"})]
|
||||||
|
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should handle empty content gracefully
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -64,6 +64,9 @@ def test_backend_basic(backend_name):
|
|||||||
assert isinstance(results[0], SearchResult)
|
assert isinstance(results[0], SearchResult)
|
||||||
assert "topic 2" in results[0].text or "document" in results[0].text
|
assert "topic 2" in results[0].text or "document" in results[0].text
|
||||||
|
|
||||||
|
# Ensure cleanup to avoid hanging background servers
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||||
@@ -90,3 +93,5 @@ def test_large_index():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
results = searcher.search(["word10 word20"], top_k=10)
|
results = searcher.search(["word10 word20"], top_k=10)
|
||||||
assert len(results[0]) == 10
|
assert len(results[0]) == 10
|
||||||
|
# Cleanup
|
||||||
|
searcher.cleanup()
|
||||||
|
|||||||
369
tests/test_diskann_partition.py
Normal file
369
tests/test_diskann_partition.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
"""
|
||||||
|
Test DiskANN graph partitioning functionality.
|
||||||
|
|
||||||
|
Tests the automatic graph partitioning feature that was implemented to save
|
||||||
|
storage space by partitioning large DiskANN indices and safely deleting
|
||||||
|
redundant files while maintaining search functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_without_partition():
|
||||||
|
"""Test DiskANN index building without partition (baseline)."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_no_partition.leann")
|
||||||
|
|
||||||
|
# Test data - enough to trigger index building
|
||||||
|
texts = [
|
||||||
|
f"Document {i} discusses topic {i % 10} with detailed analysis of subject {i // 10}."
|
||||||
|
for i in range(500)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build without partition (is_recompute=False)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
num_neighbors=32,
|
||||||
|
search_list_size=50,
|
||||||
|
is_recompute=False, # No partition
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
# Verify index was created
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
assert index_dir.exists()
|
||||||
|
|
||||||
|
# Check that traditional DiskANN files exist
|
||||||
|
index_prefix = Path(index_path).stem
|
||||||
|
# Core DiskANN files (beam search index may not be created for small datasets)
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}_disk.index",
|
||||||
|
f"{index_prefix}_pq_compressed.bin",
|
||||||
|
f"{index_prefix}_pq_pivots.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check all generated files first for debugging
|
||||||
|
generated_files = [f.name for f in index_dir.glob(f"{index_prefix}*")]
|
||||||
|
print(f"Generated files: {generated_files}")
|
||||||
|
|
||||||
|
for required_file in required_files:
|
||||||
|
file_path = index_dir / required_file
|
||||||
|
assert file_path.exists(), f"Required file {required_file} not found"
|
||||||
|
|
||||||
|
# Ensure no partition files exist in non-partition mode
|
||||||
|
partition_files = [f"{index_prefix}_disk_graph.index", f"{index_prefix}_partition.bin"]
|
||||||
|
|
||||||
|
for partition_file in partition_files:
|
||||||
|
file_path = index_dir / partition_file
|
||||||
|
assert not file_path.exists(), (
|
||||||
|
f"Partition file {partition_file} should not exist in non-partition mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test search functionality
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
results = searcher.search("topic 3 analysis", top_k=3)
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
assert all(result.score is not None and result.score != float("-inf") for result in results)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_with_partition():
|
||||||
|
"""Test DiskANN index building with automatic graph partitioning."""
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_with_partition.leann")
|
||||||
|
|
||||||
|
# Test data - enough to trigger partitioning
|
||||||
|
texts = [
|
||||||
|
f"Document {i} explores subject {i % 15} with comprehensive coverage of area {i // 15}."
|
||||||
|
for i in range(500)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build with partition (is_recompute=True)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
num_neighbors=32,
|
||||||
|
search_list_size=50,
|
||||||
|
is_recompute=True, # Enable automatic partitioning
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
# Verify index was created
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
assert index_dir.exists()
|
||||||
|
|
||||||
|
# Check that partition files exist
|
||||||
|
index_prefix = Path(index_path).stem
|
||||||
|
partition_files = [
|
||||||
|
f"{index_prefix}_disk_graph.index", # Partitioned graph
|
||||||
|
f"{index_prefix}_partition.bin", # Partition metadata
|
||||||
|
f"{index_prefix}_pq_compressed.bin",
|
||||||
|
f"{index_prefix}_pq_pivots.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
for partition_file in partition_files:
|
||||||
|
file_path = index_dir / partition_file
|
||||||
|
assert file_path.exists(), f"Expected partition file {partition_file} not found"
|
||||||
|
|
||||||
|
# Check that large files were cleaned up (storage saving goal)
|
||||||
|
large_files = [f"{index_prefix}_disk.index", f"{index_prefix}_disk_beam_search.index"]
|
||||||
|
|
||||||
|
for large_file in large_files:
|
||||||
|
file_path = index_dir / large_file
|
||||||
|
assert not file_path.exists(), (
|
||||||
|
f"Large file {large_file} should have been deleted for storage saving"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify required auxiliary files for partition mode exist
|
||||||
|
required_files = [
|
||||||
|
f"{index_prefix}_disk.index_medoids.bin",
|
||||||
|
f"{index_prefix}_disk.index_max_base_norm.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
for req_file in required_files:
|
||||||
|
file_path = index_dir / req_file
|
||||||
|
assert file_path.exists(), (
|
||||||
|
f"Required auxiliary file {req_file} missing for partition mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_partition_search_functionality():
|
||||||
|
"""Test that search works correctly with partitioned indices."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_partition_search.leann")
|
||||||
|
|
||||||
|
# Create diverse test data
|
||||||
|
texts = [
|
||||||
|
"LEANN is a storage-efficient approximate nearest neighbor search system.",
|
||||||
|
"Graph partitioning helps reduce memory usage in large scale vector search.",
|
||||||
|
"DiskANN provides high-performance disk-based approximate nearest neighbor search.",
|
||||||
|
"Vector embeddings enable semantic search over unstructured text data.",
|
||||||
|
"Approximate nearest neighbor algorithms trade accuracy for speed and storage.",
|
||||||
|
] * 100 # Repeat to get enough data
|
||||||
|
|
||||||
|
# Build with partitioning
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True, # Enable partitioning
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
# Test search with partitioned index
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Test various queries
|
||||||
|
test_queries = [
|
||||||
|
("vector search algorithms", 5),
|
||||||
|
("LEANN storage efficiency", 3),
|
||||||
|
("graph partitioning memory", 4),
|
||||||
|
("approximate nearest neighbor", 7),
|
||||||
|
]
|
||||||
|
|
||||||
|
for query, top_k in test_queries:
|
||||||
|
results = searcher.search(query, top_k=top_k)
|
||||||
|
|
||||||
|
# Verify search results
|
||||||
|
assert len(results) == top_k, f"Expected {top_k} results for query '{query}'"
|
||||||
|
assert all(result.score is not None for result in results), (
|
||||||
|
"All results should have scores"
|
||||||
|
)
|
||||||
|
assert all(result.score != float("-inf") for result in results), (
|
||||||
|
"No result should have -inf score"
|
||||||
|
)
|
||||||
|
assert all(result.text is not None for result in results), (
|
||||||
|
"All results should have text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scores should be in descending order (higher similarity first)
|
||||||
|
scores = [result.score for result in results]
|
||||||
|
assert scores == sorted(scores, reverse=True), (
|
||||||
|
"Results should be sorted by score descending"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||||
|
)
|
||||||
|
def test_diskann_medoid_and_norm_files():
|
||||||
|
"""Test that medoid and max_base_norm files are correctly generated and used."""
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / "test_medoid_norm.leann")
|
||||||
|
|
||||||
|
# Small but sufficient dataset
|
||||||
|
texts = [f"Test document {i} with content about subject {i % 10}." for i in range(200)]
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_prefix = Path(index_path).stem
|
||||||
|
|
||||||
|
# Test medoids file
|
||||||
|
medoids_file = index_dir / f"{index_prefix}_disk.index_medoids.bin"
|
||||||
|
assert medoids_file.exists(), "Medoids file should be generated"
|
||||||
|
|
||||||
|
# Read and validate medoids file format
|
||||||
|
with open(medoids_file, "rb") as f:
|
||||||
|
nshards = struct.unpack("<I", f.read(4))[0]
|
||||||
|
one_val = struct.unpack("<I", f.read(4))[0]
|
||||||
|
medoid_id = struct.unpack("<I", f.read(4))[0]
|
||||||
|
|
||||||
|
assert nshards == 1, "Single-shot build should have 1 shard"
|
||||||
|
assert one_val == 1, "Expected value should be 1"
|
||||||
|
assert medoid_id >= 0, "Medoid ID should be valid (not hardcoded 0)"
|
||||||
|
|
||||||
|
# Test max_base_norm file
|
||||||
|
norm_file = index_dir / f"{index_prefix}_disk.index_max_base_norm.bin"
|
||||||
|
assert norm_file.exists(), "Max base norm file should be generated"
|
||||||
|
|
||||||
|
# Read and validate norm file
|
||||||
|
with open(norm_file, "rb") as f:
|
||||||
|
npts = struct.unpack("<I", f.read(4))[0]
|
||||||
|
ndims = struct.unpack("<I", f.read(4))[0]
|
||||||
|
norm_val = struct.unpack("<f", f.read(4))[0]
|
||||||
|
|
||||||
|
assert npts == 1, "Should have 1 norm point"
|
||||||
|
assert ndims == 1, "Should have 1 dimension"
|
||||||
|
assert norm_val > 0, "Norm value should be positive"
|
||||||
|
assert norm_val != float("inf"), "Norm value should be finite"
|
||||||
|
|
||||||
|
# Test that search works with these files
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
results = searcher.search("test subject", top_k=3)
|
||||||
|
|
||||||
|
# Verify that scores are not -inf (which indicates norm file was loaded correctly)
|
||||||
|
assert len(results) > 0
|
||||||
|
assert all(result.score != float("-inf") for result in results), (
|
||||||
|
"Scores should not be -inf when norm file is correct"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip performance comparison in CI - requires significant compute time",
|
||||||
|
)
|
||||||
|
def test_diskann_vs_hnsw_performance():
|
||||||
|
"""Compare DiskANN (with partition) vs HNSW performance."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Test data
|
||||||
|
texts = [
|
||||||
|
f"Performance test document {i} covering topic {i % 20} in detail." for i in range(1000)
|
||||||
|
]
|
||||||
|
query = "performance topic test"
|
||||||
|
|
||||||
|
# Test DiskANN with partitioning
|
||||||
|
diskann_path = str(Path(temp_dir) / "perf_diskann.leann")
|
||||||
|
diskann_builder = LeannBuilder(
|
||||||
|
backend_name="diskann",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
diskann_builder.add_text(text)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
diskann_builder.build_index(diskann_path)
|
||||||
|
|
||||||
|
# Test HNSW
|
||||||
|
hnsw_path = str(Path(temp_dir) / "perf_hnsw.leann")
|
||||||
|
hnsw_builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
hnsw_builder.add_text(text)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
hnsw_builder.build_index(hnsw_path)
|
||||||
|
|
||||||
|
# Compare search performance
|
||||||
|
diskann_searcher = LeannSearcher(diskann_path)
|
||||||
|
hnsw_searcher = LeannSearcher(hnsw_path)
|
||||||
|
|
||||||
|
# Warm up searches
|
||||||
|
diskann_searcher.search(query, top_k=5)
|
||||||
|
hnsw_searcher.search(query, top_k=5)
|
||||||
|
|
||||||
|
# Timed searches
|
||||||
|
start_time = time.time()
|
||||||
|
diskann_results = diskann_searcher.search(query, top_k=10)
|
||||||
|
diskann_search_time = time.time() - start_time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
hnsw_results = hnsw_searcher.search(query, top_k=10)
|
||||||
|
hnsw_search_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Basic assertions
|
||||||
|
assert len(diskann_results) == 10
|
||||||
|
assert len(hnsw_results) == 10
|
||||||
|
assert all(r.score != float("-inf") for r in diskann_results)
|
||||||
|
assert all(r.score != float("-inf") for r in hnsw_results)
|
||||||
|
|
||||||
|
# Performance ratio (informational)
|
||||||
|
if hnsw_search_time > 0:
|
||||||
|
speed_ratio = hnsw_search_time / diskann_search_time
|
||||||
|
print(f"DiskANN search time: {diskann_search_time:.4f}s")
|
||||||
|
print(f"HNSW search time: {hnsw_search_time:.4f}s")
|
||||||
|
print(f"DiskANN is {speed_ratio:.2f}x faster than HNSW")
|
||||||
@@ -57,7 +57,55 @@ def test_document_rag_simulated(test_data_dir):
|
|||||||
assert "This is a simulated answer" in output
|
assert "This is a simulated answer" in output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip AST chunking tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_document_rag_with_ast_chunking(test_data_dir):
|
||||||
|
"""Test document_rag with AST-aware chunking enabled."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Use a subdirectory that doesn't exist yet to force index creation
|
||||||
|
index_dir = Path(temp_dir) / "test_ast_index"
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/document_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--embedding-mode",
|
||||||
|
"sentence-transformers",
|
||||||
|
"--index-dir",
|
||||||
|
str(index_dir),
|
||||||
|
"--data-dir",
|
||||||
|
str(test_data_dir),
|
||||||
|
"--enable-code-chunking", # Enable AST chunking
|
||||||
|
"--query",
|
||||||
|
"What is Pride and Prejudice about?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
|
||||||
|
|
||||||
|
# Check return code
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Index saved to" in output or "Using existing index" in output
|
||||||
|
assert "This is a simulated answer" in output
|
||||||
|
|
||||||
|
# Should mention AST chunking if code files are present
|
||||||
|
# (might not be relevant for the test data, but command should succeed)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true", reason="Skip OpenAI tests in CI to avoid API costs"
|
||||||
|
)
|
||||||
def test_document_rag_openai(test_data_dir):
|
def test_document_rag_openai(test_data_dir):
|
||||||
"""Test document_rag with OpenAI embeddings."""
|
"""Test document_rag with OpenAI embeddings."""
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
|||||||
365
tests/test_metadata_filtering.py
Normal file
365
tests/test_metadata_filtering.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Comprehensive tests for metadata filtering functionality.
|
||||||
|
|
||||||
|
This module tests the MetadataFilterEngine class and its integration
|
||||||
|
with the LEANN search system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import the modules we're testing
|
||||||
|
import sys
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import PassageManager, SearchResult
|
||||||
|
from leann.metadata_filter import MetadataFilterEngine
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataFilterEngine:
|
||||||
|
"""Test suite for the MetadataFilterEngine class."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
self.engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
# Sample search results for testing
|
||||||
|
self.sample_results = [
|
||||||
|
{
|
||||||
|
"id": "doc1",
|
||||||
|
"score": 0.95,
|
||||||
|
"text": "This is chapter 1 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"tags": ["adventure", "fantasy"],
|
||||||
|
"word_count": 150,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc2",
|
||||||
|
"score": 0.87,
|
||||||
|
"text": "This is chapter 3 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 3,
|
||||||
|
"character": "Bob",
|
||||||
|
"tags": ["mystery", "thriller"],
|
||||||
|
"word_count": 250,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc3",
|
||||||
|
"score": 0.82,
|
||||||
|
"text": "This is chapter 5 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 5,
|
||||||
|
"character": "Alice",
|
||||||
|
"tags": ["romance", "drama"],
|
||||||
|
"word_count": 300,
|
||||||
|
"is_published": False,
|
||||||
|
"genre": "non-fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc4",
|
||||||
|
"score": 0.78,
|
||||||
|
"text": "This is chapter 10 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 10,
|
||||||
|
"character": "Charlie",
|
||||||
|
"tags": ["action", "adventure"],
|
||||||
|
"word_count": 400,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_engine_initialization(self):
|
||||||
|
"""Test that the filter engine initializes correctly."""
|
||||||
|
assert self.engine is not None
|
||||||
|
assert len(self.engine.operators) > 0
|
||||||
|
assert "==" in self.engine.operators
|
||||||
|
assert "contains" in self.engine.operators
|
||||||
|
assert "in" in self.engine.operators
|
||||||
|
|
||||||
|
def test_direct_instantiation(self):
|
||||||
|
"""Test direct instantiation of the engine."""
|
||||||
|
engine = MetadataFilterEngine()
|
||||||
|
assert isinstance(engine, MetadataFilterEngine)
|
||||||
|
|
||||||
|
def test_no_filters_returns_all_results(self):
|
||||||
|
"""Test that passing None or empty filters returns all results."""
|
||||||
|
# Test with None
|
||||||
|
result = self.engine.apply_filters(self.sample_results, None)
|
||||||
|
assert len(result) == len(self.sample_results)
|
||||||
|
|
||||||
|
# Test with empty dict
|
||||||
|
result = self.engine.apply_filters(self.sample_results, {})
|
||||||
|
assert len(result) == len(self.sample_results)
|
||||||
|
|
||||||
|
# Test comparison operators
|
||||||
|
def test_equals_filter(self):
|
||||||
|
"""Test equals (==) filter."""
|
||||||
|
filters = {"chapter": {"==": 1}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_not_equals_filter(self):
|
||||||
|
"""Test not equals (!=) filter."""
|
||||||
|
filters = {"genre": {"!=": "fiction"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["genre"] == "non-fiction"
|
||||||
|
|
||||||
|
def test_less_than_filter(self):
|
||||||
|
"""Test less than (<) filter."""
|
||||||
|
filters = {"chapter": {"<": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
chapters = [r["metadata"]["chapter"] for r in result]
|
||||||
|
assert all(ch < 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_less_than_or_equal_filter(self):
|
||||||
|
"""Test less than or equal (<=) filter."""
|
||||||
|
filters = {"chapter": {"<=": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
chapters = [r["metadata"]["chapter"] for r in result]
|
||||||
|
assert all(ch <= 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_greater_than_filter(self):
|
||||||
|
"""Test greater than (>) filter."""
|
||||||
|
filters = {"word_count": {">": 200}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3 # Documents with word_count 250, 300, 400
|
||||||
|
word_counts = [r["metadata"]["word_count"] for r in result]
|
||||||
|
assert all(wc > 200 for wc in word_counts)
|
||||||
|
|
||||||
|
def test_greater_than_or_equal_filter(self):
|
||||||
|
"""Test greater than or equal (>=) filter."""
|
||||||
|
filters = {"word_count": {">=": 250}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
word_counts = [r["metadata"]["word_count"] for r in result]
|
||||||
|
assert all(wc >= 250 for wc in word_counts)
|
||||||
|
|
||||||
|
# Test membership operators
|
||||||
|
def test_in_filter(self):
|
||||||
|
"""Test in filter."""
|
||||||
|
filters = {"character": {"in": ["Alice", "Bob"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
characters = [r["metadata"]["character"] for r in result]
|
||||||
|
assert all(ch in ["Alice", "Bob"] for ch in characters)
|
||||||
|
|
||||||
|
def test_not_in_filter(self):
|
||||||
|
"""Test not_in filter."""
|
||||||
|
filters = {"character": {"not_in": ["Alice", "Bob"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["character"] == "Charlie"
|
||||||
|
|
||||||
|
# Test string operators
|
||||||
|
def test_contains_filter(self):
|
||||||
|
"""Test contains filter."""
|
||||||
|
filters = {"genre": {"contains": "fiction"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 4 # Both "fiction" and "non-fiction"
|
||||||
|
|
||||||
|
def test_starts_with_filter(self):
|
||||||
|
"""Test starts_with filter."""
|
||||||
|
filters = {"genre": {"starts_with": "non"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["genre"] == "non-fiction"
|
||||||
|
|
||||||
|
def test_ends_with_filter(self):
|
||||||
|
"""Test ends_with filter."""
|
||||||
|
filters = {"text": {"ends_with": "content"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 4 # All sample texts end with "content"
|
||||||
|
|
||||||
|
# Test boolean operators
|
||||||
|
def test_is_true_filter(self):
|
||||||
|
"""Test is_true filter."""
|
||||||
|
filters = {"is_published": {"is_true": True}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert all(r["metadata"]["is_published"] for r in result)
|
||||||
|
|
||||||
|
def test_is_false_filter(self):
|
||||||
|
"""Test is_false filter."""
|
||||||
|
filters = {"is_published": {"is_false": False}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert not result[0]["metadata"]["is_published"]
|
||||||
|
|
||||||
|
# Test compound filters (AND logic)
|
||||||
|
def test_compound_filters(self):
|
||||||
|
"""Test multiple filters applied together (AND logic)."""
|
||||||
|
filters = {"genre": {"==": "fiction"}, "chapter": {"<=": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
assert r["metadata"]["genre"] == "fiction"
|
||||||
|
assert r["metadata"]["chapter"] <= 5
|
||||||
|
|
||||||
|
def test_multiple_operators_same_field(self):
|
||||||
|
"""Test multiple operators on the same field."""
|
||||||
|
filters = {"word_count": {">=": 200, "<=": 350}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
wc = r["metadata"]["word_count"]
|
||||||
|
assert 200 <= wc <= 350
|
||||||
|
|
||||||
|
# Test edge cases
|
||||||
|
def test_missing_field_fails_filter(self):
|
||||||
|
"""Test that missing metadata fields fail filters."""
|
||||||
|
filters = {"nonexistent_field": {"==": "value"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_invalid_operator(self):
|
||||||
|
"""Test that invalid operators are handled gracefully."""
|
||||||
|
filters = {"chapter": {"invalid_op": 1}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 0 # Should filter out all results
|
||||||
|
|
||||||
|
def test_type_coercion_numeric(self):
|
||||||
|
"""Test numeric type coercion in comparisons."""
|
||||||
|
# Add a result with string chapter number
|
||||||
|
test_results = [
|
||||||
|
*self.sample_results,
|
||||||
|
{
|
||||||
|
"id": "doc5",
|
||||||
|
"score": 0.75,
|
||||||
|
"text": "String chapter test",
|
||||||
|
"metadata": {"chapter": "2", "genre": "test"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
filters = {"chapter": {"<": 3}}
|
||||||
|
result = self.engine.apply_filters(test_results, filters)
|
||||||
|
# Should include doc1 (chapter=1) and doc5 (chapter="2")
|
||||||
|
assert len(result) == 2
|
||||||
|
ids = [r["id"] for r in result]
|
||||||
|
assert "doc1" in ids
|
||||||
|
assert "doc5" in ids
|
||||||
|
|
||||||
|
def test_list_membership_with_nested_tags(self):
|
||||||
|
"""Test membership operations with list metadata."""
|
||||||
|
# Note: This tests the metadata structure, not list field filtering
|
||||||
|
# For list field filtering, we'd need to modify the test data
|
||||||
|
filters = {"character": {"in": ["Alice"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(r["metadata"]["character"] == "Alice" for r in result)
|
||||||
|
|
||||||
|
def test_empty_results_list(self):
|
||||||
|
"""Test filtering on empty results list."""
|
||||||
|
filters = {"chapter": {"==": 1}}
|
||||||
|
result = self.engine.apply_filters([], filters)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPassageManagerFiltering:
|
||||||
|
"""Test suite for PassageManager filtering integration."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
# Mock the passage manager without actual file I/O
|
||||||
|
self.passage_manager = Mock(spec=PassageManager)
|
||||||
|
self.passage_manager.filter_engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
# Sample SearchResult objects
|
||||||
|
self.search_results = [
|
||||||
|
SearchResult(
|
||||||
|
id="doc1",
|
||||||
|
score=0.95,
|
||||||
|
text="Chapter 1 content",
|
||||||
|
metadata={"chapter": 1, "character": "Alice"},
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="doc2",
|
||||||
|
score=0.87,
|
||||||
|
text="Chapter 5 content",
|
||||||
|
metadata={"chapter": 5, "character": "Bob"},
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="doc3",
|
||||||
|
score=0.82,
|
||||||
|
text="Chapter 10 content",
|
||||||
|
metadata={"chapter": 10, "character": "Alice"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_search_result_filtering(self):
|
||||||
|
"""Test filtering SearchResult objects."""
|
||||||
|
# Create a real PassageManager instance just for the filtering method
|
||||||
|
# We'll mock the file operations
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
filters = {"chapter": {"<=": 5}}
|
||||||
|
result = pm.filter_search_results(self.search_results, filters)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
chapters = [r.metadata["chapter"] for r in result]
|
||||||
|
assert all(ch <= 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_filter_search_results_no_filters(self):
|
||||||
|
"""Test that None filters return all results."""
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
result = pm.filter_search_results(self.search_results, None)
|
||||||
|
assert len(result) == len(self.search_results)
|
||||||
|
|
||||||
|
def test_filter_maintains_search_result_type(self):
|
||||||
|
"""Test that filtering returns SearchResult objects."""
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
filters = {"character": {"==": "Alice"}}
|
||||||
|
result = pm.filter_search_results(self.search_results, filters)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
assert isinstance(r, SearchResult)
|
||||||
|
assert r.metadata["character"] == "Alice"
|
||||||
|
|
||||||
|
|
||||||
|
# Integration tests would go here, but they require actual LEANN backend setup
|
||||||
|
# These would test the full pipeline from LeannSearcher.search() with metadata_filters
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run basic smoke tests
|
||||||
|
engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
sample_data = [
|
||||||
|
{
|
||||||
|
"id": "test1",
|
||||||
|
"score": 0.9,
|
||||||
|
"text": "Test content",
|
||||||
|
"metadata": {"chapter": 1, "published": True},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test basic filtering
|
||||||
|
result = engine.apply_filters(sample_data, {"chapter": {"==": 1}})
|
||||||
|
assert len(result) == 1
|
||||||
|
print("✅ Basic filtering test passed")
|
||||||
|
|
||||||
|
result = engine.apply_filters(sample_data, {"chapter": {"==": 2}})
|
||||||
|
assert len(result) == 0
|
||||||
|
print("✅ No match filtering test passed")
|
||||||
|
|
||||||
|
print("🎉 All smoke tests passed!")
|
||||||
@@ -10,29 +10,33 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_readme_basic_example():
|
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||||
"""Test the basic example from README.md."""
|
def test_readme_basic_example(backend_name):
|
||||||
|
"""Test the basic example from README.md with both backends."""
|
||||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||||
|
# Skip DiskANN on CI (Linux runners) due to C++ extension memory/hardware constraints
|
||||||
|
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
||||||
|
pytest.skip("Skip DiskANN tests in CI due to resource constraints and instability")
|
||||||
|
|
||||||
# This is the exact code from README (with smaller model for CI)
|
# This is the exact code from README (with smaller model for CI)
|
||||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from leann.api import SearchResult
|
from leann.api import SearchResult
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
INDEX_PATH = str(Path(temp_dir) / "demo.leann")
|
INDEX_PATH = str(Path(temp_dir) / f"demo_{backend_name}.leann")
|
||||||
|
|
||||||
# Build an index
|
# Build an index
|
||||||
# In CI, use a smaller model to avoid memory issues
|
# In CI, use a smaller model to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
if os.environ.get("CI") == "true":
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name=backend_name,
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
||||||
dimensions=384, # Smaller dimensions
|
dimensions=384, # Smaller dimensions
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
builder = LeannBuilder(backend_name=backend_name)
|
||||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
@@ -52,9 +56,15 @@ def test_readme_basic_example():
|
|||||||
# Verify search results
|
# Verify search results
|
||||||
assert len(results) > 0
|
assert len(results) > 0
|
||||||
assert isinstance(results[0], SearchResult)
|
assert isinstance(results[0], SearchResult)
|
||||||
|
assert results[0].score != float("-inf"), (
|
||||||
|
f"should return valid scores, got {results[0].score}"
|
||||||
|
)
|
||||||
# The second text about banana-crocodile should be more relevant
|
# The second text about banana-crocodile should be more relevant
|
||||||
assert "banana" in results[0].text or "crocodile" in results[0].text
|
assert "banana" in results[0].text or "crocodile" in results[0].text
|
||||||
|
|
||||||
|
# Ensure we cleanup background embedding server
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
# Chat with your data (using simulated LLM to avoid external dependencies)
|
# Chat with your data (using simulated LLM to avoid external dependencies)
|
||||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
|
||||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
@@ -62,6 +72,8 @@ def test_readme_basic_example():
|
|||||||
# Verify chat works
|
# Verify chat works
|
||||||
assert isinstance(response, str)
|
assert isinstance(response, str)
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
|
# Cleanup chat resources
|
||||||
|
chat.cleanup()
|
||||||
|
|
||||||
|
|
||||||
def test_readme_imports():
|
def test_readme_imports():
|
||||||
@@ -110,26 +122,31 @@ def test_backend_options():
|
|||||||
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_llm_config_simulated():
|
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||||
"""Test simulated LLM configuration option."""
|
def test_llm_config_simulated(backend_name):
|
||||||
|
"""Test simulated LLM configuration option with both backends."""
|
||||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
# Skip DiskANN tests in CI due to hardware requirements
|
||||||
|
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
||||||
|
pytest.skip("Skip DiskANN tests in CI - requires specific hardware and large memory")
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannChat
|
from leann import LeannBuilder, LeannChat
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Build a simple index
|
# Build a simple index
|
||||||
index_path = str(Path(temp_dir) / "test.leann")
|
index_path = str(Path(temp_dir) / f"test_{backend_name}.leann")
|
||||||
# Use smaller model in CI to avoid memory issues
|
# Use smaller model in CI to avoid memory issues
|
||||||
if os.environ.get("CI") == "true":
|
if os.environ.get("CI") == "true":
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name=backend_name,
|
||||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
dimensions=384,
|
dimensions=384,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
builder = LeannBuilder(backend_name=backend_name)
|
||||||
builder.add_text("Test document for LLM testing")
|
builder.add_text("Test document for LLM testing")
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user