Compare commits
1 Commits
fix-arm64-
...
feat/diska
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcbcde1ea8 |
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
1
.github/workflows/build-and-publish.yml
vendored
1
.github/workflows/build-and-publish.yml
vendored
@@ -5,7 +5,6 @@ on:
|
|||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|||||||
259
.github/workflows/build-reusable.yml
vendored
259
.github/workflows/build-reusable.yml
vendored
@@ -54,51 +54,20 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
# ARM64 Linux builds
|
- os: macos-latest
|
||||||
- os: ubuntu-24.04-arm
|
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: ubuntu-24.04-arm
|
- os: macos-latest
|
||||||
python: '3.10'
|
python: '3.10'
|
||||||
- os: ubuntu-24.04-arm
|
- os: macos-latest
|
||||||
python: '3.11'
|
python: '3.11'
|
||||||
- os: ubuntu-24.04-arm
|
- os: macos-latest
|
||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-24.04-arm
|
- os: macos-latest
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
- os: macos-14
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.10'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.11'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.12'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.13'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.10'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.11'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.12'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.13'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.10'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.11'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.12'
|
|
||||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
|
||||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.ref }}
|
ref: ${{ inputs.ref }}
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
@@ -109,56 +78,21 @@ jobs:
|
|||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v6
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
- name: Install system dependencies (Ubuntu)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||||
patchelf
|
|
||||||
|
|
||||||
# Debug: Show system information
|
# Install Intel MKL for DiskANN
|
||||||
echo "🔍 System Information:"
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
echo "Architecture: $(uname -m)"
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
echo "OS: $(uname -a)"
|
source /opt/intel/oneapi/setvars.sh
|
||||||
echo "CPU info: $(lscpu | head -5)"
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||||
# Install math library based on architecture
|
|
||||||
ARCH=$(uname -m)
|
|
||||||
echo "🔍 Setting up math library for architecture: $ARCH"
|
|
||||||
|
|
||||||
if [[ "$ARCH" == "x86_64" ]]; then
|
|
||||||
# Install Intel MKL for DiskANN on x86_64
|
|
||||||
echo "📦 Installing Intel MKL for x86_64..."
|
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
|
||||||
source /opt/intel/oneapi/setvars.sh
|
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
|
||||||
echo "✅ Intel MKL installed for x86_64"
|
|
||||||
|
|
||||||
# Debug: Check MKL installation
|
|
||||||
echo "🔍 MKL Installation Check:"
|
|
||||||
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
|
||||||
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
|
||||||
|
|
||||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
|
||||||
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
|
||||||
echo "📦 Installing OpenBLAS for ARM64..."
|
|
||||||
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
|
||||||
echo "✅ OpenBLAS installed for ARM64"
|
|
||||||
|
|
||||||
# Debug: Check OpenBLAS installation
|
|
||||||
echo "🔍 OpenBLAS Installation Check:"
|
|
||||||
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
|
||||||
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Debug: Show final library paths
|
|
||||||
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
- name: Install system dependencies (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
@@ -175,73 +109,48 @@ jobs:
|
|||||||
uv pip install --system delocate
|
uv pip install --system delocate
|
||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Set macOS environment variables
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
# Use brew --prefix to automatically detect Homebrew installation path
|
|
||||||
HOMEBREW_PREFIX=$(brew --prefix)
|
|
||||||
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
|
||||||
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
|
||||||
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
# Set compiler flags for OpenMP (required for both backends)
|
|
||||||
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
|
||||||
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Build packages
|
- name: Build packages
|
||||||
run: |
|
run: |
|
||||||
# Build core (platform independent)
|
# Build core (platform independent)
|
||||||
cd packages/leann-core
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
uv build
|
cd packages/leann-core
|
||||||
cd ../..
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
# Build HNSW backend
|
# Build HNSW backend
|
||||||
cd packages/leann-backend-hnsw
|
cd packages/leann-backend-hnsw
|
||||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
# Use system clang for better compatibility
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# Homebrew libraries on each macOS version require matching minimum version
|
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
uv build --wheel --python python
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
|
||||||
fi
|
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
|
||||||
else
|
else
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python python
|
||||||
fi
|
fi
|
||||||
cd ../..
|
cd ../..
|
||||||
|
|
||||||
# Build DiskANN backend
|
# Build DiskANN backend
|
||||||
cd packages/leann-backend-diskann
|
cd packages/leann-backend-diskann
|
||||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
# Use system clang for better compatibility
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
# But Homebrew libraries on each macOS version require matching minimum version
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
uv build --wheel --python python
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
|
||||||
fi
|
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
|
||||||
else
|
else
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
uv build --wheel --python python
|
||||||
fi
|
fi
|
||||||
cd ../..
|
cd ../..
|
||||||
|
|
||||||
# Build meta package (platform independent)
|
# Build meta package (platform independent)
|
||||||
cd packages/leann
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
uv build
|
cd packages/leann
|
||||||
cd ../..
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Repair wheels (Linux)
|
- name: Repair wheels (Linux)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
@@ -267,24 +176,10 @@ jobs:
|
|||||||
- name: Repair wheels (macOS)
|
- name: Repair wheels (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
run: |
|
run: |
|
||||||
# Determine deployment target based on runner OS
|
|
||||||
# Must match the Homebrew libraries for each macOS version
|
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
|
||||||
HNSW_TARGET="13.0"
|
|
||||||
DISKANN_TARGET="13.3"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
HNSW_TARGET="14.0"
|
|
||||||
DISKANN_TARGET="14.0"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
HNSW_TARGET="15.0"
|
|
||||||
DISKANN_TARGET="15.0"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Repair HNSW wheel
|
# Repair HNSW wheel
|
||||||
cd packages/leann-backend-hnsw
|
cd packages/leann-backend-hnsw
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -293,8 +188,7 @@ jobs:
|
|||||||
# Repair DiskANN wheel
|
# Repair DiskANN wheel
|
||||||
cd packages/leann-backend-diskann
|
cd packages/leann-backend-diskann
|
||||||
if [ -d dist ]; then
|
if [ -d dist ]; then
|
||||||
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
mv dist_repaired dist
|
mv dist_repaired dist
|
||||||
fi
|
fi
|
||||||
@@ -305,34 +199,39 @@ jobs:
|
|||||||
echo "📦 Built packages:"
|
echo "📦 Built packages:"
|
||||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
|
||||||
- name: Install built packages for testing
|
- name: Install built packages for testing
|
||||||
run: |
|
run: |
|
||||||
# Create a virtual environment with the correct Python version
|
# Create a virtual environment
|
||||||
uv venv --python ${{ matrix.python }}
|
uv venv
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
# Install packages using --find-links to prioritize local builds
|
# Install the built wheels
|
||||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
# Use --find-links to let uv choose the correct wheel for the platform
|
||||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
uv pip install leann-core --find-links packages/leann-core/dist
|
||||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
uv pip install leann --find-links packages/leann/dist
|
||||||
|
fi
|
||||||
|
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
||||||
|
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
||||||
|
|
||||||
# Install test dependencies using extras
|
# Install test dependencies using extras
|
||||||
uv pip install -e ".[test]"
|
uv pip install -e ".[test]"
|
||||||
|
|
||||||
- name: Run tests with pytest
|
- name: Run tests with pytest
|
||||||
env:
|
env:
|
||||||
CI: true
|
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
HF_HUB_DISABLE_SYMLINKS: 1
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
TOKENIZERS_PARALLELISM: false
|
TOKENIZERS_PARALLELISM: false
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||||
OMP_NUM_THREADS: 1
|
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||||
MKL_NUM_THREADS: 1
|
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||||
run: |
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
pytest tests/ -v --tb=short
|
|
||||||
|
# Run all tests
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
- name: Run sanity checks (optional)
|
- name: Run sanity checks (optional)
|
||||||
run: |
|
run: |
|
||||||
@@ -350,53 +249,3 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
path: packages/*/dist/
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
arch-smoke:
|
|
||||||
name: Arch Linux smoke test (install & import)
|
|
||||||
needs: build
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
container:
|
|
||||||
image: archlinux:latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Prepare system
|
|
||||||
run: |
|
|
||||||
pacman -Syu --noconfirm
|
|
||||||
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
|
||||||
|
|
||||||
- name: Download ALL wheel artifacts from this run
|
|
||||||
uses: actions/download-artifact@v5
|
|
||||||
with:
|
|
||||||
# Don't specify name, download all artifacts
|
|
||||||
path: ./wheels
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v6
|
|
||||||
|
|
||||||
- name: Create virtual environment and install wheels
|
|
||||||
run: |
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
|
||||||
uv pip install --find-links wheels leann-core
|
|
||||||
uv pip install --find-links wheels leann-backend-hnsw
|
|
||||||
uv pip install --find-links wheels leann-backend-diskann
|
|
||||||
uv pip install --find-links wheels leann
|
|
||||||
|
|
||||||
- name: Import & tiny runtime check
|
|
||||||
env:
|
|
||||||
OMP_NUM_THREADS: 1
|
|
||||||
MKL_NUM_THREADS: 1
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
|
||||||
python - <<'PY'
|
|
||||||
import leann
|
|
||||||
import leann_backend_hnsw as h
|
|
||||||
import leann_backend_diskann as d
|
|
||||||
from leann import LeannBuilder, LeannSearcher
|
|
||||||
b = LeannBuilder(backend_name="hnsw")
|
|
||||||
b.add_text("hello arch")
|
|
||||||
b.build_index("arch_demo.leann")
|
|
||||||
s = LeannSearcher("arch_demo.leann")
|
|
||||||
print("search:", s.search("hello", top_k=1))
|
|
||||||
PY
|
|
||||||
|
|||||||
19
.github/workflows/link-check.yml
vendored
19
.github/workflows/link-check.yml
vendored
@@ -1,19 +0,0 @@
|
|||||||
name: Link Check
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main, master ]
|
|
||||||
pull_request:
|
|
||||||
schedule:
|
|
||||||
- cron: "0 3 * * 1"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
link-check:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: lycheeverse/lychee-action@v2
|
|
||||||
with:
|
|
||||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
22
.gitignore
vendored
22
.gitignore
vendored
@@ -18,7 +18,6 @@ demo/experiment_results/**/*.json
|
|||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
!.vscode/*.json
|
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
@@ -35,15 +34,11 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
data/*
|
examples/data/*
|
||||||
!data/2501.14312v1 (1).pdf
|
!examples/data/2501.14312v1 (1).pdf
|
||||||
!data/2506.08276v1.pdf
|
!examples/data/2506.08276v1.pdf
|
||||||
!data/PrideandPrejudice.txt
|
!examples/data/PrideandPrejudice.txt
|
||||||
!data/huawei_pangu.md
|
!examples/data/README.md
|
||||||
!data/ground_truth/
|
|
||||||
!data/indices/
|
|
||||||
!data/queries/
|
|
||||||
!data/.gitattributes
|
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -93,10 +88,3 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
tests/__pytest_cache__/
|
||||||
tests/__pycache__/
|
tests/__pycache__/
|
||||||
paru-bin/
|
|
||||||
|
|
||||||
CLAUDE.md
|
|
||||||
CLAUDE.local.md
|
|
||||||
.claude/*.local.*
|
|
||||||
.claude/local/*
|
|
||||||
benchmarks/data/
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v4.5.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -10,8 +10,7 @@ repos:
|
|||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
rev: v0.2.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix, --exit-non-zero-on-fix]
|
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
5
.vscode/extensions.json
vendored
5
.vscode/extensions.json
vendored
@@ -1,5 +0,0 @@
|
|||||||
{
|
|
||||||
"recommendations": [
|
|
||||||
"charliermarsh.ruff",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
22
.vscode/settings.json
vendored
22
.vscode/settings.json
vendored
@@ -1,22 +0,0 @@
|
|||||||
{
|
|
||||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
|
||||||
"python.terminal.activateEnvironment": true,
|
|
||||||
"[python]": {
|
|
||||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
|
||||||
"editor.formatOnSave": true,
|
|
||||||
"editor.codeActionsOnSave": {
|
|
||||||
"source.organizeImports": "explicit",
|
|
||||||
"source.fixAll": "explicit"
|
|
||||||
},
|
|
||||||
"editor.insertSpaces": true,
|
|
||||||
"editor.tabSize": 4
|
|
||||||
},
|
|
||||||
"ruff.enable": true,
|
|
||||||
"files.watcherExclude": {
|
|
||||||
"**/.venv/**": true,
|
|
||||||
"**/__pycache__/**": true,
|
|
||||||
"**/*.egg-info/**": true,
|
|
||||||
"**/build/**": true,
|
|
||||||
"**/dist/**": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
453
README.md
453
README.md
@@ -3,13 +3,9 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
|
||||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||||
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
|
||||||
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -20,10 +16,7 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
|||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +26,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
@@ -48,111 +41,64 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### 📦 Prerequisites: Install uv
|
<details>
|
||||||
|
<summary><strong>📦 Prerequisites: Install uv (if you don't have it)</strong></summary>
|
||||||
|
|
||||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
Install uv first if you don't have it:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚀 Quick Install
|
📖 [Detailed uv installation methods →](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
|
||||||
|
|
||||||
Clone the repository to access all examples and try amazing applications,
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
LEANN provides two installation methods: **pip install** (quick and easy) and **build from source** (recommended for development).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 🚀 Quick Install (Recommended for most users)
|
||||||
|
|
||||||
|
Clone the repository to access all examples and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
cd leann
|
cd leann
|
||||||
```
|
|
||||||
|
|
||||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv venv
|
uv venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install leann
|
uv pip install leann
|
||||||
```
|
```
|
||||||
<!--
|
|
||||||
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>
|
|
||||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
|
||||||
</summary>
|
|
||||||
|
|
||||||
|
|
||||||
|
### 🔧 Build from Source (Recommended for development)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
cd leann
|
cd leann
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
```
|
```
|
||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
|
|
||||||
Note: DiskANN requires MacOS 13.3 or later.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
brew install libomp boost protobuf zeromq pkgconf
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
uv sync --extra diskann
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux (Ubuntu/Debian):**
|
**Linux:**
|
||||||
|
|
||||||
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
|
||||||
|
|
||||||
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get update && sudo apt-get install -y \
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
uv sync
|
||||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
|
||||||
libmkl-full-dev
|
|
||||||
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux (Arch Linux):**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
|
||||||
boost boost-libs protobuf abseil-cpp libaio zeromq
|
|
||||||
|
|
||||||
# For MKL in DiskANN
|
|
||||||
sudo pacman -S --needed base-devel git
|
|
||||||
git clone https://aur.archlinux.org/paru-bin.git
|
|
||||||
cd paru-bin && makepkg -si
|
|
||||||
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
|
||||||
source /opt/intel/oneapi/setvars.sh
|
|
||||||
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
|
||||||
|
|
||||||
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
|
||||||
|
|
||||||
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sudo dnf groupinstall -y "Development Tools"
|
|
||||||
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
|
||||||
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
|
||||||
|
|
||||||
# For MKL in DiskANN
|
|
||||||
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
|
||||||
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
|
||||||
source /opt/intel/oneapi/setvars.sh
|
|
||||||
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Our declarative API makes RAG as easy as writing a config file.
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
|
|
||||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
[](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb) [Try in this ipynb file →](demo.ipynb)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
@@ -176,13 +122,11 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
## RAG on Everything!
|
## RAG on Everything!
|
||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
|
> **Generation Model Setup**
|
||||||
### Generation Model Setup
|
> LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||||
|
|
||||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||||
@@ -222,53 +166,7 @@ ollama pull llama3.2:1b
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||||
## ⭐ Flexible Configuration
|
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
|
||||||
|
|
||||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
|
||||||
|
|
||||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Core Parameters (General preprocessing for all examples)
|
|
||||||
--index-dir DIR # Directory to store the index (default: current directory)
|
|
||||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
|
||||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
|
||||||
--force-rebuild # Force rebuild index even if it exists
|
|
||||||
|
|
||||||
# Embedding Parameters
|
|
||||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
|
||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
|
||||||
|
|
||||||
# Search Parameters
|
|
||||||
--top-k N # Number of results to retrieve (default: 20)
|
|
||||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
|
||||||
|
|
||||||
# Chunking Parameters
|
|
||||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
|
||||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
|
||||||
|
|
||||||
# Index Building Parameters
|
|
||||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
|
||||||
--graph-degree N # Graph degree for index construction (default: 32)
|
|
||||||
--build-complexity N # Build complexity for index construction (default: 64)
|
|
||||||
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
|
||||||
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
|
||||||
|
|
||||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||||
|
|
||||||
@@ -276,35 +174,25 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
|||||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
The example below asks a question about summarizing two papers (uses default data in `examples/data`) and this is the easiest example to run here:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
source .venv/bin/activate
|
||||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
python ./examples/main_cli_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
```bash
|
||||||
--data-dir DIR # Directory containing documents to process (default: data)
|
# Use custom index directory
|
||||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
python examples/main_cli_example.py --index-dir "./my_custom_index"
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
# Use custom data directory
|
||||||
```bash
|
python examples/main_cli_example.py --data-dir "./my_documents"
|
||||||
# Process all documents with larger chunks for academic papers
|
|
||||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
|
||||||
|
|
||||||
# Filter only markdown and Python files with smaller chunks
|
# Ask a specific question
|
||||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
python examples/main_cli_example.py --query "What are the main findings in these papers?"
|
||||||
|
|
||||||
# Enable AST-aware chunking for code files
|
|
||||||
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
|
||||||
|
|
||||||
# Or use the specialized code RAG for better code understanding
|
|
||||||
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -318,29 +206,30 @@ python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authenticat
|
|||||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
python examples/mail_reader_leann.py --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||||
```
|
```
|
||||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
```bash
|
||||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
# Use default mail path (works for most macOS setups)
|
||||||
--include-html # Include HTML content in processing (useful for newsletters)
|
python examples/mail_reader_leann.py
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
# Run with custom index directory
|
||||||
```bash
|
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||||
# Search work emails from a specific account
|
|
||||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
|
||||||
|
|
||||||
# Find all receipts and order confirmations (includes HTML)
|
# Process all emails (may take time but indexes everything)
|
||||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
python examples/mail_reader_leann.py --max-emails -1
|
||||||
|
|
||||||
|
# Limit number of emails processed (useful for testing)
|
||||||
|
python examples/mail_reader_leann.py --max-emails 1000
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -361,25 +250,25 @@ Once the index is built, you can ask questions like:
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||||
```
|
```
|
||||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
```bash
|
||||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
# Use default Chrome profile (auto-finds all profiles)
|
||||||
```
|
python examples/google_history_reader_leann.py
|
||||||
|
|
||||||
#### Example Commands
|
# Run with custom index directory
|
||||||
```bash
|
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||||
# Search academic research from your browsing history
|
|
||||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
|
||||||
|
|
||||||
# Track competitor analysis across work profile
|
# Limit number of history entries processed (useful for testing)
|
||||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
python examples/google_history_reader_leann.py --max-entries 500
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -419,7 +308,7 @@ Once the index is built, you can ask questions like:
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||||
```
|
```
|
||||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||||
|
|
||||||
@@ -427,13 +316,7 @@ python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
|||||||
<details>
|
<details>
|
||||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
|
|
||||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
First, you need to install the WeChat exporter:
|
||||||
|
|
||||||
```bash
|
|
||||||
brew install sunnyyoung/repo/wechattweak-cli
|
|
||||||
```
|
|
||||||
|
|
||||||
or install it manually (if you have issues with Homebrew):
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo packages/wechat-exporter/wechattweak-cli install
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
@@ -442,28 +325,30 @@ sudo packages/wechat-exporter/wechattweak-cli install
|
|||||||
**Troubleshooting:**
|
**Troubleshooting:**
|
||||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||||
```bash
|
```
|
||||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||||
Failed to find or export WeChat data. Exiting.
|
Failed to find or export WeChat data. Exiting.
|
||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
<summary><strong>📋 Click to expand: User Configurable Arguments</strong></summary>
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
```bash
|
||||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
# Use default settings (recommended for first run)
|
||||||
--force-export # Force re-export even if data exists
|
python examples/wechat_history_reader_leann.py
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||||
```bash
|
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
|
||||||
# Search for travel plans discussed in group chats
|
|
||||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
|
||||||
|
|
||||||
# Re-export and search recent chats (useful after new messages)
|
# Run with custom index directory
|
||||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
|
||||||
|
|
||||||
|
# Limit number of chat entries processed (useful for testing)
|
||||||
|
python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -477,68 +362,15 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
|
||||||
|
|
||||||
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
|
||||||
|
|
||||||
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
|
||||||
|
|
||||||
**Key features:**
|
|
||||||
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
|
||||||
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
|
||||||
- 📚 **Context-aware assistance** for debugging and development
|
|
||||||
- 🚀 **Zero-config setup** with automatic language detection
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install LEANN globally for MCP integration
|
|
||||||
uv tool install leann-core --with leann
|
|
||||||
claude mcp add --scope user leann-server -- leann_mcp
|
|
||||||
# Setup is automatic - just start using Claude Code!
|
|
||||||
```
|
|
||||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
|
||||||
|
|
||||||
## 🖥️ Command Line Interface
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate
|
# Build an index from documents
|
||||||
leann --help
|
leann build my-docs --docs ./documents
|
||||||
```
|
|
||||||
|
|
||||||
**To make it globally available:**
|
|
||||||
```bash
|
|
||||||
# Install the LEANN CLI globally using uv tool
|
|
||||||
uv tool install leann-core --with leann
|
|
||||||
|
|
||||||
|
|
||||||
# Now you can use leann from anywhere without activating venv
|
|
||||||
leann --help
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Usage Examples
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
|
|
||||||
leann build my-docs --docs ./your_documents
|
|
||||||
|
|
||||||
# Search your documents
|
# Search your documents
|
||||||
leann search my-docs "machine learning concepts"
|
leann search my-docs "machine learning concepts"
|
||||||
@@ -548,36 +380,30 @@ leann ask my-docs --interactive
|
|||||||
|
|
||||||
# List all your indexes
|
# List all your indexes
|
||||||
leann list
|
leann list
|
||||||
|
|
||||||
# Remove an index
|
|
||||||
leann remove my-docs
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key CLI features:**
|
**Key CLI features:**
|
||||||
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||||
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
- Smart text chunking with overlap
|
||||||
- Smart text chunking with overlap for all other content
|
|
||||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
- Organized index storage in `.leann/indexes/` (project-local)
|
- Organized index storage in `~/.leann/indexes/`
|
||||||
- Support for advanced search parameters
|
- Support for advanced search parameters
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
|
|
||||||
|
|
||||||
**Build Command:**
|
**Build Command:**
|
||||||
```bash
|
```bash
|
||||||
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
|
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
--graph-degree N Graph degree (default: 32)
|
--graph-degree N Graph degree (default: 32)
|
||||||
--complexity N Build complexity (default: 64)
|
--complexity N Build complexity (default: 64)
|
||||||
--force Force rebuild existing index
|
--force Force rebuild existing index
|
||||||
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
--compact Use compact storage (default: true)
|
||||||
--recompute / --no-recompute Enable recomputation (default: true)
|
--recompute Enable recomputation (default: true)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Search Command:**
|
**Search Command:**
|
||||||
@@ -585,9 +411,9 @@ Options:
|
|||||||
leann search INDEX_NAME QUERY [OPTIONS]
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--top-k N Number of results (default: 5)
|
--top-k N Number of results (default: 5)
|
||||||
--complexity N Search complexity (default: 64)
|
--complexity N Search complexity (default: 64)
|
||||||
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
--recompute-embeddings Use recomputation for highest accuracy
|
||||||
--pruning-strategy {global,local,proportional}
|
--pruning-strategy {global,local,proportional}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -602,60 +428,8 @@ Options:
|
|||||||
--top-k N Retrieval count (default: 20)
|
--top-k N Retrieval count (default: 20)
|
||||||
```
|
```
|
||||||
|
|
||||||
**List Command:**
|
|
||||||
```bash
|
|
||||||
leann list
|
|
||||||
|
|
||||||
# Lists all indexes across all projects with status indicators:
|
|
||||||
# ✅ - Index is complete and ready to use
|
|
||||||
# ❌ - Index is incomplete or corrupted
|
|
||||||
# 📁 - CLI-created index (in .leann/indexes/)
|
|
||||||
# 📄 - App-created index (*.leann.meta.json files)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Remove Command:**
|
|
||||||
```bash
|
|
||||||
leann remove INDEX_NAME [OPTIONS]
|
|
||||||
|
|
||||||
Options:
|
|
||||||
--force, -f Force removal without confirmation
|
|
||||||
|
|
||||||
# Smart removal: automatically finds and safely removes indexes
|
|
||||||
# - Shows all matching indexes across projects
|
|
||||||
# - Requires confirmation for cross-project removal
|
|
||||||
# - Interactive selection when multiple matches found
|
|
||||||
# - Supports both CLI and app-created indexes
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## 🚀 Advanced Features
|
|
||||||
|
|
||||||
### 🎯 Metadata Filtering
|
|
||||||
|
|
||||||
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Add metadata during indexing
|
|
||||||
builder.add_text(
|
|
||||||
"def authenticate_user(token): ...",
|
|
||||||
metadata={"file_extension": ".py", "lines_of_code": 25}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search with filters
|
|
||||||
results = searcher.search(
|
|
||||||
query="authentication function",
|
|
||||||
metadata_filters={
|
|
||||||
"file_extension": {"==": ".py"},
|
|
||||||
"lines_of_code": {"<": 100}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
|
||||||
|
|
||||||
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -670,17 +444,13 @@ results = searcher.search(
|
|||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
**Backends:**
|
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
|
||||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
|
||||||
|
|
||||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||||
|
### Storage Comparison
|
||||||
### 📊 Storage Comparison
|
|
||||||
|
|
||||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
|--------|-------------|------------|-------------|--------------|---------------|
|
|--------|-------------|------------|-------------|--------------|---------------|
|
||||||
@@ -694,8 +464,8 @@ results = searcher.search(
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||||
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||||
```
|
```
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
@@ -733,18 +503,9 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||||
|
---
|
||||||
|
|
||||||
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
|
||||||
|
|
||||||
## Star History
|
|
||||||
|
|
||||||
[](https://www.star-history.com/#yichuan-w/LEANN&Date)
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|||||||
@@ -1,342 +0,0 @@
|
|||||||
"""
|
|
||||||
Base class for unified RAG examples interface.
|
|
||||||
Provides common parameters and functionality for all RAG examples.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import dotenv
|
|
||||||
from leann.api import LeannBuilder, LeannChat
|
|
||||||
from leann.registry import register_project_directory
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRAGExample(ABC):
|
|
||||||
"""Base class for all RAG examples with unified interface."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
default_index_name: str,
|
|
||||||
):
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.default_index_name = default_index_name
|
|
||||||
self.parser = self._create_parser()
|
|
||||||
|
|
||||||
def _create_parser(self) -> argparse.ArgumentParser:
|
|
||||||
"""Create argument parser with common parameters."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Core parameters (all examples share these)
|
|
||||||
core_group = parser.add_argument_group("Core Parameters")
|
|
||||||
core_group.add_argument(
|
|
||||||
"--index-dir",
|
|
||||||
type=str,
|
|
||||||
default=f"./{self.default_index_name}",
|
|
||||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
|
||||||
)
|
|
||||||
core_group.add_argument(
|
|
||||||
"--query",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Query to run (if not provided, will run in interactive mode)",
|
|
||||||
)
|
|
||||||
# Allow subclasses to override default max_items
|
|
||||||
max_items_default = getattr(self, "max_items_default", -1)
|
|
||||||
core_group.add_argument(
|
|
||||||
"--max-items",
|
|
||||||
type=int,
|
|
||||||
default=max_items_default,
|
|
||||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
|
||||||
)
|
|
||||||
core_group.add_argument(
|
|
||||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Embedding parameters
|
|
||||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
|
||||||
# Allow subclasses to override default embedding_model
|
|
||||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
|
||||||
embedding_group.add_argument(
|
|
||||||
"--embedding-model",
|
|
||||||
type=str,
|
|
||||||
default=embedding_model_default,
|
|
||||||
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
|
||||||
)
|
|
||||||
embedding_group.add_argument(
|
|
||||||
"--embedding-mode",
|
|
||||||
type=str,
|
|
||||||
default="sentence-transformers",
|
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
|
||||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
|
||||||
)
|
|
||||||
|
|
||||||
# LLM parameters
|
|
||||||
llm_group = parser.add_argument_group("LLM Parameters")
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--llm",
|
|
||||||
type=str,
|
|
||||||
default="openai",
|
|
||||||
choices=["openai", "ollama", "hf", "simulated"],
|
|
||||||
help="LLM backend: openai, ollama, or hf (default: openai)",
|
|
||||||
)
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--llm-model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
)
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--llm-host",
|
|
||||||
type=str,
|
|
||||||
default="http://localhost:11434",
|
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
|
||||||
)
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--thinking-budget",
|
|
||||||
type=str,
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
default=None,
|
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# AST Chunking parameters
|
|
||||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
|
||||||
ast_group.add_argument(
|
|
||||||
"--use-ast-chunking",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable AST-aware chunking for code files (requires astchunk)",
|
|
||||||
)
|
|
||||||
ast_group.add_argument(
|
|
||||||
"--ast-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Maximum characters per AST chunk (default: 512)",
|
|
||||||
)
|
|
||||||
ast_group.add_argument(
|
|
||||||
"--ast-chunk-overlap",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="Overlap between AST chunks (default: 64)",
|
|
||||||
)
|
|
||||||
ast_group.add_argument(
|
|
||||||
"--code-file-extensions",
|
|
||||||
nargs="+",
|
|
||||||
default=None,
|
|
||||||
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
|
||||||
)
|
|
||||||
ast_group.add_argument(
|
|
||||||
"--ast-fallback-traditional",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search parameters
|
|
||||||
search_group = parser.add_argument_group("Search Parameters")
|
|
||||||
search_group.add_argument(
|
|
||||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
|
||||||
)
|
|
||||||
search_group.add_argument(
|
|
||||||
"--search-complexity",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Search complexity for graph traversal (default: 64)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index building parameters
|
|
||||||
index_group = parser.add_argument_group("Index Building Parameters")
|
|
||||||
index_group.add_argument(
|
|
||||||
"--backend-name",
|
|
||||||
type=str,
|
|
||||||
default="hnsw",
|
|
||||||
choices=["hnsw", "diskann"],
|
|
||||||
help="Backend to use for index (default: hnsw)",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--graph-degree",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Graph degree for index construction (default: 32)",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--build-complexity",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="Build complexity for index construction (default: 64)",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--no-compact",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable compact index storage",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--no-recompute",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable embedding recomputation",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add source-specific parameters
|
|
||||||
self._add_specific_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
||||||
"""Add source-specific arguments. Override in subclasses."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load data from the source. Returns list of text chunks."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_llm_config(self, args) -> dict[str, Any]:
|
|
||||||
"""Get LLM configuration based on arguments."""
|
|
||||||
config = {"type": args.llm}
|
|
||||||
|
|
||||||
if args.llm == "openai":
|
|
||||||
config["model"] = args.llm_model or "gpt-4o"
|
|
||||||
elif args.llm == "ollama":
|
|
||||||
config["model"] = args.llm_model or "llama3.2:1b"
|
|
||||||
config["host"] = args.llm_host
|
|
||||||
elif args.llm == "hf":
|
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
|
||||||
elif args.llm == "simulated":
|
|
||||||
# Simulated LLM doesn't need additional configuration
|
|
||||||
pass
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
async def build_index(self, args, texts: list[str]) -> str:
|
|
||||||
"""Build LEANN index from texts."""
|
|
||||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
|
||||||
|
|
||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
|
||||||
print(f"Total text chunks: {len(texts)}")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=args.backend_name,
|
|
||||||
embedding_model=args.embedding_model,
|
|
||||||
embedding_mode=args.embedding_mode,
|
|
||||||
graph_degree=args.graph_degree,
|
|
||||||
complexity=args.build_complexity,
|
|
||||||
is_compact=not args.no_compact,
|
|
||||||
is_recompute=not args.no_recompute,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add texts in batches for better progress tracking
|
|
||||||
batch_size = 1000
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch = texts[i : i + batch_size]
|
|
||||||
for text in batch:
|
|
||||||
builder.add_text(text)
|
|
||||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
|
||||||
|
|
||||||
print("Building index structure...")
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"Index saved to: {index_path}")
|
|
||||||
|
|
||||||
# Register project directory so leann list can discover this index
|
|
||||||
# The index is saved as args.index_dir/index_name.leann
|
|
||||||
# We want to register the current working directory where the app is run
|
|
||||||
register_project_directory(Path.cwd())
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def run_interactive_chat(self, args, index_path: str):
|
|
||||||
"""Run interactive chat with the index."""
|
|
||||||
chat = LeannChat(
|
|
||||||
index_path,
|
|
||||||
llm_config=self.get_llm_config(args),
|
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
|
||||||
print("Type 'quit' or 'exit' to stop.\n")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
query = input("You: ").strip()
|
|
||||||
if query.lower() in ["quit", "exit", "q"]:
|
|
||||||
print("Goodbye!")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not query:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"\nAssistant: {response}\n")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nGoodbye!")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
async def run_single_query(self, args, index_path: str, query: str):
|
|
||||||
"""Run a single query against the index."""
|
|
||||||
chat = LeannChat(
|
|
||||||
index_path,
|
|
||||||
llm_config=self.get_llm_config(args),
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
|
||||||
)
|
|
||||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main entry point for the example."""
|
|
||||||
args = self.parser.parse_args()
|
|
||||||
|
|
||||||
# Check if index exists
|
|
||||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
|
||||||
index_exists = Path(args.index_dir).exists()
|
|
||||||
|
|
||||||
if not index_exists or args.force_rebuild:
|
|
||||||
# Load data and build index
|
|
||||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
|
||||||
texts = await self.load_data(args)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
print("No data found to index!")
|
|
||||||
return
|
|
||||||
|
|
||||||
index_path = await self.build_index(args, texts)
|
|
||||||
else:
|
|
||||||
print(f"\nUsing existing index in {args.index_dir}")
|
|
||||||
|
|
||||||
# Run query or interactive mode
|
|
||||||
if args.query:
|
|
||||||
await self.run_single_query(args, index_path, args.query)
|
|
||||||
else:
|
|
||||||
await self.run_interactive_chat(args, index_path)
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
"""
|
|
||||||
Browser History RAG example using the unified interface.
|
|
||||||
Supports Chrome browser history.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
from .history_data.history import ChromeHistoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class BrowserRAG(BaseRAGExample):
|
|
||||||
"""RAG example for Chrome browser history."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="Browser History",
|
|
||||||
description="Process and query Chrome browser history with LEANN",
|
|
||||||
default_index_name="google_history_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add browser-specific arguments."""
|
|
||||||
browser_group = parser.add_argument_group("Browser Parameters")
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--chrome-profile",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
|
||||||
)
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--auto-find-profiles",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Automatically find all Chrome profiles (default: True)",
|
|
||||||
)
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
|
||||||
)
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_chrome_base_path(self) -> Path:
|
|
||||||
"""Get the base Chrome profile path based on OS."""
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
|
||||||
elif sys.platform.startswith("linux"):
|
|
||||||
return Path.home() / ".config" / "google-chrome"
|
|
||||||
elif sys.platform == "win32":
|
|
||||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
|
||||||
|
|
||||||
def _find_chrome_profiles(self) -> list[Path]:
|
|
||||||
"""Auto-detect all Chrome profiles."""
|
|
||||||
base_path = self._get_chrome_base_path()
|
|
||||||
if not base_path.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
profiles = []
|
|
||||||
|
|
||||||
# Check Default profile
|
|
||||||
default_profile = base_path / "Default"
|
|
||||||
if default_profile.exists() and (default_profile / "History").exists():
|
|
||||||
profiles.append(default_profile)
|
|
||||||
|
|
||||||
# Check numbered profiles
|
|
||||||
for item in base_path.iterdir():
|
|
||||||
if item.is_dir() and item.name.startswith("Profile "):
|
|
||||||
if (item / "History").exists():
|
|
||||||
profiles.append(item)
|
|
||||||
|
|
||||||
return profiles
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load browser history and convert to text chunks."""
|
|
||||||
# Determine Chrome profiles
|
|
||||||
if args.chrome_profile and not args.auto_find_profiles:
|
|
||||||
profile_dirs = [Path(args.chrome_profile)]
|
|
||||||
else:
|
|
||||||
print("Auto-detecting Chrome profiles...")
|
|
||||||
profile_dirs = self._find_chrome_profiles()
|
|
||||||
|
|
||||||
# If specific profile given, filter to just that one
|
|
||||||
if args.chrome_profile:
|
|
||||||
profile_path = Path(args.chrome_profile)
|
|
||||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
|
||||||
|
|
||||||
if not profile_dirs:
|
|
||||||
print("No Chrome profiles found!")
|
|
||||||
print("Please specify --chrome-profile manually")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
|
||||||
|
|
||||||
# Create reader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
# Process each profile
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, profile_dir in enumerate(profile_dirs):
|
|
||||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per profile
|
|
||||||
max_per_profile = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_profile = remaining
|
|
||||||
|
|
||||||
# Load history
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=str(profile_dir),
|
|
||||||
max_count=max_per_profile,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} history entries from this profile")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {profile_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No browser history found to process!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for browser history RAG
|
|
||||||
print("\n🌐 Browser History RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What websites did I visit about machine learning?'")
|
|
||||||
print("- 'Find my search history about programming'")
|
|
||||||
print("- 'What YouTube videos did I watch recently?'")
|
|
||||||
print("- 'Show me websites about travel planning'")
|
|
||||||
print("\nNote: Make sure Chrome is closed before running\n")
|
|
||||||
|
|
||||||
rag = BrowserRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
Chunking utilities for LEANN RAG applications.
|
|
||||||
Provides AST-aware and traditional text chunking functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
CODE_EXTENSIONS,
|
|
||||||
create_ast_chunks,
|
|
||||||
create_text_chunks,
|
|
||||||
create_traditional_chunks,
|
|
||||||
detect_code_files,
|
|
||||||
get_language_from_extension,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"CODE_EXTENSIONS",
|
|
||||||
"create_ast_chunks",
|
|
||||||
"create_text_chunks",
|
|
||||||
"create_traditional_chunks",
|
|
||||||
"detect_code_files",
|
|
||||||
"get_language_from_extension",
|
|
||||||
]
|
|
||||||
@@ -1,320 +0,0 @@
|
|||||||
"""
|
|
||||||
Enhanced chunking utilities with AST-aware code chunking support.
|
|
||||||
Provides unified interface for both traditional and AST-based text chunking.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Code file extensions supported by astchunk
|
|
||||||
CODE_EXTENSIONS = {
|
|
||||||
".py": "python",
|
|
||||||
".java": "java",
|
|
||||||
".cs": "csharp",
|
|
||||||
".ts": "typescript",
|
|
||||||
".tsx": "typescript",
|
|
||||||
".js": "typescript",
|
|
||||||
".jsx": "typescript",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default chunk parameters for different content types
|
|
||||||
DEFAULT_CHUNK_PARAMS = {
|
|
||||||
"code": {
|
|
||||||
"max_chunk_size": 512,
|
|
||||||
"chunk_overlap": 64,
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"chunk_size": 256,
|
|
||||||
"chunk_overlap": 128,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|
||||||
"""
|
|
||||||
Separate documents into code files and regular text files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of LlamaIndex Document objects
|
|
||||||
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (code_documents, text_documents)
|
|
||||||
"""
|
|
||||||
if code_extensions is None:
|
|
||||||
code_extensions = CODE_EXTENSIONS
|
|
||||||
|
|
||||||
code_docs = []
|
|
||||||
text_docs = []
|
|
||||||
|
|
||||||
for doc in documents:
|
|
||||||
# Get file path from metadata
|
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
if not file_path:
|
|
||||||
# Fallback to file_name
|
|
||||||
file_path = doc.metadata.get("file_name", "")
|
|
||||||
|
|
||||||
if file_path:
|
|
||||||
file_ext = Path(file_path).suffix.lower()
|
|
||||||
if file_ext in code_extensions:
|
|
||||||
# Add language info to metadata
|
|
||||||
doc.metadata["language"] = code_extensions[file_ext]
|
|
||||||
doc.metadata["is_code"] = True
|
|
||||||
code_docs.append(doc)
|
|
||||||
else:
|
|
||||||
doc.metadata["is_code"] = False
|
|
||||||
text_docs.append(doc)
|
|
||||||
else:
|
|
||||||
# If no file path, treat as text
|
|
||||||
doc.metadata["is_code"] = False
|
|
||||||
text_docs.append(doc)
|
|
||||||
|
|
||||||
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
|
||||||
return code_docs, text_docs
|
|
||||||
|
|
||||||
|
|
||||||
def get_language_from_extension(file_path: str) -> Optional[str]:
|
|
||||||
"""Get the programming language from file extension."""
|
|
||||||
ext = Path(file_path).suffix.lower()
|
|
||||||
return CODE_EXTENSIONS.get(ext)
|
|
||||||
|
|
||||||
|
|
||||||
def create_ast_chunks(
|
|
||||||
documents,
|
|
||||||
max_chunk_size: int = 512,
|
|
||||||
chunk_overlap: int = 64,
|
|
||||||
metadata_template: str = "default",
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Create AST-aware chunks from code documents using astchunk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of code documents
|
|
||||||
max_chunk_size: Maximum characters per chunk
|
|
||||||
chunk_overlap: Number of AST nodes to overlap between chunks
|
|
||||||
metadata_template: Template for chunk metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks with preserved code structure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from astchunk import ASTChunkBuilder
|
|
||||||
except ImportError as e:
|
|
||||||
logger.error(f"astchunk not available: {e}")
|
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
|
||||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
|
||||||
|
|
||||||
all_chunks = []
|
|
||||||
|
|
||||||
for doc in documents:
|
|
||||||
# Get language from metadata (set by detect_code_files)
|
|
||||||
language = doc.metadata.get("language")
|
|
||||||
if not language:
|
|
||||||
logger.warning(
|
|
||||||
"No language detected for document, falling back to traditional chunking"
|
|
||||||
)
|
|
||||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(traditional_chunks)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Configure astchunk
|
|
||||||
configs = {
|
|
||||||
"max_chunk_size": max_chunk_size,
|
|
||||||
"language": language,
|
|
||||||
"metadata_template": metadata_template,
|
|
||||||
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add repository-level metadata if available
|
|
||||||
repo_metadata = {
|
|
||||||
"file_path": doc.metadata.get("file_path", ""),
|
|
||||||
"file_name": doc.metadata.get("file_name", ""),
|
|
||||||
"creation_date": doc.metadata.get("creation_date", ""),
|
|
||||||
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
|
||||||
}
|
|
||||||
configs["repo_level_metadata"] = repo_metadata
|
|
||||||
|
|
||||||
# Create chunk builder and process
|
|
||||||
chunk_builder = ASTChunkBuilder(**configs)
|
|
||||||
code_content = doc.get_content()
|
|
||||||
|
|
||||||
if not code_content or not code_content.strip():
|
|
||||||
logger.warning("Empty code content, skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
|
||||||
|
|
||||||
# Extract text content from chunks
|
|
||||||
for chunk in chunks:
|
|
||||||
if hasattr(chunk, "text"):
|
|
||||||
chunk_text = chunk.text
|
|
||||||
elif isinstance(chunk, dict) and "text" in chunk:
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
elif isinstance(chunk, str):
|
|
||||||
chunk_text = chunk
|
|
||||||
else:
|
|
||||||
# Try to convert to string
|
|
||||||
chunk_text = str(chunk)
|
|
||||||
|
|
||||||
if chunk_text and chunk_text.strip():
|
|
||||||
all_chunks.append(chunk_text.strip())
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
|
||||||
logger.info("Falling back to traditional chunking")
|
|
||||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(traditional_chunks)
|
|
||||||
|
|
||||||
return all_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def create_traditional_chunks(
|
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Create traditional text chunks using LlamaIndex SentenceSplitter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of documents to chunk
|
|
||||||
chunk_size: Size of each chunk in characters
|
|
||||||
chunk_overlap: Overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
# Handle invalid chunk_size values
|
|
||||||
if chunk_size <= 0:
|
|
||||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
|
||||||
chunk_size = 256
|
|
||||||
|
|
||||||
# Ensure chunk_overlap is not negative and not larger than chunk_size
|
|
||||||
if chunk_overlap < 0:
|
|
||||||
chunk_overlap = 0
|
|
||||||
if chunk_overlap >= chunk_size:
|
|
||||||
chunk_overlap = chunk_size // 2
|
|
||||||
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator=" ",
|
|
||||||
paragraph_separator="\n\n",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
try:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
if nodes:
|
|
||||||
chunk_texts = [node.get_content() for node in nodes]
|
|
||||||
all_texts.extend(chunk_texts)
|
|
||||||
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Traditional chunking failed for document: {e}")
|
|
||||||
# As last resort, add the raw content
|
|
||||||
content = doc.get_content()
|
|
||||||
if content and content.strip():
|
|
||||||
all_texts.append(content.strip())
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size: int = 256,
|
|
||||||
chunk_overlap: int = 128,
|
|
||||||
use_ast_chunking: bool = False,
|
|
||||||
ast_chunk_size: int = 512,
|
|
||||||
ast_chunk_overlap: int = 64,
|
|
||||||
code_file_extensions: Optional[list[str]] = None,
|
|
||||||
ast_fallback_traditional: bool = True,
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Create text chunks from documents with optional AST support for code files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of LlamaIndex Document objects
|
|
||||||
chunk_size: Size for traditional text chunks
|
|
||||||
chunk_overlap: Overlap for traditional text chunks
|
|
||||||
use_ast_chunking: Whether to use AST chunking for code files
|
|
||||||
ast_chunk_size: Size for AST chunks
|
|
||||||
ast_chunk_overlap: Overlap for AST chunks
|
|
||||||
code_file_extensions: Custom list of code file extensions
|
|
||||||
ast_fallback_traditional: Fall back to traditional chunking on AST errors
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
if not documents:
|
|
||||||
logger.warning("No documents provided for chunking")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Create a local copy of supported extensions for this function call
|
|
||||||
local_code_extensions = CODE_EXTENSIONS.copy()
|
|
||||||
|
|
||||||
# Update supported extensions if provided
|
|
||||||
if code_file_extensions:
|
|
||||||
# Map extensions to languages (simplified mapping)
|
|
||||||
ext_mapping = {
|
|
||||||
".py": "python",
|
|
||||||
".java": "java",
|
|
||||||
".cs": "c_sharp",
|
|
||||||
".ts": "typescript",
|
|
||||||
".tsx": "typescript",
|
|
||||||
}
|
|
||||||
for ext in code_file_extensions:
|
|
||||||
if ext.lower() not in local_code_extensions:
|
|
||||||
# Try to guess language from extension
|
|
||||||
if ext.lower() in ext_mapping:
|
|
||||||
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
|
||||||
|
|
||||||
all_chunks = []
|
|
||||||
|
|
||||||
if use_ast_chunking:
|
|
||||||
# Separate code and text documents using local extensions
|
|
||||||
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
|
||||||
|
|
||||||
# Process code files with AST chunking
|
|
||||||
if code_docs:
|
|
||||||
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
|
|
||||||
try:
|
|
||||||
ast_chunks = create_ast_chunks(
|
|
||||||
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
|
||||||
)
|
|
||||||
all_chunks.extend(ast_chunks)
|
|
||||||
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"AST chunking failed: {e}")
|
|
||||||
if ast_fallback_traditional:
|
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
|
||||||
traditional_code_chunks = create_traditional_chunks(
|
|
||||||
code_docs, chunk_size, chunk_overlap
|
|
||||||
)
|
|
||||||
all_chunks.extend(traditional_code_chunks)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Process text files with traditional chunking
|
|
||||||
if text_docs:
|
|
||||||
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
|
|
||||||
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(text_chunks)
|
|
||||||
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
|
|
||||||
else:
|
|
||||||
# Use traditional chunking for all files
|
|
||||||
logger.info(f"Processing {len(documents)} documents with traditional chunking")
|
|
||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
|
||||||
return all_chunks
|
|
||||||
211
apps/code_rag.py
211
apps/code_rag.py
@@ -1,211 +0,0 @@
|
|||||||
"""
|
|
||||||
Code RAG example using AST-aware chunking for optimal code understanding.
|
|
||||||
Specialized for code repositories with automatic language detection and
|
|
||||||
optimized chunking parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import CODE_EXTENSIONS, create_text_chunks
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class CodeRAG(BaseRAGExample):
|
|
||||||
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="Code",
|
|
||||||
description="Process and query code repositories with AST-aware chunking",
|
|
||||||
default_index_name="code_index",
|
|
||||||
)
|
|
||||||
# Override defaults for code-specific usage
|
|
||||||
self.embedding_model_default = "facebook/contriever" # Good for code
|
|
||||||
self.max_items_default = -1 # Process all code files by default
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add code-specific arguments."""
|
|
||||||
code_group = parser.add_argument_group("Code Repository Parameters")
|
|
||||||
|
|
||||||
code_group.add_argument(
|
|
||||||
"--repo-dir",
|
|
||||||
type=str,
|
|
||||||
default=".",
|
|
||||||
help="Code repository directory to index (default: current directory)",
|
|
||||||
)
|
|
||||||
code_group.add_argument(
|
|
||||||
"--include-extensions",
|
|
||||||
nargs="+",
|
|
||||||
default=list(CODE_EXTENSIONS.keys()),
|
|
||||||
help="File extensions to include (default: supported code extensions)",
|
|
||||||
)
|
|
||||||
code_group.add_argument(
|
|
||||||
"--exclude-dirs",
|
|
||||||
nargs="+",
|
|
||||||
default=[
|
|
||||||
".git",
|
|
||||||
"__pycache__",
|
|
||||||
"node_modules",
|
|
||||||
"venv",
|
|
||||||
".venv",
|
|
||||||
"build",
|
|
||||||
"dist",
|
|
||||||
"target",
|
|
||||||
],
|
|
||||||
help="Directories to exclude from indexing",
|
|
||||||
)
|
|
||||||
code_group.add_argument(
|
|
||||||
"--max-file-size",
|
|
||||||
type=int,
|
|
||||||
default=1000000, # 1MB
|
|
||||||
help="Maximum file size in bytes to process (default: 1MB)",
|
|
||||||
)
|
|
||||||
code_group.add_argument(
|
|
||||||
"--include-comments",
|
|
||||||
action="store_true",
|
|
||||||
help="Include comments in chunking (useful for documentation)",
|
|
||||||
)
|
|
||||||
code_group.add_argument(
|
|
||||||
"--preserve-imports",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Try to preserve import statements in chunks (default: True)",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load code files and convert to AST-aware chunks."""
|
|
||||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
|
||||||
print(f"📁 Including extensions: {args.include_extensions}")
|
|
||||||
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
|
||||||
|
|
||||||
# Check if repository directory exists
|
|
||||||
repo_path = Path(args.repo_dir)
|
|
||||||
if not repo_path.exists():
|
|
||||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
|
||||||
|
|
||||||
# Load code files with filtering
|
|
||||||
reader_kwargs = {
|
|
||||||
"recursive": True,
|
|
||||||
"encoding": "utf-8",
|
|
||||||
"required_exts": args.include_extensions,
|
|
||||||
"exclude_hidden": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create exclusion filter
|
|
||||||
def file_filter(file_path: str) -> bool:
|
|
||||||
"""Filter out unwanted files and directories."""
|
|
||||||
path = Path(file_path)
|
|
||||||
|
|
||||||
# Check file size
|
|
||||||
try:
|
|
||||||
if path.stat().st_size > args.max_file_size:
|
|
||||||
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if in excluded directory
|
|
||||||
for exclude_dir in args.exclude_dirs:
|
|
||||||
if exclude_dir in path.parts:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load documents with file filtering
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
args.repo_dir,
|
|
||||||
file_extractor=None, # Use default extractors
|
|
||||||
**reader_kwargs,
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
|
|
||||||
# Apply custom filtering
|
|
||||||
filtered_docs = []
|
|
||||||
for doc in documents:
|
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
if file_filter(file_path):
|
|
||||||
filtered_docs.append(doc)
|
|
||||||
|
|
||||||
documents = filtered_docs
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error loading code files: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print(
|
|
||||||
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"✅ Loaded {len(documents)} code files")
|
|
||||||
|
|
||||||
# Show breakdown by language/extension
|
|
||||||
ext_counts = {}
|
|
||||||
for doc in documents:
|
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
if file_path:
|
|
||||||
ext = Path(file_path).suffix.lower()
|
|
||||||
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
|
||||||
|
|
||||||
print("📊 Files by extension:")
|
|
||||||
for ext, count in sorted(ext_counts.items()):
|
|
||||||
print(f" {ext}: {count} files")
|
|
||||||
|
|
||||||
# Use AST-aware chunking by default for code
|
|
||||||
print(
|
|
||||||
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size=256, # Fallback for non-code files
|
|
||||||
chunk_overlap=64,
|
|
||||||
use_ast_chunking=True, # Always use AST for code RAG
|
|
||||||
ast_chunk_size=args.ast_chunk_size,
|
|
||||||
ast_chunk_overlap=args.ast_chunk_overlap,
|
|
||||||
code_file_extensions=args.include_extensions,
|
|
||||||
ast_fallback_traditional=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
|
||||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
|
||||||
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
|
||||||
all_texts = all_texts[: args.max_items]
|
|
||||||
|
|
||||||
print(f"✅ Generated {len(all_texts)} code chunks")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for code RAG
|
|
||||||
print("\n💻 Code RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'How does the embedding computation work?'")
|
|
||||||
print("- 'What are the main classes in this codebase?'")
|
|
||||||
print("- 'Show me the search implementation'")
|
|
||||||
print("- 'How is error handling implemented?'")
|
|
||||||
print("- 'What design patterns are used?'")
|
|
||||||
print("- 'Explain the chunking logic'")
|
|
||||||
print("\n🚀 Features:")
|
|
||||||
print("- ✅ AST-aware chunking preserves code structure")
|
|
||||||
print("- ✅ Automatic language detection")
|
|
||||||
print("- ✅ Smart filtering of large files and common excludes")
|
|
||||||
print("- ✅ Optimized for code understanding")
|
|
||||||
print("\nUsage examples:")
|
|
||||||
print(" python -m apps.code_rag --repo-dir ./my_project")
|
|
||||||
print(
|
|
||||||
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
|
||||||
)
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = CodeRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""
|
|
||||||
Document RAG example using the unified interface.
|
|
||||||
Supports PDF, TXT, MD, and other document formats.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentRAG(BaseRAGExample):
|
|
||||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="Document",
|
|
||||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
|
||||||
default_index_name="test_doc_files",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add document-specific arguments."""
|
|
||||||
doc_group = parser.add_argument_group("Document Parameters")
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--data-dir",
|
|
||||||
type=str,
|
|
||||||
default="data",
|
|
||||||
help="Directory containing documents to index (default: data)",
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--file-types",
|
|
||||||
nargs="+",
|
|
||||||
default=None,
|
|
||||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--enable-code-chunking",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable AST-aware chunking for code files in the data directory",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load documents and convert to text chunks."""
|
|
||||||
print(f"Loading documents from: {args.data_dir}")
|
|
||||||
if args.file_types:
|
|
||||||
print(f"Filtering by file types: {args.file_types}")
|
|
||||||
else:
|
|
||||||
print("Processing all supported file types")
|
|
||||||
|
|
||||||
# Check if data directory exists
|
|
||||||
data_path = Path(args.data_dir)
|
|
||||||
if not data_path.exists():
|
|
||||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
|
||||||
|
|
||||||
# Load documents
|
|
||||||
reader_kwargs = {
|
|
||||||
"recursive": True,
|
|
||||||
"encoding": "utf-8",
|
|
||||||
}
|
|
||||||
if args.file_types:
|
|
||||||
reader_kwargs["required_exts"] = args.file_types
|
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} documents")
|
|
||||||
|
|
||||||
# Determine chunking strategy
|
|
||||||
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
|
||||||
|
|
||||||
if use_ast:
|
|
||||||
print("Using AST-aware chunking for code files")
|
|
||||||
|
|
||||||
# Convert to text chunks with optional AST support
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
documents,
|
|
||||||
chunk_size=args.chunk_size,
|
|
||||||
chunk_overlap=args.chunk_overlap,
|
|
||||||
use_ast_chunking=use_ast,
|
|
||||||
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
|
||||||
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
|
||||||
code_file_extensions=getattr(args, "code_file_extensions", None),
|
|
||||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
|
||||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
|
||||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
|
||||||
all_texts = all_texts[: args.max_items]
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for document RAG
|
|
||||||
print("\n📄 Document RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What are the main techniques LEANN uses?'")
|
|
||||||
print("- 'What is the technique DLPM?'")
|
|
||||||
print("- 'Who does Elizabeth Bennet marry?'")
|
|
||||||
print(
|
|
||||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
|
||||||
)
|
|
||||||
print("\n🚀 NEW: Code-aware chunking available!")
|
|
||||||
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
|
||||||
print("- Supports Python, Java, C#, TypeScript files")
|
|
||||||
print("- Better semantic understanding of code structure")
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = DocumentRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
"""
|
|
||||||
Email RAG example using the unified interface.
|
|
||||||
Supports Apple Mail on macOS.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
from .email_data.LEANN_email_reader import EmlxReader
|
|
||||||
|
|
||||||
|
|
||||||
class EmailRAG(BaseRAGExample):
|
|
||||||
"""RAG example for Apple Mail processing."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Process all emails by default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="Email",
|
|
||||||
description="Process and query Apple Mail emails with LEANN",
|
|
||||||
default_index_name="mail_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add email-specific arguments."""
|
|
||||||
email_group = parser.add_argument_group("Email Parameters")
|
|
||||||
email_group.add_argument(
|
|
||||||
"--mail-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
|
||||||
)
|
|
||||||
email_group.add_argument(
|
|
||||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
|
||||||
)
|
|
||||||
email_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
|
||||||
)
|
|
||||||
email_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_mail_directories(self) -> list[Path]:
|
|
||||||
"""Auto-detect all Apple Mail directories."""
|
|
||||||
mail_base = Path.home() / "Library" / "Mail"
|
|
||||||
if not mail_base.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find all Messages directories
|
|
||||||
messages_dirs = []
|
|
||||||
for item in mail_base.rglob("Messages"):
|
|
||||||
if item.is_dir():
|
|
||||||
messages_dirs.append(item)
|
|
||||||
|
|
||||||
return messages_dirs
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load emails and convert to text chunks."""
|
|
||||||
# Determine mail directories
|
|
||||||
if args.mail_path:
|
|
||||||
messages_dirs = [Path(args.mail_path)]
|
|
||||||
else:
|
|
||||||
print("Auto-detecting Apple Mail directories...")
|
|
||||||
messages_dirs = self._find_mail_directories()
|
|
||||||
|
|
||||||
if not messages_dirs:
|
|
||||||
print("No Apple Mail directories found!")
|
|
||||||
print("Please specify --mail-path manually")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(messages_dirs)} mail directories")
|
|
||||||
|
|
||||||
# Create reader
|
|
||||||
reader = EmlxReader(include_html=args.include_html)
|
|
||||||
|
|
||||||
# Process each directory
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, messages_dir in enumerate(messages_dirs):
|
|
||||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Count emlx files
|
|
||||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
|
||||||
print(f"Found {len(emlx_files)} email files")
|
|
||||||
|
|
||||||
# Apply max_items limit per directory
|
|
||||||
max_per_dir = -1 # Default to process all
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_dir = remaining
|
|
||||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
|
||||||
|
|
||||||
# Load emails - fix the parameter passing
|
|
||||||
documents = reader.load_data(
|
|
||||||
input_dir=str(messages_dir),
|
|
||||||
max_count=max_per_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} emails from this directory")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {messages_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No emails found to process!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
|
||||||
print("now starting to split into text chunks ... take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
# Email reader uses chunk_overlap=25 as in original
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Check platform
|
|
||||||
if sys.platform != "darwin":
|
|
||||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
|
||||||
print(" Windows/Linux support coming soon!\n")
|
|
||||||
|
|
||||||
# Example queries for email RAG
|
|
||||||
print("\n📧 Email RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What did my boss say about deadlines?'")
|
|
||||||
print("- 'Find emails about travel expenses'")
|
|
||||||
print("- 'Show me emails from last month about the project'")
|
|
||||||
print("- 'What food did I order from DoorDash?'")
|
|
||||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
|
||||||
|
|
||||||
rag = EmailRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1,189 +0,0 @@
|
|||||||
"""
|
|
||||||
WeChat History RAG example using the unified interface.
|
|
||||||
Supports WeChat chat history export and search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
|
|
||||||
from .history_data.wechat_history import WeChatHistoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class WeChatRAG(BaseRAGExample):
|
|
||||||
"""RAG example for WeChat chat history."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Match original default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="WeChat History",
|
|
||||||
description="Process and query WeChat chat history with LEANN",
|
|
||||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add WeChat-specific arguments."""
|
|
||||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--export-dir",
|
|
||||||
type=str,
|
|
||||||
default="./wechat_export",
|
|
||||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
|
||||||
)
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--force-export",
|
|
||||||
action="store_true",
|
|
||||||
help="Force re-export of WeChat data even if exports exist",
|
|
||||||
)
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
|
||||||
)
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
|
||||||
"""Export WeChat data using wechattweak-cli."""
|
|
||||||
print("Exporting WeChat data...")
|
|
||||||
|
|
||||||
# Check if WeChat is running
|
|
||||||
try:
|
|
||||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print("WeChat is not running. Please start WeChat first.")
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
pass # pgrep might not be available on all systems
|
|
||||||
|
|
||||||
# Create export directory
|
|
||||||
export_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Run export command
|
|
||||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
print(f"Running: {' '.join(cmd)}")
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
|
||||||
print("WeChat data exported successfully!")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"Export failed: {result.stderr}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
print("\nError: wechattweak-cli not found!")
|
|
||||||
print("Please install it first:")
|
|
||||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Export error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load WeChat history and convert to text chunks."""
|
|
||||||
# Initialize WeChat reader with export capabilities
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
# Find existing exports or create new ones using the centralized method
|
|
||||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
|
||||||
if not export_dirs:
|
|
||||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
|
||||||
# Try to find any existing exports in common locations
|
|
||||||
export_dirs = reader.find_wechat_export_dirs()
|
|
||||||
if not export_dirs:
|
|
||||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Load documents from all found export directories
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, export_dir in enumerate(export_dirs):
|
|
||||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per export
|
|
||||||
max_per_export = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_export = remaining
|
|
||||||
|
|
||||||
documents = reader.load_data(
|
|
||||||
wechat_export_dir=str(export_dir),
|
|
||||||
max_count=max_per_export,
|
|
||||||
concatenate_messages=True, # Enable message concatenation for better context
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {export_dir}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
|
||||||
print("now starting to split into text chunks ... take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks with contact information
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
text_splitter = SentenceSplitter(
|
|
||||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
# Add contact information to each chunk
|
|
||||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
|
||||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Check platform
|
|
||||||
if sys.platform != "darwin":
|
|
||||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
|
||||||
print(" You can still query existing exports on other platforms\n")
|
|
||||||
|
|
||||||
# Example queries for WeChat RAG
|
|
||||||
print("\n💬 WeChat History RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'Show me conversations about travel plans'")
|
|
||||||
print("- 'Find group chats about weekend activities'")
|
|
||||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
|
||||||
print("- 'What did we discuss about the project last month?'")
|
|
||||||
print("\nNote: WeChat must be running for export to work\n")
|
|
||||||
|
|
||||||
rag = WeChatRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 73 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 224 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 152 KiB |
@@ -1,148 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
def _meta_exists(index_path: str) -> bool:
|
|
||||||
p = Path(index_path)
|
|
||||||
return (p.parent / f"{p.stem}.meta.json").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
|
||||||
# if _meta_exists(index_path):
|
|
||||||
# return
|
|
||||||
kwargs = {}
|
|
||||||
if backend_name == "hnsw":
|
|
||||||
kwargs["is_compact"] = is_recompute
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=backend_name,
|
|
||||||
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
|
||||||
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_recompute=is_recompute,
|
|
||||||
num_threads=4,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for i in range(num_docs):
|
|
||||||
builder.add_text(
|
|
||||||
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
|
||||||
)
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _bench_group(
|
|
||||||
index_path: str,
|
|
||||||
recompute: bool,
|
|
||||||
query: str,
|
|
||||||
repeats: int,
|
|
||||||
complexity: int = 32,
|
|
||||||
top_k: int = 10,
|
|
||||||
) -> float:
|
|
||||||
# Independent searcher per group; fixed port when recompute
|
|
||||||
searcher = LeannSearcher(index_path=index_path)
|
|
||||||
|
|
||||||
# Warm-up once
|
|
||||||
_ = searcher.search(
|
|
||||||
query,
|
|
||||||
top_k=top_k,
|
|
||||||
complexity=complexity,
|
|
||||||
recompute_embeddings=recompute,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _once() -> float:
|
|
||||||
t0 = time.time()
|
|
||||||
_ = searcher.search(
|
|
||||||
query,
|
|
||||||
top_k=top_k,
|
|
||||||
complexity=complexity,
|
|
||||||
recompute_embeddings=recompute,
|
|
||||||
)
|
|
||||||
return time.time() - t0
|
|
||||||
|
|
||||||
if repeats <= 1:
|
|
||||||
t = _once()
|
|
||||||
else:
|
|
||||||
vals = [_once() for _ in range(repeats)]
|
|
||||||
vals.sort()
|
|
||||||
t = vals[len(vals) // 2]
|
|
||||||
|
|
||||||
searcher.cleanup()
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--num-docs", type=int, default=5000)
|
|
||||||
parser.add_argument("--repeats", type=int, default=3)
|
|
||||||
parser.add_argument("--complexity", type=int, default=32)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
|
||||||
base.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
# ---------- Build HNSW variants ----------
|
|
||||||
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
|
||||||
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
|
||||||
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
|
||||||
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
|
||||||
|
|
||||||
# ---------- Build DiskANN variants ----------
|
|
||||||
diskann_r = str(base / "diskann_r.leann")
|
|
||||||
diskann_nr = str(base / "diskann_nr.leann")
|
|
||||||
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
|
||||||
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
|
||||||
|
|
||||||
# ---------- Helpers ----------
|
|
||||||
def _size_for(prefix: str) -> int:
|
|
||||||
p = Path(prefix)
|
|
||||||
base_dir = p.parent
|
|
||||||
stem = p.stem
|
|
||||||
total = 0
|
|
||||||
for f in base_dir.iterdir():
|
|
||||||
if f.is_file() and f.name.startswith(stem):
|
|
||||||
total += f.stat().st_size
|
|
||||||
return total
|
|
||||||
|
|
||||||
# ---------- HNSW benchmark ----------
|
|
||||||
t_hnsw_r = _bench_group(
|
|
||||||
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
|
||||||
)
|
|
||||||
t_hnsw_nr = _bench_group(
|
|
||||||
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
|
||||||
)
|
|
||||||
size_hnsw_r = _size_for(hnsw_r)
|
|
||||||
size_hnsw_nr = _size_for(hnsw_nr)
|
|
||||||
|
|
||||||
print("Benchmark results (HNSW):")
|
|
||||||
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
|
||||||
print(
|
|
||||||
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
|
||||||
)
|
|
||||||
print(" Expectation: no-recompute should be faster but larger on disk.")
|
|
||||||
|
|
||||||
# ---------- DiskANN benchmark ----------
|
|
||||||
t_diskann_r = _bench_group(
|
|
||||||
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
|
||||||
)
|
|
||||||
t_diskann_nr = _bench_group(
|
|
||||||
diskann_nr,
|
|
||||||
False,
|
|
||||||
"DiskANN NR test doc 123",
|
|
||||||
repeats=args.repeats,
|
|
||||||
complexity=args.complexity,
|
|
||||||
)
|
|
||||||
size_diskann_r = _size_for(diskann_r)
|
|
||||||
size_diskann_nr = _size_for(diskann_nr)
|
|
||||||
|
|
||||||
print("\nBenchmark results (DiskANN):")
|
|
||||||
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
|
||||||
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
|
||||||
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
|
||||||
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
DiskANN vs HNSW Search Performance Comparison
|
|
||||||
|
|
||||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
|
||||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
|
||||||
- HNSW: With recompute enabled (is_recompute=True)
|
|
||||||
- Tests performance across different dataset sizes
|
|
||||||
- Measures search latency, recall, and index size
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import multiprocessing as mp
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
|
||||||
try:
|
|
||||||
mp.set_start_method("fork", force=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_texts(n_docs: int) -> list[str]:
|
|
||||||
"""Create synthetic test documents for benchmarking."""
|
|
||||||
np.random.seed(42)
|
|
||||||
topics = [
|
|
||||||
"machine learning and artificial intelligence",
|
|
||||||
"natural language processing and text analysis",
|
|
||||||
"computer vision and image recognition",
|
|
||||||
"data science and statistical analysis",
|
|
||||||
"deep learning and neural networks",
|
|
||||||
"information retrieval and search engines",
|
|
||||||
"database systems and data management",
|
|
||||||
"software engineering and programming",
|
|
||||||
"cybersecurity and network protection",
|
|
||||||
"cloud computing and distributed systems",
|
|
||||||
]
|
|
||||||
|
|
||||||
texts = []
|
|
||||||
for i in range(n_docs):
|
|
||||||
topic = topics[i % len(topics)]
|
|
||||||
variation = np.random.randint(1, 100)
|
|
||||||
text = (
|
|
||||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
|
||||||
f"Additional information about {topic} with details and examples. "
|
|
||||||
f"Technical discussion of {topic} including implementation aspects."
|
|
||||||
)
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_backend(
|
|
||||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
|
||||||
) -> dict[str, float]:
|
|
||||||
"""Benchmark a specific backend with the given configuration."""
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
|
||||||
|
|
||||||
# Build index
|
|
||||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=backend_name,
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
**backend_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
build_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Measure index size
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
|
||||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
|
||||||
size_mb = total_size / (1024 * 1024)
|
|
||||||
|
|
||||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
|
||||||
|
|
||||||
# Search benchmark
|
|
||||||
print("🔍 Running search benchmark...")
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
|
|
||||||
search_times = []
|
|
||||||
all_results = []
|
|
||||||
|
|
||||||
for query in test_queries:
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(query, top_k=5)
|
|
||||||
search_time = time.time() - start_time
|
|
||||||
search_times.append(search_time)
|
|
||||||
all_results.append(results)
|
|
||||||
|
|
||||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
|
||||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
|
||||||
|
|
||||||
# Check for valid scores (detect -inf issues)
|
|
||||||
all_scores = [
|
|
||||||
result.score
|
|
||||||
for results in all_results
|
|
||||||
for result in results
|
|
||||||
if result.score is not None
|
|
||||||
]
|
|
||||||
valid_scores = [
|
|
||||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
|
||||||
]
|
|
||||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
|
||||||
|
|
||||||
# Clean up (ensure embedding server shutdown and object GC)
|
|
||||||
try:
|
|
||||||
if hasattr(searcher, "cleanup"):
|
|
||||||
searcher.cleanup()
|
|
||||||
del searcher
|
|
||||||
del builder
|
|
||||||
gc.collect()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"build_time": build_time,
|
|
||||||
"avg_search_time_ms": avg_search_time,
|
|
||||||
"index_size_mb": size_mb,
|
|
||||||
"score_validity_rate": score_validity_rate,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
|
||||||
"""Run performance comparison between DiskANN and HNSW."""
|
|
||||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
|
||||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
texts = create_test_texts(n_docs)
|
|
||||||
test_queries = [
|
|
||||||
"machine learning algorithms",
|
|
||||||
"natural language processing",
|
|
||||||
"computer vision techniques",
|
|
||||||
"data analysis methods",
|
|
||||||
"neural network architectures",
|
|
||||||
"database query optimization",
|
|
||||||
"software development practices",
|
|
||||||
"security vulnerabilities",
|
|
||||||
"cloud infrastructure",
|
|
||||||
"distributed computing",
|
|
||||||
][:n_queries]
|
|
||||||
|
|
||||||
# HNSW benchmark
|
|
||||||
hnsw_results = benchmark_backend(
|
|
||||||
backend_name="hnsw",
|
|
||||||
texts=texts,
|
|
||||||
test_queries=test_queries,
|
|
||||||
backend_kwargs={
|
|
||||||
"is_recompute": True, # Enable recompute for fair comparison
|
|
||||||
"M": 16,
|
|
||||||
"efConstruction": 200,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# DiskANN benchmark
|
|
||||||
diskann_results = benchmark_backend(
|
|
||||||
backend_name="diskann",
|
|
||||||
texts=texts,
|
|
||||||
test_queries=test_queries,
|
|
||||||
backend_kwargs={
|
|
||||||
"is_recompute": True, # Enable graph partitioning
|
|
||||||
"num_neighbors": 32,
|
|
||||||
"search_list_size": 50,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Performance comparison
|
|
||||||
print("\n📈 Performance Comparison Results")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
|
||||||
print(f"{'-' * 60}")
|
|
||||||
|
|
||||||
# Build time comparison
|
|
||||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
|
||||||
print(
|
|
||||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search time comparison
|
|
||||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
|
||||||
print(
|
|
||||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index size comparison
|
|
||||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
|
||||||
print(
|
|
||||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Score validity
|
|
||||||
print(
|
|
||||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print("\n🎯 Summary:")
|
|
||||||
if search_speedup > 1:
|
|
||||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
|
||||||
else:
|
|
||||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
|
||||||
|
|
||||||
if size_ratio > 1:
|
|
||||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
|
||||||
else:
|
|
||||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
|
||||||
|
|
||||||
print(
|
|
||||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Handle help request
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
|
||||||
print("DiskANN vs HNSW Performance Comparison")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
|
||||||
print()
|
|
||||||
print("Arguments:")
|
|
||||||
print(" n_docs Number of documents to index (default: 500)")
|
|
||||||
print(" n_queries Number of test queries to run (default: 10)")
|
|
||||||
print()
|
|
||||||
print("Examples:")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Parse command line arguments
|
|
||||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
|
||||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
|
||||||
|
|
||||||
print("DiskANN vs HNSW Performance Comparison")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
|
||||||
print()
|
|
||||||
|
|
||||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n⚠️ Benchmark interrupted by user")
|
|
||||||
sys.exit(130)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Benchmark failed: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
finally:
|
|
||||||
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
|
||||||
try:
|
|
||||||
gc.collect()
|
|
||||||
print("\n🧹 Cleanup completed")
|
|
||||||
# Flush stdio to ensure message is visible before hard-exit
|
|
||||||
try:
|
|
||||||
import sys as _sys
|
|
||||||
|
|
||||||
_sys.stdout.flush()
|
|
||||||
_sys.stderr.flush()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
|
||||||
import os as _os
|
|
||||||
|
|
||||||
_os._exit(0)
|
|
||||||
82
data/.gitattributes
vendored
Normal file
82
data/.gitattributes
vendored
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - uncompressed
|
||||||
|
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Audio files - compressed
|
||||||
|
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - uncompressed
|
||||||
|
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Image files - compressed
|
||||||
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||||
|
# Video files - compressed
|
||||||
|
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||||
0
benchmarks/data/README.md → data/README.md
Executable file → Normal file
0
benchmarks/data/README.md → data/README.md
Executable file → Normal file
@@ -1,123 +0,0 @@
|
|||||||
# Thinking Budget Feature Implementation
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
|
||||||
|
|
||||||
## Feature Description
|
|
||||||
|
|
||||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
|
||||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
|
||||||
- **`medium`**: Balanced speed and reasoning depth
|
|
||||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
### 1. Command Line Interface
|
|
||||||
|
|
||||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# LEANN CLI
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# RAG Examples
|
|
||||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. LLM Backend Support
|
|
||||||
|
|
||||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
options = kwargs.copy()
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget:
|
|
||||||
options.pop("thinking_budget", None)
|
|
||||||
if thinking_budget in ["low", "medium", "high"]:
|
|
||||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
|
||||||
```
|
|
||||||
|
|
||||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
|
||||||
|
|
||||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
|
||||||
# Check if this is an o-series model
|
|
||||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
|
||||||
if any(model in self.model for model in o_series_models):
|
|
||||||
params["reasoning_effort"] = thinking_budget
|
|
||||||
```
|
|
||||||
|
|
||||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
|
||||||
|
|
||||||
### 3. Parameter Propagation
|
|
||||||
|
|
||||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
|
||||||
|
|
||||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
|
||||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
|
||||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
|
||||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
|
||||||
|
|
||||||
## Files Modified
|
|
||||||
|
|
||||||
### Core Implementation
|
|
||||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
|
||||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
|
||||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
- `README.md`: Added thinking budget parameter to usage examples
|
|
||||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
|
||||||
|
|
||||||
### Examples
|
|
||||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
|
||||||
|
|
||||||
## Usage Examples
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
```bash
|
|
||||||
# High reasoning effort for complex questions
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# Medium reasoning for balanced performance
|
|
||||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
|
||||||
|
|
||||||
# Low reasoning for fast responses
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
|
||||||
```
|
|
||||||
|
|
||||||
### RAG Examples
|
|
||||||
```bash
|
|
||||||
# Email RAG with high reasoning
|
|
||||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# Document RAG with medium reasoning
|
|
||||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
### Ollama Models
|
|
||||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
|
||||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
|
||||||
|
|
||||||
### OpenAI Models
|
|
||||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
|
||||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
The implementation includes comprehensive testing:
|
|
||||||
- Parameter handling verification
|
|
||||||
- Backend-specific API format validation
|
|
||||||
- CLI argument parsing tests
|
|
||||||
- Integration with existing LEANN architecture
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
# AST-Aware Code chunking guide
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Enable AST chunking for mixed content (code + docs)
|
|
||||||
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
|
||||||
|
|
||||||
# Specialized code repository indexing
|
|
||||||
python -m apps.code_rag --repo-dir ./my_codebase
|
|
||||||
|
|
||||||
# Global CLI with AST support
|
|
||||||
leann build my-code-index --docs ./src --use-ast-chunking
|
|
||||||
```
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install LEANN with AST chunking support
|
|
||||||
uv pip install -e "."
|
|
||||||
```
|
|
||||||
|
|
||||||
## Best Practices
|
|
||||||
|
|
||||||
### When to Use AST Chunking
|
|
||||||
|
|
||||||
✅ **Recommended for:**
|
|
||||||
- Code repositories with multiple languages
|
|
||||||
- Mixed documentation and code content
|
|
||||||
- Complex codebases with deep function/class hierarchies
|
|
||||||
- When working with Claude Code for code assistance
|
|
||||||
|
|
||||||
❌ **Not recommended for:**
|
|
||||||
- Pure text documents
|
|
||||||
- Very large files (>1MB)
|
|
||||||
- Languages not supported by tree-sitter
|
|
||||||
|
|
||||||
### Optimal Configuration
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Recommended settings for most codebases
|
|
||||||
python -m apps.code_rag \
|
|
||||||
--repo-dir ./src \
|
|
||||||
--ast-chunk-size 768 \
|
|
||||||
--ast-chunk-overlap 96 \
|
|
||||||
--exclude-dirs .git __pycache__ node_modules build dist
|
|
||||||
```
|
|
||||||
|
|
||||||
### Supported Languages
|
|
||||||
|
|
||||||
| Extension | Language | Status |
|
|
||||||
|-----------|----------|--------|
|
|
||||||
| `.py` | Python | ✅ Full support |
|
|
||||||
| `.java` | Java | ✅ Full support |
|
|
||||||
| `.cs` | C# | ✅ Full support |
|
|
||||||
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
|
||||||
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
|
||||||
|
|
||||||
## Integration Examples
|
|
||||||
|
|
||||||
### Document RAG with Code Support
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Enable code chunking in document RAG
|
|
||||||
python -m apps.document_rag \
|
|
||||||
--enable-code-chunking \
|
|
||||||
--data-dir ./project \
|
|
||||||
--query "How does authentication work in the codebase?"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Claude Code Integration
|
|
||||||
|
|
||||||
When using with Claude Code MCP server, AST chunking provides better context for:
|
|
||||||
- Code completion and suggestions
|
|
||||||
- Bug analysis and debugging
|
|
||||||
- Architecture understanding
|
|
||||||
- Refactoring assistance
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
1. **Fallback to Traditional Chunking**
|
|
||||||
- Normal behavior for unsupported languages
|
|
||||||
- Check logs for specific language support
|
|
||||||
|
|
||||||
2. **Performance with Large Files**
|
|
||||||
- Adjust `--max-file-size` parameter
|
|
||||||
- Use `--exclude-dirs` to skip unnecessary directories
|
|
||||||
|
|
||||||
3. **Quality Issues**
|
|
||||||
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
|
||||||
- Adjust overlap for better context preservation
|
|
||||||
|
|
||||||
### Debug Mode
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export LEANN_LOG_LEVEL=DEBUG
|
|
||||||
python -m apps.code_rag --repo-dir ./my_code
|
|
||||||
```
|
|
||||||
|
|
||||||
## Migration from Traditional Chunking
|
|
||||||
|
|
||||||
Existing workflows continue to work without changes. To enable AST chunking:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Before
|
|
||||||
python -m apps.document_rag --chunk-size 256
|
|
||||||
|
|
||||||
# After (maintains traditional chunking for non-code files)
|
|
||||||
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
|
||||||
```
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
|
||||||
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
|
||||||
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
# LEANN Configuration Guide
|
|
||||||
|
|
||||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
|
||||||
|
|
||||||
## Getting Started: Simple is Better
|
|
||||||
|
|
||||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
|
||||||
|
|
||||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
|
||||||
```bash
|
|
||||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**For other data sources**: Limit the dataset size for quick testing
|
|
||||||
```bash
|
|
||||||
# WeChat: Test with recent messages only
|
|
||||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
|
||||||
|
|
||||||
# Browser history: Last few days
|
|
||||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
|
||||||
|
|
||||||
# Email: Recent inbox
|
|
||||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
|
||||||
```
|
|
||||||
|
|
||||||
Once validated, scale up gradually:
|
|
||||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
|
||||||
- This helps identify issues early before committing to long processing times
|
|
||||||
|
|
||||||
## Embedding Model Selection: Understanding the Trade-offs
|
|
||||||
|
|
||||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
|
||||||
|
|
||||||
### Small Models (< 100M parameters)
|
|
||||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
|
||||||
- **Pros**: Lightweight, fast for both indexing and inference
|
|
||||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
|
||||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
|
||||||
|
|
||||||
### Medium Models (100M-500M parameters)
|
|
||||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
|
||||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
|
||||||
- **Cons**: Requires more compute than small models
|
|
||||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
|
||||||
|
|
||||||
### Large Models (500M+ parameters)
|
|
||||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
|
||||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
|
||||||
- **Cons**: Slower inference, longer index build times
|
|
||||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
|
||||||
|
|
||||||
### Quick Start: Cloud and Local Embedding Options
|
|
||||||
|
|
||||||
**OpenAI Embeddings (Fastest Setup)**
|
|
||||||
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
|
||||||
```bash
|
|
||||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
|
||||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
|
||||||
```
|
|
||||||
|
|
||||||
**Ollama Embeddings (Privacy-Focused)**
|
|
||||||
For local embeddings with complete privacy:
|
|
||||||
```bash
|
|
||||||
# First, pull an embedding model
|
|
||||||
ollama pull nomic-embed-text
|
|
||||||
|
|
||||||
# Use Ollama embeddings
|
|
||||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
|
||||||
|
|
||||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
|
||||||
- **Pros**: No local compute needed, consistently fast, high quality
|
|
||||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
|
||||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
|
||||||
|
|
||||||
**Local Embeddings**
|
|
||||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
|
||||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
|
||||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
|
||||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
|
||||||
- Full recomputation required
|
|
||||||
- High memory usage during build phase
|
|
||||||
- Excellent recall (95%+)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Optimal for most use cases
|
|
||||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
|
||||||
```
|
|
||||||
|
|
||||||
### DiskANN
|
|
||||||
**Best for**: Large datasets, especially when you want `recompute=True`.
|
|
||||||
|
|
||||||
**Key advantages:**
|
|
||||||
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
|
||||||
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
|
||||||
- **Better scaling**: Designed for 100k+ documents
|
|
||||||
|
|
||||||
**Recompute behavior:**
|
|
||||||
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
|
||||||
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Recommended for most use cases
|
|
||||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
|
||||||
```
|
|
||||||
|
|
||||||
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
|
||||||
|
|
||||||
## LLM Selection: Engine and Model Comparison
|
|
||||||
|
|
||||||
### LLM Engines
|
|
||||||
|
|
||||||
**OpenAI** (`--llm openai`)
|
|
||||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
|
||||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
|
||||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
|
||||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
|
||||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
|
||||||
|
|
||||||
**Ollama** (`--llm ollama`)
|
|
||||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
|
||||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
|
||||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
|
||||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
|
||||||
|
|
||||||
**HuggingFace** (`--llm hf`)
|
|
||||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
|
||||||
- **Cons**: More complex initial setup
|
|
||||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
|
||||||
|
|
||||||
## Parameter Tuning Guide
|
|
||||||
|
|
||||||
### Search Complexity Parameters
|
|
||||||
|
|
||||||
**`--build-complexity`** (index building)
|
|
||||||
- Controls thoroughness during index construction
|
|
||||||
- Higher = better recall but slower build
|
|
||||||
- Recommendations:
|
|
||||||
- 32: Quick prototyping
|
|
||||||
- 64: Balanced (default)
|
|
||||||
- 128: Production systems
|
|
||||||
- 256: Maximum quality
|
|
||||||
|
|
||||||
**`--search-complexity`** (query time)
|
|
||||||
- Controls search thoroughness
|
|
||||||
- Higher = better results but slower
|
|
||||||
- Recommendations:
|
|
||||||
- 16: Fast/Interactive search
|
|
||||||
- 32: High quality with diversity
|
|
||||||
- 64+: Maximum accuracy
|
|
||||||
|
|
||||||
### Top-K Selection
|
|
||||||
|
|
||||||
**`--top-k`** (number of retrieved chunks)
|
|
||||||
- More chunks = better context but slower LLM processing
|
|
||||||
- Should be always smaller than `--search-complexity`
|
|
||||||
- Guidelines:
|
|
||||||
- 10-20: General questions (default: 20)
|
|
||||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
|
||||||
|
|
||||||
**Trade-off formula**:
|
|
||||||
- Retrieval time ∝ log(n) × search_complexity
|
|
||||||
- LLM processing time ∝ top_k × chunk_size
|
|
||||||
- Total context = top_k × chunk_size tokens
|
|
||||||
|
|
||||||
### Thinking Budget for Reasoning Models
|
|
||||||
|
|
||||||
**`--thinking-budget`** (reasoning effort level)
|
|
||||||
- Controls the computational effort for reasoning models
|
|
||||||
- Options: `low`, `medium`, `high`
|
|
||||||
- Guidelines:
|
|
||||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
|
||||||
- `medium`: Balanced speed and reasoning depth
|
|
||||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
|
||||||
- **Supported Models**:
|
|
||||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
|
||||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
|
||||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
|
||||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
|
||||||
|
|
||||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
|
||||||
|
|
||||||
**💡 Quick Examples:**
|
|
||||||
```bash
|
|
||||||
# OpenAI o-series reasoning model
|
|
||||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
|
||||||
--index-dir hnswbuild --backend hnsw \
|
|
||||||
--llm openai --llm-model o3 --thinking-budget medium
|
|
||||||
|
|
||||||
# Ollama reasoning model
|
|
||||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
|
||||||
--index-dir hnswbuild --backend hnsw \
|
|
||||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
```
|
|
||||||
|
|
||||||
### Graph Degree (HNSW/DiskANN)
|
|
||||||
|
|
||||||
**`--graph-degree`**
|
|
||||||
- Number of connections per node in the graph
|
|
||||||
- Higher = better recall but more memory
|
|
||||||
- HNSW: 16-32 (default: 32)
|
|
||||||
- DiskANN: 32-128 (default: 64)
|
|
||||||
|
|
||||||
|
|
||||||
## Performance Optimization Checklist
|
|
||||||
|
|
||||||
### If Embedding is Too Slow
|
|
||||||
|
|
||||||
1. **Switch to smaller model**:
|
|
||||||
```bash
|
|
||||||
# From large model
|
|
||||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
|
||||||
# To small model
|
|
||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Limit dataset size for testing**:
|
|
||||||
```bash
|
|
||||||
--max-items 1000 # Process first 1k items only
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
|
||||||
```bash
|
|
||||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
|
||||||
```
|
|
||||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
|
||||||
|
|
||||||
4. **Use Ollama**
|
|
||||||
```bash
|
|
||||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
|
||||||
```
|
|
||||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
|
||||||
### If Search Quality is Poor
|
|
||||||
|
|
||||||
1. **Increase retrieval count**:
|
|
||||||
```bash
|
|
||||||
--top-k 30 # Retrieve more candidates
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Upgrade embedding model**:
|
|
||||||
```bash
|
|
||||||
# For English
|
|
||||||
--embedding-model BAAI/bge-base-en-v1.5
|
|
||||||
# For multilingual
|
|
||||||
--embedding-model intfloat/multilingual-e5-large
|
|
||||||
```
|
|
||||||
|
|
||||||
## Understanding the Trade-offs
|
|
||||||
|
|
||||||
Every configuration choice involves trade-offs:
|
|
||||||
|
|
||||||
| Factor | Small/Fast | Large/Quality |
|
|
||||||
|--------|------------|---------------|
|
|
||||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
|
||||||
| Chunk Size | 512 tokens | 128 tokens |
|
|
||||||
| Index Type | HNSW | DiskANN |
|
|
||||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
|
||||||
|
|
||||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
|
||||||
|
|
||||||
## Low-resource setups
|
|
||||||
|
|
||||||
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
|
||||||
|
|
||||||
### 1) Use OpenAI embeddings (no local compute)
|
|
||||||
|
|
||||||
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export OPENAI_API_KEY=sk-...
|
|
||||||
|
|
||||||
# Build with OpenAI embeddings
|
|
||||||
leann build my-index \
|
|
||||||
--embedding-mode openai \
|
|
||||||
--embedding-model text-embedding-3-small
|
|
||||||
|
|
||||||
# Search with OpenAI embeddings (recompute at query time)
|
|
||||||
leann search my-index "your query" \
|
|
||||||
--recompute
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
|
||||||
|
|
||||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# One-time: install and configure SkyPilot
|
|
||||||
pip install skypilot
|
|
||||||
|
|
||||||
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
|
||||||
sky launch -c leann-gpu sky/leann-build.yaml
|
|
||||||
|
|
||||||
# Override parameters via -e key=value (optional)
|
|
||||||
sky launch -c leann-gpu sky/leann-build.yaml \
|
|
||||||
-e index_name=my-index \
|
|
||||||
-e backend=hnsw \
|
|
||||||
-e embedding_mode=sentence-transformers \
|
|
||||||
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
|
||||||
|
|
||||||
# Copy the built index back to your local .leann (use rsync)
|
|
||||||
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3) Disable recomputation to trade storage for speed
|
|
||||||
|
|
||||||
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build without recomputation (HNSW requires non-compact in this mode)
|
|
||||||
leann build my-index --no-recompute --no-compact
|
|
||||||
|
|
||||||
# Search without recomputation
|
|
||||||
leann search my-index "your query" --no-recompute
|
|
||||||
```
|
|
||||||
|
|
||||||
When to use:
|
|
||||||
- Extreme low latency requirements (high QPS, interactive assistants)
|
|
||||||
- Read-heavy workloads where storage is cheaper than latency
|
|
||||||
- No always-available GPU
|
|
||||||
|
|
||||||
Constraints:
|
|
||||||
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
|
||||||
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
|
||||||
|
|
||||||
Storage impact:
|
|
||||||
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
|
||||||
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
|
||||||
|
|
||||||
Converting an existing index (rebuild required):
|
|
||||||
```bash
|
|
||||||
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
|
||||||
leann build my-index --force --no-recompute --no-compact
|
|
||||||
```
|
|
||||||
|
|
||||||
Python API usage:
|
|
||||||
```python
|
|
||||||
from leann import LeannSearcher
|
|
||||||
|
|
||||||
searcher = LeannSearcher("/path/to/my-index.leann")
|
|
||||||
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
Trade-offs:
|
|
||||||
- Lower latency and fewer network hops at query time
|
|
||||||
- Significantly higher storage (10–100× vs selective recomputation)
|
|
||||||
- Slightly larger memory footprint during build and search
|
|
||||||
|
|
||||||
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
|
||||||
|
|
||||||
- HNSW
|
|
||||||
|
|
||||||
```text
|
|
||||||
recompute=True: search_time=0.818s, size=1.1MB
|
|
||||||
recompute=False: search_time=0.012s, size=16.6MB
|
|
||||||
```
|
|
||||||
|
|
||||||
- DiskANN
|
|
||||||
|
|
||||||
```text
|
|
||||||
recompute=True: search_time=0.041s, size=5.9MB
|
|
||||||
recompute=False: search_time=0.013s, size=24.6MB
|
|
||||||
```
|
|
||||||
|
|
||||||
Conclusion:
|
|
||||||
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
|
||||||
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Further Reading
|
|
||||||
|
|
||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
|
||||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
|
||||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
|
||||||
@@ -3,10 +3,9 @@
|
|||||||
## 🔥 Core Features
|
## 🔥 Core Features
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
|
|
||||||
## 🛠️ Technical Highlights
|
## 🛠️ Technical Highlights
|
||||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
@@ -14,7 +13,7 @@
|
|||||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||||
|
|
||||||
## 🎨 Developer Experience
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
|||||||
@@ -1,300 +0,0 @@
|
|||||||
# LEANN Metadata Filtering Usage Guide
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
|
||||||
|
|
||||||
## Basic Usage
|
|
||||||
|
|
||||||
### Adding Metadata to Your Documents
|
|
||||||
|
|
||||||
When building your index, add metadata to each text chunk:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
|
|
||||||
builder = LeannBuilder("hnsw")
|
|
||||||
|
|
||||||
# Add text with metadata
|
|
||||||
builder.add_text(
|
|
||||||
text="Chapter 1: Alice falls down the rabbit hole",
|
|
||||||
metadata={
|
|
||||||
"chapter": 1,
|
|
||||||
"character": "Alice",
|
|
||||||
"themes": ["adventure", "curiosity"],
|
|
||||||
"word_count": 150
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
builder.build_index("alice_in_wonderland_index")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Searching with Metadata Filters
|
|
||||||
|
|
||||||
Use the `metadata_filters` parameter in search calls:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from leann.api import LeannSearcher
|
|
||||||
|
|
||||||
searcher = LeannSearcher("alice_in_wonderland_index")
|
|
||||||
|
|
||||||
# Search with filters
|
|
||||||
results = searcher.search(
|
|
||||||
query="What happens to Alice?",
|
|
||||||
top_k=10,
|
|
||||||
metadata_filters={
|
|
||||||
"chapter": {"<=": 5}, # Only chapters 1-5
|
|
||||||
"spoiler_level": {"!=": "high"} # No high spoilers
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Filter Syntax
|
|
||||||
|
|
||||||
### Basic Structure
|
|
||||||
|
|
||||||
```python
|
|
||||||
metadata_filters = {
|
|
||||||
"field_name": {"operator": value},
|
|
||||||
"another_field": {"operator": value}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Supported Operators
|
|
||||||
|
|
||||||
#### Comparison Operators
|
|
||||||
- `"=="`: Equal to
|
|
||||||
- `"!="`: Not equal to
|
|
||||||
- `"<"`: Less than
|
|
||||||
- `"<="`: Less than or equal
|
|
||||||
- `">"`: Greater than
|
|
||||||
- `">="`: Greater than or equal
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Examples
|
|
||||||
{"chapter": {"==": 1}} # Exactly chapter 1
|
|
||||||
{"page": {">": 100}} # Pages after 100
|
|
||||||
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
|
||||||
{"word_count": {"<": 500}} # Short passages
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Membership Operators
|
|
||||||
- `"in"`: Value is in list
|
|
||||||
- `"not_in"`: Value is not in list
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Examples
|
|
||||||
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
|
||||||
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
|
||||||
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
|
||||||
```
|
|
||||||
|
|
||||||
#### String Operators
|
|
||||||
- `"contains"`: String contains substring
|
|
||||||
- `"starts_with"`: String starts with prefix
|
|
||||||
- `"ends_with"`: String ends with suffix
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Examples
|
|
||||||
{"title": {"contains": "alice"}} # Title contains "alice"
|
|
||||||
{"filename": {"ends_with": ".py"}} # Python files
|
|
||||||
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Boolean Operators
|
|
||||||
- `"is_true"`: Field is truthy
|
|
||||||
- `"is_false"`: Field is falsy
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Examples
|
|
||||||
{"is_published": {"is_true": True}} # Published content
|
|
||||||
{"is_draft": {"is_false": False}} # Not drafts
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multiple Operators on Same Field
|
|
||||||
|
|
||||||
You can apply multiple operators to the same field (AND logic):
|
|
||||||
|
|
||||||
```python
|
|
||||||
metadata_filters = {
|
|
||||||
"word_count": {
|
|
||||||
">=": 100, # At least 100 words
|
|
||||||
"<=": 500 # At most 500 words
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Compound Filters
|
|
||||||
|
|
||||||
Multiple fields are combined with AND logic:
|
|
||||||
|
|
||||||
```python
|
|
||||||
metadata_filters = {
|
|
||||||
"chapter": {"<=": 10}, # Up to chapter 10
|
|
||||||
"character": {"==": "Alice"}, # About Alice
|
|
||||||
"spoiler_level": {"!=": "high"} # No major spoilers
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Use Case Examples
|
|
||||||
|
|
||||||
### 1. Spoiler-Free Book Search
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Reader has only read up to chapter 5
|
|
||||||
def search_spoiler_free(query, max_chapter):
|
|
||||||
return searcher.search(
|
|
||||||
query=query,
|
|
||||||
metadata_filters={
|
|
||||||
"chapter": {"<=": max_chapter},
|
|
||||||
"spoiler_level": {"in": ["none", "low"]}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Document Management by Date
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Find recent documents
|
|
||||||
recent_docs = searcher.search(
|
|
||||||
query="project updates",
|
|
||||||
metadata_filters={
|
|
||||||
"date": {">=": "2024-01-01"},
|
|
||||||
"document_type": {"==": "report"}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Code Search by File Type
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Search only Python files
|
|
||||||
python_code = searcher.search(
|
|
||||||
query="authentication function",
|
|
||||||
metadata_filters={
|
|
||||||
"file_extension": {"==": ".py"},
|
|
||||||
"lines_of_code": {"<": 100}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Content Filtering by Audience
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Age-appropriate content
|
|
||||||
family_content = searcher.search(
|
|
||||||
query="adventure stories",
|
|
||||||
metadata_filters={
|
|
||||||
"age_rating": {"in": ["G", "PG"]},
|
|
||||||
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Multi-Book Series Management
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Search across first 3 books only
|
|
||||||
early_series = searcher.search(
|
|
||||||
query="character development",
|
|
||||||
metadata_filters={
|
|
||||||
"series": {"==": "Harry Potter"},
|
|
||||||
"book_number": {"<=": 3}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running the Example
|
|
||||||
|
|
||||||
You can see metadata filtering in action with our spoiler-free book RAG example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Don't forget to set up the environment
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
|
|
||||||
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
|
||||||
export OPENAI_API_KEY="your-api-key-here"
|
|
||||||
|
|
||||||
# Run the spoiler-free book RAG example
|
|
||||||
uv run examples/spoiler_free_book_rag.py
|
|
||||||
```
|
|
||||||
|
|
||||||
This example demonstrates:
|
|
||||||
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
|
||||||
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
|
||||||
- Different scenarios for readers at various points in the book
|
|
||||||
|
|
||||||
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
|
||||||
|
|
||||||
## Advanced Patterns
|
|
||||||
|
|
||||||
### Custom Chunking with metadata
|
|
||||||
|
|
||||||
```python
|
|
||||||
def chunk_book_with_metadata(book_text, book_info):
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
for chapter_num, chapter_text in parse_chapters(book_text):
|
|
||||||
# Extract entities, themes, etc.
|
|
||||||
characters = extract_characters(chapter_text)
|
|
||||||
themes = classify_themes(chapter_text)
|
|
||||||
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
|
||||||
|
|
||||||
# Create chunks with rich metadata
|
|
||||||
for paragraph in split_paragraphs(chapter_text):
|
|
||||||
chunks.append({
|
|
||||||
"text": paragraph,
|
|
||||||
"metadata": {
|
|
||||||
"book_title": book_info["title"],
|
|
||||||
"chapter": chapter_num,
|
|
||||||
"characters": characters,
|
|
||||||
"themes": themes,
|
|
||||||
"spoiler_level": spoiler_level,
|
|
||||||
"word_count": len(paragraph.split()),
|
|
||||||
"reading_level": calculate_reading_level(paragraph)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
```
|
|
||||||
|
|
||||||
## Performance Considerations
|
|
||||||
|
|
||||||
### Efficient Filtering Strategies
|
|
||||||
|
|
||||||
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
|
||||||
|
|
||||||
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
|
||||||
|
|
||||||
### Best Practices
|
|
||||||
|
|
||||||
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
|
||||||
|
|
||||||
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
|
||||||
|
|
||||||
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
|
||||||
|
|
||||||
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
|
||||||
|
|
||||||
### Adding Metadata to Existing Indices
|
|
||||||
|
|
||||||
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Read existing passages and add metadata
|
|
||||||
def add_metadata_to_existing_chunks(chunks):
|
|
||||||
for chunk in chunks:
|
|
||||||
# Extract or assign metadata based on content
|
|
||||||
chunk["metadata"] = extract_metadata(chunk["text"])
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
# Rebuild index with metadata
|
|
||||||
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
|
||||||
builder = LeannBuilder("hnsw")
|
|
||||||
for chunk in enhanced_chunks:
|
|
||||||
builder.add_text(chunk["text"], chunk["metadata"])
|
|
||||||
builder.build_index("enhanced_index")
|
|
||||||
```
|
|
||||||
@@ -72,4 +72,4 @@ Using the wrong distance metric with normalized embeddings can lead to:
|
|||||||
- **Incorrect ranking** of search results
|
- **Incorrect ranking** of search results
|
||||||
- **Suboptimal performance** compared to using the correct metric
|
- **Suboptimal performance** compared to using the correct metric
|
||||||
|
|
||||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
## 🎯 Q2 2025
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
- [X] HNSW backend integration
|
|
||||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] HNSW backend integration
|
||||||
- [X] Real-time embedding pipeline
|
- [X] Real-time embedding pipeline
|
||||||
- [X] Memory-efficient graph pruning
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[sys.executable, "benchmarks/faiss_only.py"],
|
[sys.executable, "examples/faiss_only.py"],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300,
|
timeout=300,
|
||||||
@@ -115,7 +115,7 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
# Load and parse documents
|
# Load and parse documents
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"data",
|
"examples/data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
158
examples/document_search.py
Normal file
158
examples/document_search.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Document search demo with recompute mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Import backend packages to trigger plugin registration
|
||||||
|
try:
|
||||||
|
import leann_backend_diskann # noqa: F401
|
||||||
|
import leann_backend_hnsw # noqa: F401
|
||||||
|
|
||||||
|
print("INFO: Backend packages imported successfully.")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||||
|
|
||||||
|
# Import upper-level API from leann-core
|
||||||
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def load_sample_documents():
|
||||||
|
"""Create sample documents for demonstration"""
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"title": "Intro to Python",
|
||||||
|
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "ML Basics",
|
||||||
|
"content": "Machine learning builds systems that learn from data.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Data Structures",
|
||||||
|
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("==========================================================")
|
||||||
|
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||||
|
print("==========================================================")
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./test_indices")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||||
|
BACKEND_TO_TEST = "diskann"
|
||||||
|
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||||
|
shutil.rmtree(INDEX_DIR)
|
||||||
|
|
||||||
|
# --- 1. Build index ---
|
||||||
|
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
|
||||||
|
|
||||||
|
documents = load_sample_documents()
|
||||||
|
print(f"Loaded {len(documents)} sample documents.")
|
||||||
|
for doc in documents:
|
||||||
|
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print("\nIndex built!")
|
||||||
|
|
||||||
|
# --- 2. Basic search demo ---
|
||||||
|
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||||
|
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||||
|
|
||||||
|
query = "What is machine learning?"
|
||||||
|
print(f"\nQuery: '{query}'")
|
||||||
|
|
||||||
|
print("\n--- Basic search mode (PQ computation) ---")
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=2)
|
||||||
|
basic_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||||
|
print(">>> Basic search results <<<")
|
||||||
|
for i, res in enumerate(results, 1):
|
||||||
|
print(
|
||||||
|
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 3. Recompute search demo ---
|
||||||
|
print("\n[PHASE 3] Recompute search using embedding server...")
|
||||||
|
|
||||||
|
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||||
|
|
||||||
|
# Configure recompute parameters
|
||||||
|
recompute_params = {
|
||||||
|
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||||
|
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||||
|
"skip_search_reorder": True, # Skip search reordering
|
||||||
|
"dedup_node_dis": True, # Enable node distance deduplication
|
||||||
|
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||||
|
"batch_recompute": False, # Don't use batch recomputation
|
||||||
|
"global_pruning": False, # Don't use global pruning
|
||||||
|
"zmq_port": 5555, # ZMQ port
|
||||||
|
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Recompute parameter configuration:")
|
||||||
|
for key, value in recompute_params.items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
print("\n🔄 Executing Recompute search...")
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||||
|
recompute_time = time.time() - start_time
|
||||||
|
|
||||||
|
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||||
|
print(">>> Recompute search results <<<")
|
||||||
|
for i, res in enumerate(recompute_results, 1):
|
||||||
|
print(
|
||||||
|
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
print("\n--- Result comparison ---")
|
||||||
|
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||||
|
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||||
|
|
||||||
|
print("\nBasic search vs Recompute results:")
|
||||||
|
for i in range(min(len(results), len(recompute_results))):
|
||||||
|
basic_score = results[i].score
|
||||||
|
recompute_score = recompute_results[i].score
|
||||||
|
score_diff = abs(basic_score - recompute_score)
|
||||||
|
print(
|
||||||
|
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recompute_time > basic_time:
|
||||||
|
print("✅ Recompute mode working correctly (more accurate but slower)")
|
||||||
|
else:
|
||||||
|
print("i️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Recompute search failed: {e}")
|
||||||
|
print("This usually indicates an embedding server connection issue")
|
||||||
|
|
||||||
|
# --- 4. Chat demo ---
|
||||||
|
print("\n[PHASE 4] Starting chat session...")
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
|
chat_response = chat.ask(query)
|
||||||
|
print(f"You: {query}")
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
print("\n==========================================================")
|
||||||
|
print("✅ Demo finished successfully!")
|
||||||
|
print("==========================================================")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -52,11 +52,6 @@ class EmlxReader(BaseReader):
|
|||||||
docs: list[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get("max_count", 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
count = 0
|
count = 0
|
||||||
total_files = 0
|
|
||||||
successful_files = 0
|
|
||||||
failed_files = 0
|
|
||||||
|
|
||||||
print(f"Starting to process directory: {input_dir}")
|
|
||||||
|
|
||||||
# Walk through the directory recursively
|
# Walk through the directory recursively
|
||||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
@@ -64,12 +59,10 @@ class EmlxReader(BaseReader):
|
|||||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
# Check if we've reached the max count (skip if max_count == -1)
|
if count >= max_count:
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if filename.endswith(".emlx"):
|
if filename.endswith(".emlx"):
|
||||||
total_files += 1
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
filepath = os.path.join(dirpath, filename)
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
@@ -105,26 +98,17 @@ class EmlxReader(BaseReader):
|
|||||||
and not self.include_html
|
and not self.include_html
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
try:
|
body += part.get_payload(decode=True).decode(
|
||||||
payload = part.get_payload(decode=True)
|
"utf-8", errors="ignore"
|
||||||
if payload:
|
)
|
||||||
body += payload.decode("utf-8", errors="ignore")
|
# break
|
||||||
except Exception as e:
|
|
||||||
print(f"Error decoding payload: {e}")
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
try:
|
body = msg.get_payload(decode=True).decode(
|
||||||
payload = msg.get_payload(decode=True)
|
"utf-8", errors="ignore"
|
||||||
if payload:
|
)
|
||||||
body = payload.decode("utf-8", errors="ignore")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error decoding single part payload: {e}")
|
|
||||||
body = ""
|
|
||||||
|
|
||||||
# Only create document if we have some content
|
# Create document content with metadata embedded in text
|
||||||
if body.strip() or subject != "No Subject":
|
doc_content = f"""
|
||||||
# Create document content with metadata embedded in text
|
|
||||||
doc_content = f"""
|
|
||||||
[File]: {filename}
|
[File]: {filename}
|
||||||
[From]: {from_addr}
|
[From]: {from_addr}
|
||||||
[To]: {to_addr}
|
[To]: {to_addr}
|
||||||
@@ -134,34 +118,18 @@ class EmlxReader(BaseReader):
|
|||||||
{body}
|
{body}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# No separate metadata - everything is in the text
|
# No separate metadata - everything is in the text
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(text=doc_content, metadata={})
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
successful_files += 1
|
|
||||||
|
|
||||||
# Print first few successful files for debugging
|
|
||||||
if successful_files <= 3:
|
|
||||||
print(
|
|
||||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_files += 1
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
if failed_files <= 5: # Only print first few errors
|
|
||||||
print(f"Error parsing email from {filepath}: {e}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
failed_files += 1
|
print(f"Error reading file {filepath}: {e}")
|
||||||
if failed_files <= 5: # Only print first few errors
|
|
||||||
print(f"Error reading file {filepath}: {e}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Processing summary:")
|
print(f"Loaded {len(docs)} email documents")
|
||||||
print(f" Total .emlx files found: {total_files}")
|
|
||||||
print(f" Successfully loaded: {successful_files}")
|
|
||||||
print(f" Failed to load: {failed_files}")
|
|
||||||
print(f" Final documents: {len(docs)}")
|
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
@@ -65,7 +65,7 @@ def main():
|
|||||||
tracker.checkpoint("After Faiss index creation")
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"data",
|
"examples/data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
362
examples/google_history_reader_leann.py
Normal file
362
examples/google_history_reader_leann.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# python-dotenv is not installed; skip loading environment variables
|
||||||
|
dotenv = None
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||||
|
|
||||||
|
# Default Chrome profile path
|
||||||
|
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_chrome_profiles(
|
||||||
|
profile_dirs: list[Path],
|
||||||
|
index_path: str = "chrome_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple Chrome profile data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process per profile
|
||||||
|
embedding_model: The embedding model to use
|
||||||
|
embedding_mode: The embedding backend mode
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Chrome profile directory
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing Chrome profile {i + 1}/{len(profile_dirs)}: {profile_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir), max_count=max_count
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {profile_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||||
|
print(
|
||||||
|
"\033[91mYou need to close or quit all chrome browser before running this script\033[0m"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = node.get_content()
|
||||||
|
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index(
|
||||||
|
profile_path: str | None = None,
|
||||||
|
index_path: str = "chrome_history_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from Chrome history data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process
|
||||||
|
embedding_model: The embedding model to use
|
||||||
|
embedding_mode: The embedding backend mode
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from Chrome history data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
documents = reader.load_data(chrome_profile_path=profile_path, max_count=max_count)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} history documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=32,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LEANN Chrome History Reader - Create and query browser history index"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chrome-profile",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_CHROME_PROFILE,
|
||||||
|
help=f"Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./google_history_index",
|
||||||
|
help="Directory to store the LEANN index (default: ./chrome_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Maximum number of history entries to process (default: 1000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="The embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-existing-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Use existing index without rebuilding",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||||
|
|
||||||
|
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
if args.use_existing_index:
|
||||||
|
# Use existing index without rebuilding
|
||||||
|
if not Path(INDEX_PATH).exists():
|
||||||
|
print(f"Error: Index file not found at {INDEX_PATH}")
|
||||||
|
return
|
||||||
|
print(f"Using existing index at {INDEX_PATH}")
|
||||||
|
index_path = INDEX_PATH
|
||||||
|
else:
|
||||||
|
# Find Chrome profile directories
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
if args.auto_find_profiles:
|
||||||
|
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found automatically. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Use single specified profile
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
if not profile_path.exists():
|
||||||
|
print(f"Chrome profile not found: {profile_path}")
|
||||||
|
return
|
||||||
|
profile_dirs = [profile_path]
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||||
|
profile_dirs, INDEX_PATH, args.max_entries, args.embedding_model, args.embedding_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"What websites did I visit about machine learning?",
|
||||||
|
"Find my search history about programming",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -97,11 +97,6 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading Chrome history: {e}")
|
print(f"Error reading Chrome history: {e}")
|
||||||
# add you may need to close your browser to make the database file available
|
|
||||||
# also highlight in red
|
|
||||||
print(
|
|
||||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
|
||||||
)
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
@@ -411,8 +411,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
include_non_text = load_kwargs.get("include_non_text", False)
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
max_length = load_kwargs.get("max_length", 1000)
|
load_kwargs.get("max_length", 1000)
|
||||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
# Default WeChat export path
|
# Default WeChat export path
|
||||||
if wechat_export_dir is None:
|
if wechat_export_dir is None:
|
||||||
@@ -460,9 +460,9 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
# Concatenate messages based on rules
|
# Concatenate messages based on rules
|
||||||
message_groups = self._concatenate_messages(
|
message_groups = self._concatenate_messages(
|
||||||
readable_messages,
|
readable_messages,
|
||||||
max_length=max_length,
|
max_length=-1,
|
||||||
time_window_minutes=time_window_minutes,
|
time_window_minutes=-1,
|
||||||
overlap_messages=0, # No overlap between groups
|
overlap_messages=0, # Keep 2 messages overlap between groups
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
# Create documents from concatenated groups
|
||||||
@@ -532,9 +532,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(
|
doc = Document(text=doc_content, metadata={})
|
||||||
text=doc_content, metadata={"contact_name": contact_name}
|
|
||||||
)
|
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -562,8 +560,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
|
|
||||||
# Look for common export directory names
|
# Look for common export directory names
|
||||||
possible_dirs = [
|
possible_dirs = [
|
||||||
|
Path("./wechat_export_test"),
|
||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export"),
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
342
examples/mail_reader_leann.py
Normal file
342
examples/mail_reader_leann.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
# Add the project root to Python path so we can import from examples
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
def get_mail_path():
|
||||||
|
"""Get the mail path for the current user"""
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
|
||||||
|
# Default mail path for macOS
|
||||||
|
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_sources(
|
||||||
|
messages_dirs: list[Path],
|
||||||
|
index_path: str = "mail_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
include_html: bool = False,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_dirs: List of Path objects pointing to Messages directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process per directory
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple mail data sources...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Messages directory
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing Messages directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(messages_dir)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {messages_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = node.get_content()
|
||||||
|
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index(
|
||||||
|
mail_path: str,
|
||||||
|
index_path: str = "mail_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
include_html: bool = False,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from mail data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from mail data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
documents = reader.load_data(Path(mail_path))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path, llm_config={"type": "openai", "model": "gpt-4o"})
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.time()
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=20,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=32,
|
||||||
|
beam_width=1,
|
||||||
|
)
|
||||||
|
time.time()
|
||||||
|
# print(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
# highlight the answer
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="LEANN Mail Reader - Create and query email index")
|
||||||
|
# Remove --mail-path argument and auto-detect all Messages directories
|
||||||
|
# Remove DEFAULT_MAIL_PATH
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./mail_index",
|
||||||
|
help="Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-emails",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Maximum number of emails to process (-1 means all)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="Give me some funny advertisement about apple or other companies",
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-html",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Include HTML content in email processing (default: False)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="Embedding model to use (default: facebook/contriever)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"args: {args}")
|
||||||
|
|
||||||
|
# Automatically find all Messages directories under the current user's Mail directory
|
||||||
|
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||||
|
|
||||||
|
mail_path = get_mail_path()
|
||||||
|
print(f"Searching for email data in: {mail_path}")
|
||||||
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||||
|
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||||
|
# messages_dirs = messages_dirs[:1]
|
||||||
|
|
||||||
|
print("len(messages_dirs): ", len(messages_dirs))
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Messages directories found. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_sources(
|
||||||
|
messages_dirs,
|
||||||
|
INDEX_PATH,
|
||||||
|
args.max_emails,
|
||||||
|
args.include_html,
|
||||||
|
args.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
135
examples/mail_reader_llamaindex.py
Normal file
135
examples/mail_reader_llamaindex.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the project root to Python path so we can import from examples
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from llama_index.core import StorageContext, VectorStoreIndex
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# --- EMBEDDING MODEL ---
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
|
||||||
|
# --- END EMBEDDING MODEL ---
|
||||||
|
# Import EmlxReader from the new module
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_save_index(
|
||||||
|
mail_path: str,
|
||||||
|
save_dir: str = "mail_index_embedded",
|
||||||
|
max_count: int = 1000,
|
||||||
|
include_html: bool = False,
|
||||||
|
):
|
||||||
|
print("Creating index from mail data with embedded metadata...")
|
||||||
|
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
# Use facebook/contriever as the embedder
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
# set on device
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
embed_model._model.to("cuda")
|
||||||
|
# set mps
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
embed_model._model.to("mps")
|
||||||
|
else:
|
||||||
|
embed_model._model.to("cpu")
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents, transformations=[text_splitter], embed_model=embed_model
|
||||||
|
)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index_embedded"):
|
||||||
|
try:
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store, storage_context=storage_context
|
||||||
|
)
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LlamaIndex Mail Reader - Create and query email index"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mail-path",
|
||||||
|
type=str,
|
||||||
|
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||||
|
help="Path to mail data directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-dir",
|
||||||
|
type=str,
|
||||||
|
default="mail_index_embedded",
|
||||||
|
help="Directory to store the index (default: mail_index_embedded)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-emails",
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help="Maximum number of emails to process",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-html",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Include HTML content in email processing (default: False)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
mail_path = args.mail_path
|
||||||
|
save_dir = args.save_dir
|
||||||
|
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(
|
||||||
|
mail_path,
|
||||||
|
save_dir,
|
||||||
|
max_count=args.max_emails,
|
||||||
|
include_html=args.include_html,
|
||||||
|
)
|
||||||
|
if index:
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
146
examples/main_cli_example.py
Normal file
146
examples/main_cli_example.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args):
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Loading documents...")
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.data_dir,
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
print("Documents loaded.")
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# LeannBuilder now automatically detects normalized embeddings and sets appropriate distance metric
|
||||||
|
print(f"Using {args.embedding_model} with {args.embedding_mode} mode")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
# distance_metric is automatically set based on embedding model
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
|
# Build llm_config based on command line arguments
|
||||||
|
if args.llm == "simulated":
|
||||||
|
llm_config = {"type": "simulated"}
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
llm_config = {"type": "ollama", "model": args.model, "host": args.host}
|
||||||
|
elif args.llm == "hf":
|
||||||
|
llm_config = {"type": "hf", "model": args.model}
|
||||||
|
elif args.llm == "openai":
|
||||||
|
llm_config = {"type": "openai", "model": args.model}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown LLM type: {args.llm}")
|
||||||
|
|
||||||
|
print(f"Using LLM: {args.llm} with model: {args.model if args.llm != 'simulated' else 'N/A'}")
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
# query = (
|
||||||
|
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
# )
|
||||||
|
query = args.query
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
help="The LLM backend to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="gpt-4o",
|
||||||
|
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="The embedding backend mode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="The host for the Ollama API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./test_doc_files",
|
||||||
|
help="Directory where the Leann index will be stored.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="examples/data",
|
||||||
|
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||||
|
help="The query to ask the Leann chat system.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(main(args))
|
||||||
360
examples/multi_vector_aggregator.py
Normal file
360
examples/multi_vector_aggregator.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Multi-Vector Aggregator for Fat Embeddings
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
This module implements aggregation strategies for multi-vector embeddings,
|
||||||
|
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- MaxSim aggregation (take maximum similarity across patches)
|
||||||
|
- Voting-based aggregation (count patch matches)
|
||||||
|
- Weighted aggregation (attention-score weighted)
|
||||||
|
- Spatial clustering of matching patches
|
||||||
|
- Document-level result consolidation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PatchResult:
|
||||||
|
"""Represents a single patch search result."""
|
||||||
|
|
||||||
|
patch_id: int
|
||||||
|
image_name: str
|
||||||
|
image_path: str
|
||||||
|
coordinates: tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||||
|
score: float
|
||||||
|
attention_score: float
|
||||||
|
scale: float
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AggregatedResult:
|
||||||
|
"""Represents an aggregated document-level result."""
|
||||||
|
|
||||||
|
image_name: str
|
||||||
|
image_path: str
|
||||||
|
doc_score: float
|
||||||
|
patch_count: int
|
||||||
|
best_patch: PatchResult
|
||||||
|
all_patches: list[PatchResult]
|
||||||
|
aggregation_method: str
|
||||||
|
spatial_clusters: list[list[PatchResult]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MultiVectorAggregator:
|
||||||
|
"""
|
||||||
|
Aggregates multiple patch-level results into document-level results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
aggregation_method: str = "maxsim",
|
||||||
|
spatial_clustering: bool = True,
|
||||||
|
cluster_distance_threshold: float = 100.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the aggregator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||||
|
spatial_clustering: Whether to cluster spatially close patches
|
||||||
|
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||||
|
"""
|
||||||
|
self.aggregation_method = aggregation_method
|
||||||
|
self.spatial_clustering = spatial_clustering
|
||||||
|
self.cluster_distance_threshold = cluster_distance_threshold
|
||||||
|
|
||||||
|
def aggregate_results(
|
||||||
|
self, search_results: list[dict[str, Any]], top_k: int = 10
|
||||||
|
) -> list[AggregatedResult]:
|
||||||
|
"""
|
||||||
|
Aggregate patch-level search results into document-level results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of search results from LeannSearcher
|
||||||
|
top_k: Number of top documents to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of aggregated document results
|
||||||
|
"""
|
||||||
|
# Group results by image
|
||||||
|
image_groups = defaultdict(list)
|
||||||
|
|
||||||
|
for result in search_results:
|
||||||
|
metadata = result.metadata
|
||||||
|
if "image_name" in metadata and "patch_id" in metadata:
|
||||||
|
patch_result = PatchResult(
|
||||||
|
patch_id=metadata["patch_id"],
|
||||||
|
image_name=metadata["image_name"],
|
||||||
|
image_path=metadata["image_path"],
|
||||||
|
coordinates=tuple(metadata["coordinates"]),
|
||||||
|
score=result.score,
|
||||||
|
attention_score=metadata.get("attention_score", 0.0),
|
||||||
|
scale=metadata.get("scale", 1.0),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
image_groups[metadata["image_name"]].append(patch_result)
|
||||||
|
|
||||||
|
# Aggregate each image group
|
||||||
|
aggregated_results = []
|
||||||
|
for image_name, patches in image_groups.items():
|
||||||
|
if len(patches) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||||
|
aggregated_results.append(agg_result)
|
||||||
|
|
||||||
|
# Sort by aggregated score and return top-k
|
||||||
|
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||||
|
return aggregated_results[:top_k]
|
||||||
|
|
||||||
|
def _aggregate_image_patches(
|
||||||
|
self, image_name: str, patches: list[PatchResult]
|
||||||
|
) -> AggregatedResult:
|
||||||
|
"""Aggregate patches for a single image."""
|
||||||
|
|
||||||
|
if self.aggregation_method == "maxsim":
|
||||||
|
doc_score = max(patch.score for patch in patches)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "voting":
|
||||||
|
# Count patches above threshold
|
||||||
|
threshold = np.percentile([p.score for p in patches], 75)
|
||||||
|
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "weighted":
|
||||||
|
# Weight by attention scores
|
||||||
|
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||||
|
total_weights = sum(p.attention_score for p in patches)
|
||||||
|
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||||
|
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||||
|
|
||||||
|
elif self.aggregation_method == "mean":
|
||||||
|
doc_score = np.mean([patch.score for patch in patches])
|
||||||
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||||
|
|
||||||
|
# Spatial clustering if enabled
|
||||||
|
spatial_clusters = None
|
||||||
|
if self.spatial_clustering:
|
||||||
|
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||||
|
|
||||||
|
return AggregatedResult(
|
||||||
|
image_name=image_name,
|
||||||
|
image_path=patches[0].image_path,
|
||||||
|
doc_score=float(doc_score),
|
||||||
|
patch_count=len(patches),
|
||||||
|
best_patch=best_patch,
|
||||||
|
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||||
|
aggregation_method=self.aggregation_method,
|
||||||
|
spatial_clusters=spatial_clusters,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cluster_patches_spatially(self, patches: list[PatchResult]) -> list[list[PatchResult]]:
|
||||||
|
"""Cluster patches that are spatially close to each other."""
|
||||||
|
if len(patches) <= 1:
|
||||||
|
return [patches]
|
||||||
|
|
||||||
|
clusters = []
|
||||||
|
remaining_patches = patches.copy()
|
||||||
|
|
||||||
|
while remaining_patches:
|
||||||
|
# Start new cluster with highest scoring remaining patch
|
||||||
|
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||||
|
current_cluster = [seed_patch]
|
||||||
|
remaining_patches.remove(seed_patch)
|
||||||
|
|
||||||
|
# Add nearby patches to cluster
|
||||||
|
added_to_cluster = True
|
||||||
|
while added_to_cluster:
|
||||||
|
added_to_cluster = False
|
||||||
|
for patch in remaining_patches.copy():
|
||||||
|
if self._is_patch_nearby(patch, current_cluster):
|
||||||
|
current_cluster.append(patch)
|
||||||
|
remaining_patches.remove(patch)
|
||||||
|
added_to_cluster = True
|
||||||
|
|
||||||
|
clusters.append(current_cluster)
|
||||||
|
|
||||||
|
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||||
|
|
||||||
|
def _is_patch_nearby(self, patch: PatchResult, cluster: list[PatchResult]) -> bool:
|
||||||
|
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||||
|
patch_center = self._get_patch_center(patch.coordinates)
|
||||||
|
|
||||||
|
for cluster_patch in cluster:
|
||||||
|
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||||
|
distance = np.sqrt(
|
||||||
|
(patch_center[0] - cluster_center[0]) ** 2
|
||||||
|
+ (patch_center[1] - cluster_center[1]) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
|
if distance <= self.cluster_distance_threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_patch_center(self, coordinates: tuple[int, int, int, int]) -> tuple[float, float]:
|
||||||
|
"""Get center point of a patch."""
|
||||||
|
x1, y1, x2, y2 = coordinates
|
||||||
|
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||||
|
|
||||||
|
def print_aggregated_results(
|
||||||
|
self, results: list[AggregatedResult], max_patches_per_doc: int = 3
|
||||||
|
):
|
||||||
|
"""Pretty print aggregated results."""
|
||||||
|
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"\n{i + 1}. {result.image_name}")
|
||||||
|
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||||
|
print(f" Path: {result.image_path}")
|
||||||
|
|
||||||
|
# Show best patch
|
||||||
|
best = result.best_patch
|
||||||
|
print(
|
||||||
|
f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show top patches
|
||||||
|
print(" 📍 Top Patches:")
|
||||||
|
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||||
|
print(
|
||||||
|
f" {j + 1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show spatial clusters if available
|
||||||
|
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||||
|
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||||
|
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||||
|
cluster_score = max(p.score for p in cluster)
|
||||||
|
print(
|
||||||
|
f" Cluster {j + 1}: {len(cluster)} patches (best: {cluster_score:.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def demo_aggregation():
|
||||||
|
"""Demonstrate the multi-vector aggregation functionality."""
|
||||||
|
print("=== Multi-Vector Aggregation Demo ===")
|
||||||
|
|
||||||
|
# Simulate some patch-level search results
|
||||||
|
# In real usage, these would come from LeannSearcher.search()
|
||||||
|
|
||||||
|
class MockResult:
|
||||||
|
def __init__(self, score, metadata):
|
||||||
|
self.score = score
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
# Simulate results for 2 images with multiple patches each
|
||||||
|
mock_results = [
|
||||||
|
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||||
|
MockResult(
|
||||||
|
0.85,
|
||||||
|
{
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 3,
|
||||||
|
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||||
|
"attention_score": 0.92,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResult(
|
||||||
|
0.78,
|
||||||
|
{
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 7,
|
||||||
|
"coordinates": [200, 300, 324, 424], # Cat area
|
||||||
|
"attention_score": 0.88,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResult(
|
||||||
|
0.72,
|
||||||
|
{
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 12,
|
||||||
|
"coordinates": [150, 100, 274, 224], # Appliances
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResult(
|
||||||
|
0.65,
|
||||||
|
{
|
||||||
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
|
"patch_id": 15,
|
||||||
|
"coordinates": [50, 250, 174, 374], # Furniture
|
||||||
|
"attention_score": 0.70,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# Image 2: city_street.jpg - 3 patches
|
||||||
|
MockResult(
|
||||||
|
0.68,
|
||||||
|
{
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 2,
|
||||||
|
"coordinates": [300, 100, 424, 224], # Buildings
|
||||||
|
"attention_score": 0.80,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResult(
|
||||||
|
0.62,
|
||||||
|
{
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 8,
|
||||||
|
"coordinates": [100, 350, 224, 474], # Street level
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResult(
|
||||||
|
0.55,
|
||||||
|
{
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 11,
|
||||||
|
"coordinates": [400, 200, 524, 324], # Sky area
|
||||||
|
"attention_score": 0.60,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test different aggregation methods
|
||||||
|
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||||
|
|
||||||
|
aggregator = MultiVectorAggregator(
|
||||||
|
aggregation_method=method,
|
||||||
|
spatial_clustering=True,
|
||||||
|
cluster_distance_threshold=100.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||||
|
aggregator.print_aggregated_results(aggregated)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo_aggregation()
|
||||||
113
examples/openai_hnsw_example.py
Normal file
113
examples/openai_hnsw_example.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OpenAI Embedding Example
|
||||||
|
|
||||||
|
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Check if OpenAI API key is available
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||||
|
|
||||||
|
# Sample texts
|
||||||
|
sample_texts = [
|
||||||
|
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||||
|
"Natural language processing helps computers understand and generate human language.",
|
||||||
|
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
||||||
|
"Computer vision allows machines to interpret and understand visual information.",
|
||||||
|
"Reinforcement learning trains agents to make decisions through trial and error.",
|
||||||
|
"Data science combines statistics, math, and programming to extract insights from data.",
|
||||||
|
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||||
|
"Python is a popular programming language used extensively in data science and AI.",
|
||||||
|
"Neural networks are inspired by the structure and function of the human brain.",
|
||||||
|
"Big data refers to extremely large datasets that require special tools to process.",
|
||||||
|
]
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./simple_openai_test_index")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||||
|
|
||||||
|
print("\n=== Building Index with OpenAI Embeddings ===")
|
||||||
|
print(f"Index path: {INDEX_PATH}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use proper configuration for OpenAI embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
# HNSW settings for OpenAI embeddings
|
||||||
|
M=16, # Smaller graph degree
|
||||||
|
efConstruction=64, # Smaller construction complexity
|
||||||
|
is_compact=True, # Enable compact storage for recompute
|
||||||
|
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||||
|
for i, text in enumerate(sample_texts):
|
||||||
|
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||||
|
builder.add_text(text, metadata)
|
||||||
|
|
||||||
|
print("Building index...")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print("✅ Index built successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error building index: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n=== Testing Search ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
|
||||||
|
test_queries = [
|
||||||
|
"What is machine learning?",
|
||||||
|
"How do neural networks work?",
|
||||||
|
"Programming languages for data science",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
print(f"\n🔍 Query: '{query}'")
|
||||||
|
results = searcher.search(query, top_k=3)
|
||||||
|
|
||||||
|
print(f" Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" {i + 1}. Score: {result.score:.4f}")
|
||||||
|
print(f" Text: {result.text[:80]}...")
|
||||||
|
|
||||||
|
print("\n✅ Search test completed successfully!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during search: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = main()
|
||||||
|
if success:
|
||||||
|
print("\n🎉 Simple OpenAI index test completed successfully!")
|
||||||
|
else:
|
||||||
|
print("\n💥 Simple OpenAI index test failed!")
|
||||||
23
examples/resue_index.py
Normal file
23
examples/resue_index.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.api import LeannChat
|
||||||
|
|
||||||
|
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
|
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||||
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
response = chat.ask(
|
||||||
|
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1
|
||||||
|
)
|
||||||
|
print(f"\n[PHASE 2] Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -12,7 +12,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
@@ -197,32 +197,13 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Batch size for HNSW batched search (0 disables batching)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--llm-type",
|
|
||||||
type=str,
|
|
||||||
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
|
||||||
default="ollama",
|
|
||||||
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--llm-model",
|
|
||||||
type=str,
|
|
||||||
default="qwen3:1.7b",
|
|
||||||
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
# Assumes a project structure where the script is in 'benchmarks/'
|
# Assumes a project structure where the script is in 'examples/'
|
||||||
# and evaluation data is in 'benchmarks/data/'.
|
# and data is in 'data/' at the project root.
|
||||||
script_dir = Path(__file__).resolve().parent
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
data_root = script_dir / "data"
|
data_root = project_root / "data"
|
||||||
|
|
||||||
# Download data based on mode
|
# Download data based on mode
|
||||||
if args.mode == "build":
|
if args.mode == "build":
|
||||||
@@ -298,9 +279,7 @@ def main():
|
|||||||
|
|
||||||
if not args.index_path:
|
if not args.index_path:
|
||||||
print("No indices found. The data download should have included pre-built indices.")
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
print(
|
print("Please check the data/indices/ directory or provide --index-path manually.")
|
||||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Detect dataset type from index path to select the correct ground truth
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
@@ -337,24 +316,9 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(
|
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||||
queries[i],
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.ef_search,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
)
|
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
|
||||||
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
|
||||||
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
|
||||||
answer = chat.ask(
|
|
||||||
queries[i],
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.ef_search,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
)
|
|
||||||
print(f"Answer: {answer}")
|
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Simple demo showing basic leann usage
|
Simple demo showing basic leann usage
|
||||||
Run: uv run python examples/basic_demo.py
|
Run: uv run python examples/simple_demo.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -81,7 +81,7 @@ def main():
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
print("Demo completed! Try running:")
|
print("Demo completed! Try running:")
|
||||||
print(" uv run python apps/document_rag.py")
|
print(" uv run python examples/document_search.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
|
||||||
|
|
||||||
This example demonstrates how to use LEANN's metadata filtering to create
|
|
||||||
a spoiler-free book RAG system where users can search for information
|
|
||||||
up to a specific chapter they've read.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python spoiler_free_book_rag.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
# Add LEANN to path (adjust path as needed)
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Create sample book chunks with metadata for demonstration.
|
|
||||||
|
|
||||||
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
|
||||||
and extract chapter boundaries, character mentions, etc.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
book_title: Title of the book
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of chunk dictionaries with text and metadata
|
|
||||||
"""
|
|
||||||
# Sample book chunks with metadata
|
|
||||||
# In practice, you'd use proper text processing libraries
|
|
||||||
|
|
||||||
sample_chunks = [
|
|
||||||
{
|
|
||||||
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
|
||||||
"metadata": {
|
|
||||||
"book": book_title,
|
|
||||||
"chapter": 1,
|
|
||||||
"page": 1,
|
|
||||||
"characters": ["Alice", "Sister"],
|
|
||||||
"themes": ["boredom", "curiosity"],
|
|
||||||
"location": "riverbank",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
|
||||||
"metadata": {
|
|
||||||
"book": book_title,
|
|
||||||
"chapter": 1,
|
|
||||||
"page": 2,
|
|
||||||
"characters": ["Alice", "White Rabbit"],
|
|
||||||
"themes": ["decision", "surprise", "magic"],
|
|
||||||
"location": "riverbank",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
|
||||||
"metadata": {
|
|
||||||
"book": book_title,
|
|
||||||
"chapter": 2,
|
|
||||||
"page": 15,
|
|
||||||
"characters": ["Alice"],
|
|
||||||
"themes": ["falling", "wonder", "transformation"],
|
|
||||||
"location": "rabbit hole",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
|
||||||
"metadata": {
|
|
||||||
"book": book_title,
|
|
||||||
"chapter": 6,
|
|
||||||
"page": 85,
|
|
||||||
"characters": ["Alice", "Cheshire Cat"],
|
|
||||||
"themes": ["madness", "philosophy", "identity"],
|
|
||||||
"location": "Duchess's house",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
|
||||||
"metadata": {
|
|
||||||
"book": book_title,
|
|
||||||
"chapter": 8,
|
|
||||||
"page": 120,
|
|
||||||
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
|
||||||
"themes": ["justice", "absurdity", "authority"],
|
|
||||||
"location": "Queen's court",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
|
||||||
"metadata": {
|
|
||||||
"book": book_title,
|
|
||||||
"chapter": 12,
|
|
||||||
"page": 180,
|
|
||||||
"characters": ["Alice", "Sister", "Rabbit"],
|
|
||||||
"themes": ["revelation", "reality", "growth"],
|
|
||||||
"location": "riverbank",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
return sample_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Build a LEANN index with book chunks that include spoiler metadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
book_chunks: List of book chunks with metadata
|
|
||||||
index_name: Name for the index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the built index
|
|
||||||
"""
|
|
||||||
print(f"📚 Building spoiler-free book index: {index_name}")
|
|
||||||
|
|
||||||
# Initialize LEANN builder
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add each chunk with its metadata
|
|
||||||
for chunk in book_chunks:
|
|
||||||
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
|
||||||
|
|
||||||
# Build the index
|
|
||||||
index_path = f"{index_name}_book_index"
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
print(f"✅ Index built successfully: {index_path}")
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
|
|
||||||
def spoiler_free_search(
|
|
||||||
index_path: str,
|
|
||||||
query: str,
|
|
||||||
max_chapter: int,
|
|
||||||
character_filter: Optional[list[str]] = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Perform a spoiler-free search on the book index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: Search query
|
|
||||||
max_chapter: Maximum chapter number to include
|
|
||||||
character_filter: Optional list of characters to focus on
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of search results safe for the reader
|
|
||||||
"""
|
|
||||||
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
|
||||||
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
|
|
||||||
metadata_filters = {"chapter": {"<=": max_chapter}}
|
|
||||||
|
|
||||||
if character_filter:
|
|
||||||
metadata_filters["characters"] = {"contains": character_filter[0]}
|
|
||||||
|
|
||||||
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def demo_spoiler_free_rag():
|
|
||||||
"""
|
|
||||||
Demonstrate the spoiler-free book RAG system.
|
|
||||||
"""
|
|
||||||
print("🎭 Spoiler-Free Book RAG Demo")
|
|
||||||
print("=" * 40)
|
|
||||||
|
|
||||||
# Step 1: Prepare book data
|
|
||||||
book_title = "Alice's Adventures in Wonderland"
|
|
||||||
book_chunks = chunk_book_with_metadata(book_title)
|
|
||||||
|
|
||||||
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
|
||||||
|
|
||||||
# Step 2: Build the index (in practice, this would be done once)
|
|
||||||
try:
|
|
||||||
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
|
||||||
print(
|
|
||||||
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 3: Demonstrate various spoiler-free searches
|
|
||||||
search_scenarios = [
|
|
||||||
{
|
|
||||||
"description": "Reader who has only read Chapter 1",
|
|
||||||
"query": "What can you tell me about the rabbit?",
|
|
||||||
"max_chapter": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"description": "Reader who has read up to Chapter 5",
|
|
||||||
"query": "Tell me about Alice's adventures",
|
|
||||||
"max_chapter": 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"description": "Reader who has read most of the book",
|
|
||||||
"query": "What does the Cheshire Cat represent?",
|
|
||||||
"max_chapter": 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"description": "Reader who has read the whole book",
|
|
||||||
"query": "What can you tell me about the rabbit?",
|
|
||||||
"max_chapter": 12,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
for scenario in search_scenarios:
|
|
||||||
print(f"\n📚 Scenario: {scenario['description']}")
|
|
||||||
print(f" Query: {scenario['query']}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
results = spoiler_free_search(
|
|
||||||
index_path=index_path,
|
|
||||||
query=scenario["query"],
|
|
||||||
max_chapter=scenario["max_chapter"],
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f" 📄 Found {len(results)} results:")
|
|
||||||
for i, result in enumerate(results[:3], 1): # Show top 3
|
|
||||||
chapter = result.metadata.get("chapter", "?")
|
|
||||||
location = result.metadata.get("location", "?")
|
|
||||||
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ Search failed: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("📚 LEANN Spoiler-Free Book RAG Example")
|
|
||||||
print("=====================================")
|
|
||||||
|
|
||||||
try:
|
|
||||||
demo_spoiler_free_rag()
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error running demo: {e}")
|
|
||||||
320
examples/wechat_history_reader_leann.py
Normal file
320
examples/wechat_history_reader_leann.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Default WeChat export directory
|
||||||
|
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs: list[Path],
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple WeChat export data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dirs: List of Path objects pointing to WeChat export directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process per export
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each WeChat export directory
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=True, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = (
|
||||||
|
"[Contact] means the message is from: "
|
||||||
|
+ doc.metadata["contact_name"]
|
||||||
|
+ "\n"
|
||||||
|
+ node.get_content()
|
||||||
|
)
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index(
|
||||||
|
export_dir: str | None = None,
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from WeChat chat history data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from WeChat chat history data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=export_dir,
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=False, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} chat documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=20,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=16,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
|
)
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function with integrated WeChat export functionality."""
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||||
|
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_history_magic_test_11Debug_new",
|
||||||
|
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Maximum number of chat entries to process (default: 5000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||||
|
|
||||||
|
print(f"Using WeChat export directory: {args.export_dir}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
8
packages/leann-backend-diskann/CMakeLists.txt
Normal file
8
packages/leann-backend-diskann/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
@@ -1,7 +1 @@
|
|||||||
from . import diskann_backend as diskann_backend
|
from . import diskann_backend as diskann_backend
|
||||||
from . import graph_partition
|
|
||||||
|
|
||||||
# Export main classes and functions
|
|
||||||
from .graph_partition import GraphPartitioner, partition_graph
|
|
||||||
|
|
||||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -22,11 +22,6 @@ logger = logging.getLogger(__name__)
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def suppress_cpp_output_if_needed():
|
def suppress_cpp_output_if_needed():
|
||||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
|
||||||
if os.getenv("CI") == "true":
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
@@ -142,71 +137,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
|
||||||
"""
|
|
||||||
Safely cleanup files after partition.
|
|
||||||
In partition mode, C++ doesn't read _disk.index content,
|
|
||||||
so we can delete it if all derived files exist.
|
|
||||||
"""
|
|
||||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
|
||||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
|
||||||
|
|
||||||
# Required files that C++ partition mode needs
|
|
||||||
# Note: C++ generates these with _disk.index suffix
|
|
||||||
disk_suffix = "_disk.index"
|
|
||||||
required_files = [
|
|
||||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
|
||||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
|
||||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
|
||||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if all required files exist
|
|
||||||
missing_files = []
|
|
||||||
for filename in required_files:
|
|
||||||
file_path = index_dir / filename
|
|
||||||
if not file_path.exists():
|
|
||||||
missing_files.append(filename)
|
|
||||||
|
|
||||||
if missing_files:
|
|
||||||
logger.warning(
|
|
||||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
|
||||||
)
|
|
||||||
logger.info("Keeping all original files for safety")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate space savings
|
|
||||||
space_saved = 0
|
|
||||||
files_to_delete = []
|
|
||||||
|
|
||||||
if disk_index_file.exists():
|
|
||||||
space_saved += disk_index_file.stat().st_size
|
|
||||||
files_to_delete.append(disk_index_file)
|
|
||||||
|
|
||||||
if beam_search_file.exists():
|
|
||||||
space_saved += beam_search_file.stat().st_size
|
|
||||||
files_to_delete.append(beam_search_file)
|
|
||||||
|
|
||||||
# Safe to delete!
|
|
||||||
for file_to_delete in files_to_delete:
|
|
||||||
try:
|
|
||||||
os.remove(file_to_delete)
|
|
||||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
|
||||||
|
|
||||||
if space_saved > 0:
|
|
||||||
space_saved_mb = space_saved / (1024 * 1024)
|
|
||||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
|
||||||
|
|
||||||
# Show what files are kept
|
|
||||||
logger.info("📁 Kept essential files for partition mode:")
|
|
||||||
for filename in required_files:
|
|
||||||
file_path = index_dir / filename
|
|
||||||
if file_path.exists():
|
|
||||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
||||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
@@ -221,17 +151,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
|
|
||||||
# Extract is_recompute from nested backend_kwargs if needed
|
|
||||||
is_recompute = build_kwargs.get("is_recompute", False)
|
|
||||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
|
||||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
|
||||||
|
|
||||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
|
||||||
if "backend_kwargs" in build_kwargs:
|
|
||||||
nested_params = build_kwargs.pop("backend_kwargs")
|
|
||||||
build_kwargs.update(nested_params)
|
|
||||||
|
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
@@ -266,30 +185,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
build_kwargs.get("pq_disk_bytes", 0),
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Auto-partition if is_recompute is enabled
|
|
||||||
if build_kwargs.get("is_recompute", False):
|
|
||||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
|
||||||
from .graph_partition import partition_graph
|
|
||||||
|
|
||||||
# Partition the index using absolute paths
|
|
||||||
# Convert to absolute paths to avoid issues with working directory changes
|
|
||||||
absolute_index_dir = Path(index_dir).resolve()
|
|
||||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph(
|
|
||||||
index_prefix_path=absolute_index_prefix_path,
|
|
||||||
output_dir=str(absolute_index_dir),
|
|
||||||
partition_prefix=index_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
|
||||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
|
||||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
|
||||||
|
|
||||||
logger.info("✅ Graph partitioning completed successfully!")
|
|
||||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
|
||||||
logger.info(f" - Partition file: {partition_bin_path}")
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
@@ -318,26 +213,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
# Store the initialization parameters for later use
|
# Store the initialization parameters for later use
|
||||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
# C++ internally constructs: index_prefix + "_disk.index"
|
|
||||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
|
||||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
||||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
|
||||||
|
|
||||||
# Auto-detect partition files and set partition_prefix
|
|
||||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
|
||||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
|
||||||
|
|
||||||
partition_prefix = ""
|
|
||||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
|
||||||
# C++ expects full path prefix, not just filename
|
|
||||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
||||||
logger.info(
|
|
||||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("No partition files detected, using standard index files")
|
|
||||||
|
|
||||||
self._init_params = {
|
self._init_params = {
|
||||||
"metric_enum": metric_enum,
|
"metric_enum": metric_enum,
|
||||||
"full_index_prefix": full_index_prefix,
|
"full_index_prefix": full_index_prefix,
|
||||||
@@ -345,14 +221,8 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
"cache_mechanism": 1,
|
"cache_mechanism": 1,
|
||||||
"pq_prefix": "",
|
"pq_prefix": "",
|
||||||
"partition_prefix": partition_prefix,
|
"partition_prefix": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Log partition configuration for debugging
|
|
||||||
if partition_prefix:
|
|
||||||
logger.info(
|
|
||||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
||||||
)
|
|
||||||
self._diskannpy = diskannpy
|
self._diskannpy = diskannpy
|
||||||
self._current_zmq_port = None
|
self._current_zmq_port = None
|
||||||
self._index = None
|
self._index = None
|
||||||
@@ -389,7 +259,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -441,14 +311,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
else: # "global"
|
else: # "global"
|
||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
# Strategy:
|
# Perform search with suppressed C++ output based on log level
|
||||||
# - Traversal always uses PQ distances
|
|
||||||
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
|
|
||||||
# (fetch embeddings for the final candidate set only)
|
|
||||||
# - Do not recompute neighbor distances along the path
|
|
||||||
use_deferred_fetch = True if recompute_embeddings else False
|
|
||||||
recompute_neighors = False # Expected typo. For backward compatibility.
|
|
||||||
|
|
||||||
with suppress_cpp_output_if_needed():
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
@@ -457,9 +320,9 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
complexity,
|
complexity,
|
||||||
beam_width,
|
beam_width,
|
||||||
self.num_threads,
|
self.num_threads,
|
||||||
use_deferred_fetch,
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
kwargs.get("skip_search_reorder", False),
|
kwargs.get("skip_search_reorder", False),
|
||||||
recompute_neighors,
|
recompute_embeddings,
|
||||||
dedup_node_dis,
|
dedup_node_dis,
|
||||||
prune_ratio,
|
prune_ratio,
|
||||||
batch_recompute,
|
batch_recompute,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -33,7 +32,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
@@ -81,9 +80,10 @@ def create_diskann_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
passages = PassageManager(meta["passage_sources"])
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
logger.info(
|
||||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
|
)
|
||||||
|
|
||||||
# Import protobuf after ensuring the path is correct
|
# Import protobuf after ensuring the path is correct
|
||||||
try:
|
try:
|
||||||
@@ -101,9 +101,8 @@ def create_diskann_embedding_server(
|
|||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
socket.setsockopt(zmq.LINGER, 0)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -220,217 +219,30 @@ def create_diskann_embedding_server(
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
|
||||||
|
|
||||||
This creates its own REP socket, binds to zmq_port, and periodically
|
|
||||||
checks shutdown_event using recv timeouts to exit cleanly.
|
|
||||||
"""
|
|
||||||
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
|
||||||
|
|
||||||
context = zmq.Context()
|
|
||||||
rep_socket = context.socket(zmq.REP)
|
|
||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
|
||||||
|
|
||||||
# Set receive timeout so we can check shutdown_event periodically
|
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
try:
|
|
||||||
e2e_start = time.time()
|
|
||||||
# REP socket receives single-part messages
|
|
||||||
message = rep_socket.recv()
|
|
||||||
|
|
||||||
# Check for empty messages - REP socket requires response to every request
|
|
||||||
if not message:
|
|
||||||
logger.warning("Received empty message, sending empty response")
|
|
||||||
rep_socket.send(b"")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try protobuf first (same logic as original)
|
|
||||||
texts = []
|
|
||||||
is_text_request = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.ParseFromString(message)
|
|
||||||
node_ids = list(req_proto.node_ids)
|
|
||||||
|
|
||||||
# Look up texts by node IDs
|
|
||||||
for nid in node_ids:
|
|
||||||
try:
|
|
||||||
passage_data = passages.get_passage(str(nid))
|
|
||||||
txt = passage_data["text"]
|
|
||||||
if not txt:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
|
|
||||||
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
|
||||||
except Exception:
|
|
||||||
# Fallback to msgpack for text requests
|
|
||||||
try:
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
request = msgpack.unpackb(message)
|
|
||||||
if isinstance(request, list) and all(
|
|
||||||
isinstance(item, str) for item in request
|
|
||||||
):
|
|
||||||
texts = request
|
|
||||||
is_text_request = True
|
|
||||||
logger.info(
|
|
||||||
f"ZMQ received msgpack text request for {len(texts)} texts"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Not a valid msgpack text request")
|
|
||||||
except Exception:
|
|
||||||
logger.error("Both protobuf and msgpack parsing failed!")
|
|
||||||
# Send error response
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
rep_socket.send(resp_proto.SerializeToString())
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process the request
|
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
|
||||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
|
||||||
|
|
||||||
# Validation
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
|
||||||
logger.error("NaN or Inf detected in embeddings!")
|
|
||||||
# Send error response
|
|
||||||
if is_text_request:
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
response_data = msgpack.packb([])
|
|
||||||
else:
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
rep_socket.send(response_data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prepare response based on request type
|
|
||||||
if is_text_request:
|
|
||||||
# For direct text requests, return msgpack
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
response_data = msgpack.packb(embeddings.tolist())
|
|
||||||
else:
|
|
||||||
# For protobuf requests, return protobuf
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
|
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
|
|
||||||
# Send response back to the client
|
|
||||||
rep_socket.send(response_data)
|
|
||||||
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
try:
|
|
||||||
# Send error response for REP socket
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
rep_socket.send(resp_proto.SerializeToString())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
rep_socket.close(0)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
context.term()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
|
||||||
|
|
||||||
# Add shutdown coordination
|
|
||||||
shutdown_event = threading.Event()
|
|
||||||
|
|
||||||
def shutdown_zmq_server():
|
|
||||||
"""Gracefully shutdown ZMQ server."""
|
|
||||||
logger.info("Initiating graceful shutdown...")
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
if zmq_thread.is_alive():
|
|
||||||
logger.info("Waiting for ZMQ thread to finish...")
|
|
||||||
zmq_thread.join(timeout=5)
|
|
||||||
if zmq_thread.is_alive():
|
|
||||||
logger.warning("ZMQ thread did not finish in time")
|
|
||||||
|
|
||||||
# Clean up ZMQ resources
|
|
||||||
try:
|
|
||||||
# Note: socket and context are cleaned up by thread exit
|
|
||||||
logger.info("ZMQ resources cleaned up")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
|
||||||
|
|
||||||
# Clean up other resources
|
|
||||||
try:
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
logger.info("Additional resources cleaned up")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error cleaning additional resources: {e}")
|
|
||||||
|
|
||||||
logger.info("Graceful shutdown completed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers within this function scope
|
|
||||||
import signal
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
shutdown_zmq_server()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Start ZMQ thread (NOT daemon!)
|
|
||||||
zmq_thread = threading.Thread(
|
|
||||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
|
||||||
daemon=False, # Not daemon - we want to wait for it
|
|
||||||
)
|
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while True:
|
||||||
time.sleep(0.1) # Check shutdown more frequently
|
time.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("DiskANN Server shutting down...")
|
logger.info("DiskANN Server shutting down...")
|
||||||
shutdown_zmq_server()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If we reach here, shutdown was triggered by signal
|
|
||||||
logger.info("Main loop exited, process should be shutting down")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Signal handlers are now registered within create_diskann_embedding_server
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
@@ -449,7 +261,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -1,299 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Graph Partition Module for LEANN DiskANN Backend
|
|
||||||
|
|
||||||
This module provides Python bindings for the graph partition functionality
|
|
||||||
of DiskANN, allowing users to partition disk-based indices for better
|
|
||||||
performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class GraphPartitioner:
|
|
||||||
"""
|
|
||||||
A Python interface for DiskANN's graph partition functionality.
|
|
||||||
|
|
||||||
This class provides methods to partition disk-based indices for improved
|
|
||||||
search performance and memory efficiency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, build_type: str = "release"):
|
|
||||||
"""
|
|
||||||
Initialize the GraphPartitioner.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
build_type: Build type for the executables ("debug" or "release")
|
|
||||||
"""
|
|
||||||
self.build_type = build_type
|
|
||||||
self._ensure_executables()
|
|
||||||
|
|
||||||
def _get_executable_path(self, name: str) -> str:
|
|
||||||
"""Get the path to a graph partition executable."""
|
|
||||||
# Get the directory where this Python module is located
|
|
||||||
module_dir = Path(__file__).parent
|
|
||||||
# Navigate to the graph_partition directory
|
|
||||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
|
||||||
|
|
||||||
if not executable_path.exists():
|
|
||||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
|
||||||
|
|
||||||
return str(executable_path)
|
|
||||||
|
|
||||||
def _ensure_executables(self):
|
|
||||||
"""Ensure that the required executables are built."""
|
|
||||||
try:
|
|
||||||
self._get_executable_path("partitioner")
|
|
||||||
self._get_executable_path("index_relayout")
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Try to build the executables automatically
|
|
||||||
print("Executables not found, attempting to build them...")
|
|
||||||
self._build_executables()
|
|
||||||
|
|
||||||
def _build_executables(self):
|
|
||||||
"""Build the required executables."""
|
|
||||||
graph_partition_dir = (
|
|
||||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
)
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.chdir(graph_partition_dir)
|
|
||||||
|
|
||||||
# Clean any existing build
|
|
||||||
if (graph_partition_dir / "build").exists():
|
|
||||||
shutil.rmtree(graph_partition_dir / "build")
|
|
||||||
|
|
||||||
# Run the build script
|
|
||||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
|
||||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
|
||||||
|
|
||||||
# Check if executables were created
|
|
||||||
partitioner_path = self._get_executable_path("partitioner")
|
|
||||||
relayout_path = self._get_executable_path("index_relayout")
|
|
||||||
|
|
||||||
print(f"✅ Built partitioner: {partitioner_path}")
|
|
||||||
print(f"✅ Built index_relayout: {relayout_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to build executables: {e}")
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
|
|
||||||
def partition_graph(
|
|
||||||
self,
|
|
||||||
index_prefix_path: str,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
partition_prefix: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Partition a disk-based index for improved performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
|
||||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
|
||||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
|
||||||
**kwargs: Additional parameters for graph partitioning:
|
|
||||||
- gp_times: Number of LDG partition iterations (default: 10)
|
|
||||||
- lock_nums: Number of lock nodes (default: 10)
|
|
||||||
- cut: Cut adjacency list degree (default: 100)
|
|
||||||
- scale_factor: Scale factor (default: 1)
|
|
||||||
- data_type: Data type (default: "float")
|
|
||||||
- thread_nums: Number of threads (default: 10)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the partitioning process fails
|
|
||||||
"""
|
|
||||||
# Set default parameters
|
|
||||||
params = {
|
|
||||||
"gp_times": 10,
|
|
||||||
"lock_nums": 10,
|
|
||||||
"cut": 100,
|
|
||||||
"scale_factor": 1,
|
|
||||||
"data_type": "float",
|
|
||||||
"thread_nums": 10,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine output directory
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = str(Path(index_prefix_path).parent)
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Determine partition prefix
|
|
||||||
if partition_prefix is None:
|
|
||||||
partition_prefix = Path(index_prefix_path).name
|
|
||||||
|
|
||||||
# Get executable paths
|
|
||||||
partitioner_path = self._get_executable_path("partitioner")
|
|
||||||
relayout_path = self._get_executable_path("index_relayout")
|
|
||||||
|
|
||||||
# Create temporary directory for processing
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
# Change to the graph_partition directory for temporary files
|
|
||||||
graph_partition_dir = (
|
|
||||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
)
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.chdir(graph_partition_dir)
|
|
||||||
|
|
||||||
# Create temporary data directory
|
|
||||||
temp_data_dir = Path(temp_dir) / "data"
|
|
||||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Set up paths for temporary files
|
|
||||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
|
||||||
graph_gp_path = (
|
|
||||||
graph_path
|
|
||||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
|
||||||
)
|
|
||||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Find input index file
|
|
||||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
old_index_file = f"{index_prefix_path}_disk.index"
|
|
||||||
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
|
||||||
|
|
||||||
# Run partitioner
|
|
||||||
gp_file_path = graph_gp_path / "_part.bin"
|
|
||||||
partitioner_cmd = [
|
|
||||||
partitioner_path,
|
|
||||||
"--index_file",
|
|
||||||
old_index_file,
|
|
||||||
"--data_type",
|
|
||||||
params["data_type"],
|
|
||||||
"--gp_file",
|
|
||||||
str(gp_file_path),
|
|
||||||
"-T",
|
|
||||||
str(params["thread_nums"]),
|
|
||||||
"--ldg_times",
|
|
||||||
str(params["gp_times"]),
|
|
||||||
"--scale",
|
|
||||||
str(params["scale_factor"]),
|
|
||||||
"--mode",
|
|
||||||
"1",
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
|
||||||
result = subprocess.run(
|
|
||||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Partitioner failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run relayout
|
|
||||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
|
||||||
relayout_cmd = [
|
|
||||||
relayout_path,
|
|
||||||
old_index_file,
|
|
||||||
str(gp_file_path),
|
|
||||||
params["data_type"],
|
|
||||||
"1",
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
|
||||||
result = subprocess.run(
|
|
||||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Relayout failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy results to output directory
|
|
||||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
|
||||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
|
||||||
|
|
||||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
|
||||||
shutil.copy2(gp_file_path, partition_bin_path)
|
|
||||||
|
|
||||||
print(f"Results copied to: {output_dir}")
|
|
||||||
return str(disk_graph_path), str(partition_bin_path)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
|
|
||||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
|
||||||
"""
|
|
||||||
Get information about a partition file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
partition_bin_path: Path to the partition binary file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing partition information
|
|
||||||
"""
|
|
||||||
if not os.path.exists(partition_bin_path):
|
|
||||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
|
||||||
|
|
||||||
# For now, return basic file information
|
|
||||||
# In the future, this could parse the binary file for detailed info
|
|
||||||
stat = os.stat(partition_bin_path)
|
|
||||||
return {
|
|
||||||
"file_size": stat.st_size,
|
|
||||||
"file_path": partition_bin_path,
|
|
||||||
"modified_time": stat.st_mtime,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def partition_graph(
|
|
||||||
index_prefix_path: str,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
partition_prefix: Optional[str] = None,
|
|
||||||
build_type: str = "release",
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Convenience function to partition a graph index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix
|
|
||||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
|
||||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
|
||||||
build_type: Build type for executables ("debug" or "release")
|
|
||||||
**kwargs: Additional parameters for graph partitioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
"""
|
|
||||||
partitioner = GraphPartitioner(build_type=build_type)
|
|
||||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Example: partition an index
|
|
||||||
try:
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph(
|
|
||||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
|
||||||
)
|
|
||||||
print("Partitioning completed successfully!")
|
|
||||||
print(f"Disk graph index: {disk_graph_path}")
|
|
||||||
print(f"Partition binary: {partition_bin_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Partitioning failed: {e}")
|
|
||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.2"
|
version = "0.1.16"
|
||||||
dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.1.16", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
@@ -17,5 +17,3 @@ editable.mode = "redirect"
|
|||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
# Let CMake find packages via Homebrew prefix
|
|
||||||
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
|
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 19f9603c72...67a2611ad1
@@ -5,20 +5,11 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
|
|||||||
|
|
||||||
# Set OpenMP path for macOS
|
# Set OpenMP path for macOS
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
# Detect Homebrew installation path (Apple Silicon vs Intel)
|
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
if(EXISTS "/opt/homebrew/opt/libomp")
|
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
set(HOMEBREW_PREFIX "/opt/homebrew")
|
|
||||||
elseif(EXISTS "/usr/local/opt/libomp")
|
|
||||||
set(HOMEBREW_PREFIX "/usr/local")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
|
||||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
|
||||||
set(OpenMP_C_LIB_NAMES "omp")
|
set(OpenMP_C_LIB_NAMES "omp")
|
||||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||||
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
|
|
||||||
# Force use of system libc++ to avoid version mismatch
|
# Force use of system libc++ to avoid version mismatch
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||||
@@ -49,28 +40,9 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
# Disable additional SIMD versions to speed up compilation
|
||||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
|
||||||
|
|
||||||
# ARM64-specific configuration
|
|
||||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
|
||||||
message(STATUS "Configuring Faiss for ARM64 architecture")
|
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
|
||||||
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
|
||||||
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
|
||||||
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
|
||||||
else()
|
|
||||||
# Use generic optimization for other ARM64 platforms (like macOS)
|
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
|
||||||
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
|
||||||
message(STATUS "Using ARM64-compatible Faiss submodule")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Additional optimization options from INSTALL.md
|
# Additional optimization options from INSTALL.md
|
||||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc # Import garbage collector interface
|
import gc # Import garbage collector interface
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
@@ -8,12 +7,6 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# Set up logging to avoid print buffer issues
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
|
||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
@@ -250,8 +243,6 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
output_filename: Output CSR index file
|
output_filename: Output CSR index file
|
||||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||||
"""
|
"""
|
||||||
# Keep prints simple; rely on CI runner to flush output as needed
|
|
||||||
|
|
||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
@@ -55,13 +54,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||||
self.dimensions = self.build_params.get("dimensions")
|
self.dimensions = self.build_params.get("dimensions")
|
||||||
if not self.is_recompute and self.is_compact:
|
if not self.is_recompute:
|
||||||
# Auto-correct: non-recompute requires non-compact storage for HNSW
|
if self.is_compact:
|
||||||
logger.warning(
|
# TODO: support this case @andy
|
||||||
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
|
raise ValueError(
|
||||||
)
|
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
||||||
self.is_compact = False
|
)
|
||||||
self.build_params["is_compact"] = False
|
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
@@ -154,7 +152,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
@@ -186,11 +184,9 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
"""
|
"""
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
if not recompute_embeddings and self.is_pruned:
|
if not recompute_embeddings:
|
||||||
raise RuntimeError(
|
if self.is_pruned:
|
||||||
"Recompute is required for pruned/compact HNSW index. "
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
|
|
||||||
)
|
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
@@ -237,7 +233,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|
||||||
search_time = time.time()
|
|
||||||
self._index.search(
|
self._index.search(
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
faiss.swig_ptr(query),
|
faiss.swig_ptr(query),
|
||||||
@@ -246,8 +241,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.swig_ptr(labels),
|
faiss.swig_ptr(labels),
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
|
||||||
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -34,7 +33,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -82,315 +81,199 @@ def create_hnsw_embedding_server(
|
|||||||
with open(passages_file) as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Let PassageManager handle path resolution uniformly. It supports fallback order:
|
# Convert relative paths to absolute paths based on metadata file location
|
||||||
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
|
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
passage_sources = []
|
||||||
# Dimension from metadata for shaping responses
|
for source in meta["passage_sources"]:
|
||||||
try:
|
source_copy = source.copy()
|
||||||
embedding_dim: int = int(meta.get("dimensions", 0))
|
# Convert relative paths to absolute paths
|
||||||
except Exception:
|
if not Path(source_copy["path"]).is_absolute():
|
||||||
embedding_dim = 0
|
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||||
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
if not Path(source_copy["index_path"]).is_absolute():
|
||||||
|
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||||
|
passage_sources.append(source_copy)
|
||||||
|
|
||||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
passages = PassageManager(passage_sources)
|
||||||
|
logger.info(
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
)
|
||||||
|
|
||||||
Creates its own REP socket bound to zmq_port and polls with timeouts
|
|
||||||
to allow graceful shutdown.
|
|
||||||
"""
|
|
||||||
logger.info("ZMQ server thread started with shutdown support")
|
|
||||||
|
|
||||||
|
def zmq_server_thread():
|
||||||
|
"""ZMQ server thread"""
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
rep_socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
|
||||||
# Keep sends from blocking during shutdown; fail fast and drop on close
|
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
|
||||||
|
|
||||||
# Track last request type/length for shape-correct fallbacks
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
last_request_length = 0
|
|
||||||
|
|
||||||
try:
|
while True:
|
||||||
while not shutdown_event.is_set():
|
try:
|
||||||
try:
|
message_bytes = socket.recv()
|
||||||
e2e_start = time.time()
|
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
|
||||||
request_bytes = rep_socket.recv()
|
|
||||||
|
|
||||||
# Rest of the processing logic (same as original)
|
e2e_start = time.time()
|
||||||
request = msgpack.unpackb(request_bytes)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
|
|
||||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
# Handle direct text embedding request
|
||||||
response_bytes = msgpack.packb([model_name])
|
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||||
rep_socket.send(response_bytes)
|
# Check if this is a direct text request (list of strings)
|
||||||
continue
|
if all(isinstance(item, str) for item in request_payload):
|
||||||
|
logger.info(
|
||||||
|
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||||
|
)
|
||||||
|
|
||||||
# Handle direct text embedding request
|
# Use unified embedding computation (now with model caching)
|
||||||
if (
|
embeddings = compute_embeddings(
|
||||||
isinstance(request, list)
|
request_payload, model_name, mode=embedding_mode
|
||||||
and request
|
)
|
||||||
and all(isinstance(item, str) for item in request)
|
|
||||||
):
|
response = embeddings.tolist()
|
||||||
last_request_type = "text"
|
socket.send(msgpack.packb(response))
|
||||||
last_request_length = len(request)
|
|
||||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle distance calculation request: [[ids], [query_vector]]
|
# Handle distance calculation requests
|
||||||
if (
|
if (
|
||||||
isinstance(request, list)
|
isinstance(request_payload, list)
|
||||||
and len(request) == 2
|
and len(request_payload) == 2
|
||||||
and isinstance(request[0], list)
|
and isinstance(request_payload[0], list)
|
||||||
and isinstance(request[1], list)
|
and isinstance(request_payload[1], list)
|
||||||
):
|
):
|
||||||
node_ids = request[0]
|
node_ids = request_payload[0]
|
||||||
# Handle nested [[ids]] shape defensively
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
|
||||||
node_ids = node_ids[0]
|
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
|
||||||
last_request_type = "distance"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
# Gather texts for found ids
|
# Get embeddings for node IDs
|
||||||
texts: list[str] = []
|
texts = []
|
||||||
found_indices: list[int] = []
|
for nid in node_ids:
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_data = passages.get_passage(str(nid))
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {nid}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
|
|
||||||
# Prepare full-length response with large sentinel values
|
|
||||||
large_distance = 1e9
|
|
||||||
response_distances = [large_distance] * len(node_ids)
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts, model_name, mode=embedding_mode
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
if distance_metric == "l2":
|
|
||||||
partial = np.sum(
|
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
||||||
)
|
|
||||||
else: # mips or cosine
|
|
||||||
partial = -np.dot(embeddings, query_vector)
|
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
|
||||||
response_distances[pos] = float(dval)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Distance computation error, using sentinels: {e}")
|
|
||||||
|
|
||||||
# Send response in expected shape [[distances]]
|
|
||||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Fallback: treat as embedding-by-id request
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
):
|
|
||||||
node_ids = request[0]
|
|
||||||
elif isinstance(request, list):
|
|
||||||
node_ids = request
|
|
||||||
else:
|
|
||||||
node_ids = []
|
|
||||||
last_request_type = "embedding"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
|
||||||
|
|
||||||
# Preallocate zero-filled flat data for robustness
|
|
||||||
if embedding_dim <= 0:
|
|
||||||
dims = [0, 0]
|
|
||||||
flat_data: list[float] = []
|
|
||||||
else:
|
|
||||||
dims = [len(node_ids), embedding_dim]
|
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
|
||||||
|
|
||||||
# Collect texts for found ids
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
try:
|
||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data.get("text", "")
|
txt = passage_data["text"]
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
texts.append(txt)
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {nid}")
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage with ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if texts:
|
# Process embeddings
|
||||||
try:
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
logger.info(
|
||||||
logger.info(
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
# Calculate distances
|
||||||
logger.error(
|
if distance_metric == "l2":
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
distances = np.sum(
|
||||||
)
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
dims = [0, embedding_dim]
|
)
|
||||||
flat_data = []
|
else: # mips or cosine
|
||||||
else:
|
distances = -np.dot(embeddings, query_vector)
|
||||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
flat = emb_f32.flatten().tolist()
|
|
||||||
for j, pos in enumerate(found_indices):
|
|
||||||
start = pos * embedding_dim
|
|
||||||
end = start + embedding_dim
|
|
||||||
if end <= len(flat_data):
|
|
||||||
flat_data[start:end] = flat[
|
|
||||||
j * embedding_dim : (j + 1) * embedding_dim
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
response_payload = distances.flatten().tolist()
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
||||||
|
logger.debug(f"Sending distance response with {len(distances)} distances")
|
||||||
|
|
||||||
rep_socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
# Shape-correct fallback
|
|
||||||
try:
|
|
||||||
if last_request_type == "distance":
|
|
||||||
large_distance = 1e9
|
|
||||||
fallback_len = max(0, int(last_request_length))
|
|
||||||
safe = [[large_distance] * fallback_len]
|
|
||||||
elif last_request_type == "embedding":
|
|
||||||
bsz = max(0, int(last_request_length))
|
|
||||||
dim = max(0, int(embedding_dim))
|
|
||||||
safe = (
|
|
||||||
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
|
||||||
)
|
|
||||||
elif last_request_type == "text":
|
|
||||||
safe = [] # direct text embeddings expectation is a flat list
|
|
||||||
else:
|
|
||||||
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
rep_socket.close(0)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
context.term()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info("ZMQ server thread exiting gracefully")
|
# Standard embedding request (passage ID lookup)
|
||||||
|
if (
|
||||||
|
not isinstance(request_payload, list)
|
||||||
|
or len(request_payload) != 1
|
||||||
|
or not isinstance(request_payload[0], list)
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||||
|
)
|
||||||
|
socket.send(msgpack.packb([[], []]))
|
||||||
|
continue
|
||||||
|
|
||||||
# Add shutdown coordination
|
node_ids = request_payload[0]
|
||||||
shutdown_event = threading.Event()
|
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||||
|
|
||||||
def shutdown_zmq_server():
|
# Look up texts by node IDs
|
||||||
"""Gracefully shutdown ZMQ server."""
|
texts = []
|
||||||
logger.info("Initiating graceful shutdown...")
|
for nid in node_ids:
|
||||||
shutdown_event.set()
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if zmq_thread.is_alive():
|
# Process embeddings
|
||||||
logger.info("Waiting for ZMQ thread to finish...")
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
zmq_thread.join(timeout=5)
|
logger.info(
|
||||||
if zmq_thread.is_alive():
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
logger.warning("ZMQ thread did not finish in time")
|
)
|
||||||
|
|
||||||
# Clean up ZMQ resources
|
# Serialization and response
|
||||||
try:
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
# Note: socket and context are cleaned up by thread exit
|
logger.error(
|
||||||
logger.info("ZMQ resources cleaned up")
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
except Exception as e:
|
)
|
||||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
raise AssertionError()
|
||||||
|
|
||||||
# Clean up other resources
|
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
try:
|
response_payload = [
|
||||||
import gc
|
list(hidden_contiguous_f32.shape),
|
||||||
|
hidden_contiguous_f32.flatten().tolist(),
|
||||||
|
]
|
||||||
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
|
||||||
gc.collect()
|
socket.send(response_bytes)
|
||||||
logger.info("Additional resources cleaned up")
|
e2e_end = time.time()
|
||||||
except Exception as e:
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
logger.warning(f"Error cleaning additional resources: {e}")
|
|
||||||
|
|
||||||
logger.info("Graceful shutdown completed")
|
except zmq.Again:
|
||||||
sys.exit(0)
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
# Register signal handlers within this function scope
|
traceback.print_exc()
|
||||||
import signal
|
socket.send(msgpack.packb([[], []]))
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
shutdown_zmq_server()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Pass shutdown_event to ZMQ thread
|
|
||||||
zmq_thread = threading.Thread(
|
|
||||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
|
||||||
daemon=False, # Not daemon - we want to wait for it
|
|
||||||
)
|
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while True:
|
||||||
time.sleep(0.1) # Check shutdown more frequently
|
time.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
shutdown_zmq_server()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If we reach here, shutdown was triggered by signal
|
|
||||||
logger.info("Main loop exited, process should be shutting down")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Signal handlers are now registered within create_hnsw_embedding_server
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
@@ -412,7 +295,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.2"
|
version = "0.1.16"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.2",
|
"leann-core==0.1.16",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
@@ -22,8 +22,6 @@ cmake.build-type = "Release"
|
|||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
|
||||||
# CMake definitions to optimize compilation and find Homebrew packages
|
# CMake definitions to optimize compilation
|
||||||
[tool.scikit-build.cmake.define]
|
[tool.scikit-build.cmake.define]
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
|
|
||||||
OpenMP_ROOT = {env = "OpenMP_ROOT"}
|
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: ed96ff7dba...ff22e2c86b
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.2"
|
version = "0.1.16"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
@@ -31,10 +31,8 @@ dependencies = [
|
|||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
"pymupdf>=1.23.0",
|
"pymupdf>=1.23.0",
|
||||||
"pdfplumber>=0.10.0",
|
"pdfplumber>=0.10.0",
|
||||||
"nbconvert>=7.0.0", # For .ipynb file support
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -46,7 +44,6 @@ colab = [
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
leann = "leann.cli:main"
|
leann = "leann.cli:main"
|
||||||
leann_mcp = "leann.mcp:main"
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -18,7 +18,6 @@ from leann.interface import LeannBackendSearcherInterface
|
|||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
from .metadata_filter import MetadataFilterEngine
|
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -34,7 +33,7 @@ def compute_embeddings(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: Optional[int] = None,
|
port: int | None = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -47,7 +46,6 @@ def compute_embeddings(
|
|||||||
- "sentence-transformers": Use sentence-transformers library (default)
|
- "sentence-transformers": Use sentence-transformers library (default)
|
||||||
- "mlx": Use MLX backend for Apple Silicon
|
- "mlx": Use MLX backend for Apple Silicon
|
||||||
- "openai": Use OpenAI embedding API
|
- "openai": Use OpenAI embedding API
|
||||||
- "gemini": Use Google Gemini embedding API
|
|
||||||
use_server: Whether to use embedding server (True for search, False for build)
|
use_server: Whether to use embedding server (True for search, False for build)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -117,180 +115,54 @@ class SearchResult:
|
|||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(
|
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
self.offset_maps = {}
|
||||||
):
|
self.passage_files = {}
|
||||||
self.offset_maps: dict[str, dict[str, int]] = {}
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
self.passage_files: dict[str, str] = {}
|
|
||||||
# Avoid materializing a single gigantic global map to reduce memory
|
|
||||||
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
|
||||||
# per-shard maps and do a lightweight per-shard lookup on demand.
|
|
||||||
self._total_count: int = 0
|
|
||||||
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
|
|
||||||
|
|
||||||
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
|
||||||
index_name_base = None
|
|
||||||
if metadata_file_path:
|
|
||||||
meta_name = Path(metadata_file_path).name
|
|
||||||
if meta_name.endswith(".meta.json"):
|
|
||||||
index_name_base = meta_name[: -len(".meta.json")]
|
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source.get("path", "")
|
passage_file = source["path"]
|
||||||
index_file = source.get("index_path", "") # .idx file
|
index_file = source["index_path"] # .idx file
|
||||||
|
|
||||||
# Fix path resolution - relative paths should be relative to metadata file directory
|
# Fix path resolution for Colab and other environments
|
||||||
def _resolve_candidates(
|
if not Path(index_file).is_absolute():
|
||||||
primary: str,
|
# If relative path, try to resolve it properly
|
||||||
relative_key: str,
|
index_file = str(Path(index_file).resolve())
|
||||||
default_name: Optional[str],
|
|
||||||
source_dict: dict[str, Any],
|
|
||||||
) -> list[Path]:
|
|
||||||
"""
|
|
||||||
Build an ordered list of candidate paths. For relative paths specified in
|
|
||||||
metadata, prefer resolution relative to the metadata file directory first,
|
|
||||||
then fall back to CWD-based resolution, and finally to conventional
|
|
||||||
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
|
|
||||||
"""
|
|
||||||
candidates: list[Path] = []
|
|
||||||
# 1) Primary path
|
|
||||||
if primary:
|
|
||||||
p = Path(primary)
|
|
||||||
if p.is_absolute():
|
|
||||||
candidates.append(p)
|
|
||||||
else:
|
|
||||||
# Prefer metadata-relative resolution for relative paths
|
|
||||||
if metadata_file_path:
|
|
||||||
candidates.append(Path(metadata_file_path).parent / p)
|
|
||||||
# Also consider CWD-relative as a fallback for legacy layouts
|
|
||||||
candidates.append(Path.cwd() / p)
|
|
||||||
# 2) metadata-relative explicit relative key (if present)
|
|
||||||
if metadata_file_path and source_dict.get(relative_key):
|
|
||||||
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
|
||||||
# 3) metadata-relative standard sibling filename
|
|
||||||
if metadata_file_path and default_name:
|
|
||||||
candidates.append(Path(metadata_file_path).parent / default_name)
|
|
||||||
return candidates
|
|
||||||
|
|
||||||
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
|
|
||||||
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
|
|
||||||
idx_candidates = _resolve_candidates(
|
|
||||||
index_file, "index_path_relative", idx_default, source
|
|
||||||
)
|
|
||||||
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
|
|
||||||
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
|
|
||||||
|
|
||||||
def _pick_existing(cands: list[Path]) -> str:
|
|
||||||
for c in cands:
|
|
||||||
if c.exists():
|
|
||||||
return str(c.resolve())
|
|
||||||
# Fallback to last candidate (best guess) even if not exists; will error below
|
|
||||||
return str(cands[-1].resolve()) if cands else ""
|
|
||||||
|
|
||||||
index_file = _pick_existing(idx_candidates)
|
|
||||||
passage_file = _pick_existing(pas_candidates)
|
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
offset_map: dict[str, int] = pickle.load(f)
|
offset_map = pickle.load(f)
|
||||||
self.offset_maps[passage_file] = offset_map
|
self.offset_maps[passage_file] = offset_map
|
||||||
self.passage_files[passage_file] = passage_file
|
self.passage_files[passage_file] = passage_file
|
||||||
self._total_count += len(offset_map)
|
|
||||||
|
# Build global map for O(1) lookup
|
||||||
|
for passage_id, offset in offset_map.items():
|
||||||
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
# Fast path: check each shard map (there are typically few shards).
|
if passage_id in self.global_offset_map:
|
||||||
# This avoids building a massive combined dict while keeping lookups
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
# bounded by the number of shards.
|
# Lazy file opening - only open when needed
|
||||||
for passage_file, offset_map in self.offset_maps.items():
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
try:
|
f.seek(offset)
|
||||||
offset = offset_map[passage_id]
|
return json.loads(f.readline())
|
||||||
with open(passage_file, encoding="utf-8") as f:
|
|
||||||
f.seek(offset)
|
|
||||||
return json.loads(f.readline())
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
def filter_search_results(
|
|
||||||
self,
|
|
||||||
search_results: list[SearchResult],
|
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""
|
|
||||||
Apply metadata filters to search results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_results: List of SearchResult objects
|
|
||||||
metadata_filters: Filter specifications to apply
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Filtered list of SearchResult objects
|
|
||||||
"""
|
|
||||||
if not metadata_filters:
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
logger.debug(f"Applying metadata filters to {len(search_results)} results")
|
|
||||||
|
|
||||||
# Convert SearchResult objects to dictionaries for the filter engine
|
|
||||||
result_dicts = []
|
|
||||||
for result in search_results:
|
|
||||||
result_dicts.append(
|
|
||||||
{
|
|
||||||
"id": result.id,
|
|
||||||
"score": result.score,
|
|
||||||
"text": result.text,
|
|
||||||
"metadata": result.metadata,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply filters using the filter engine
|
|
||||||
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
|
|
||||||
|
|
||||||
# Convert back to SearchResult objects
|
|
||||||
filtered_results = []
|
|
||||||
for result_dict in filtered_dicts:
|
|
||||||
filtered_results.append(
|
|
||||||
SearchResult(
|
|
||||||
id=result_dict["id"],
|
|
||||||
score=result_dict["score"],
|
|
||||||
text=result_dict["text"],
|
|
||||||
metadata=result_dict["metadata"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
|
|
||||||
return filtered_results
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self._total_count
|
|
||||||
|
|
||||||
|
|
||||||
class LeannBuilder:
|
class LeannBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: int | None = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
# Normalize incompatible combinations early (for consistent metadata)
|
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_name == "hnsw":
|
|
||||||
is_recompute = backend_kwargs.get("is_recompute", True)
|
|
||||||
is_compact = backend_kwargs.get("is_compact", True)
|
|
||||||
if is_recompute is False and is_compact is True:
|
|
||||||
warnings.warn(
|
|
||||||
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
|
|
||||||
UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
backend_kwargs["is_compact"] = False
|
|
||||||
|
|
||||||
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
self.backend_factory = backend_factory
|
self.backend_factory = backend_factory
|
||||||
@@ -370,7 +242,7 @@ class LeannBuilder:
|
|||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
self.chunks: list[dict[str, Any]] = []
|
self.chunks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
@@ -380,23 +252,6 @@ class LeannBuilder:
|
|||||||
def build_index(self, index_path: str):
|
def build_index(self, index_path: str):
|
||||||
if not self.chunks:
|
if not self.chunks:
|
||||||
raise ValueError("No chunks added.")
|
raise ValueError("No chunks added.")
|
||||||
|
|
||||||
# Filter out invalid/empty text chunks early to keep passage and embedding counts aligned
|
|
||||||
valid_chunks: list[dict[str, Any]] = []
|
|
||||||
skipped = 0
|
|
||||||
for chunk in self.chunks:
|
|
||||||
text = chunk.get("text", "")
|
|
||||||
if isinstance(text, str) and text.strip():
|
|
||||||
valid_chunks.append(chunk)
|
|
||||||
else:
|
|
||||||
skipped += 1
|
|
||||||
if skipped > 0:
|
|
||||||
print(
|
|
||||||
f"Warning: Skipping {skipped} empty/invalid text chunk(s). Processing {len(valid_chunks)} valid chunks"
|
|
||||||
)
|
|
||||||
self.chunks = valid_chunks
|
|
||||||
if not self.chunks:
|
|
||||||
raise ValueError("All provided chunks are empty or invalid. Nothing to index.")
|
|
||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = len(
|
self.dimensions = len(
|
||||||
compute_embeddings(
|
compute_embeddings(
|
||||||
@@ -459,12 +314,8 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
# Preserve existing relative file names (backward-compatible)
|
"path": str(passages_file),
|
||||||
"path": passages_file.name,
|
"index_path": str(offset_file),
|
||||||
"index_path": offset_file.name,
|
|
||||||
# Add optional redundant relative keys for remote build portability (non-breaking)
|
|
||||||
"path_relative": passages_file.name,
|
|
||||||
"index_path_relative": offset_file.name,
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -579,12 +430,8 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
# Preserve existing relative file names (backward-compatible)
|
"path": str(passages_file),
|
||||||
"path": passages_file.name,
|
"index_path": str(offset_file),
|
||||||
"index_path": offset_file.name,
|
|
||||||
# Add optional redundant relative keys for remote build portability (non-breaking)
|
|
||||||
"path_relative": passages_file.name,
|
|
||||||
"index_path_relative": offset_file.name,
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"built_from_precomputed_embeddings": True,
|
"built_from_precomputed_embeddings": True,
|
||||||
@@ -612,26 +459,14 @@ class LeannSearcher:
|
|||||||
|
|
||||||
self.meta_path_str = f"{index_path}.meta.json"
|
self.meta_path_str = f"{index_path}.meta.json"
|
||||||
if not Path(self.meta_path_str).exists():
|
if not Path(self.meta_path_str).exists():
|
||||||
parent_dir = Path(index_path).parent
|
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
|
||||||
print(
|
|
||||||
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
|
||||||
)
|
|
||||||
# highlight in red the filenotfound error
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
|
||||||
)
|
|
||||||
with open(self.meta_path_str, encoding="utf-8") as f:
|
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||||
self.meta_data = json.load(f)
|
self.meta_data = json.load(f)
|
||||||
backend_name = self.meta_data["backend_name"]
|
backend_name = self.meta_data["backend_name"]
|
||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
# Delegate portability handling to PassageManager
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
self.passage_manager = PassageManager(
|
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
|
||||||
)
|
|
||||||
# Preserve backend name for conditional parameter forwarding
|
|
||||||
self.backend_name = backend_name
|
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
@@ -651,52 +486,13 @@ class LeannSearcher:
|
|||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
|
||||||
batch_size: int = 0,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
|
||||||
Search for nearest neighbors with optional metadata filtering.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text query to search for
|
|
||||||
top_k: Number of nearest neighbors to return
|
|
||||||
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
|
||||||
beam_width: Number of parallel search paths/IO requests per iteration
|
|
||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
|
||||||
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
|
||||||
expected_zmq_port: ZMQ port for embedding server communication
|
|
||||||
metadata_filters: Optional filters to apply to search results based on metadata.
|
|
||||||
Format: {"field_name": {"operator": value}}
|
|
||||||
Supported operators:
|
|
||||||
- Comparison: "==", "!=", "<", "<=", ">", ">="
|
|
||||||
- Membership: "in", "not_in"
|
|
||||||
- String: "contains", "starts_with", "ends_with"
|
|
||||||
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
|
|
||||||
**kwargs: Backend-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of SearchResult objects with text, metadata, and similarity scores
|
|
||||||
"""
|
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
logger.info(f" Metadata filters: {metadata_filters}")
|
|
||||||
logger.info(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Smart top_k detection and adjustment
|
|
||||||
# Use PassageManager length (sum of shard sizes) to avoid
|
|
||||||
# depending on a massive combined map
|
|
||||||
total_docs = len(self.passage_manager)
|
|
||||||
original_top_k = top_k
|
|
||||||
if top_k > total_docs:
|
|
||||||
top_k = total_docs
|
|
||||||
logger.warning(
|
|
||||||
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
|
||||||
)
|
|
||||||
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
|
||||||
|
|
||||||
zmq_port = None
|
zmq_port = None
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -717,41 +513,31 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
time.time() - start_time
|
||||||
logger.info(f" Embedding time: {embedding_time} seconds")
|
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
backend_search_kwargs: dict[str, Any] = {
|
|
||||||
"complexity": complexity,
|
|
||||||
"beam_width": beam_width,
|
|
||||||
"prune_ratio": prune_ratio,
|
|
||||||
"recompute_embeddings": recompute_embeddings,
|
|
||||||
"pruning_strategy": pruning_strategy,
|
|
||||||
"zmq_port": zmq_port,
|
|
||||||
}
|
|
||||||
# Only HNSW supports batching; forward conditionally
|
|
||||||
if self.backend_name == "hnsw":
|
|
||||||
backend_search_kwargs["batch_size"] = batch_size
|
|
||||||
|
|
||||||
# Merge any extra kwargs last
|
|
||||||
backend_search_kwargs.update(kwargs)
|
|
||||||
|
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
top_k,
|
top_k,
|
||||||
**backend_search_kwargs,
|
complexity=complexity,
|
||||||
|
beam_width=beam_width,
|
||||||
|
prune_ratio=prune_ratio,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
pruning_strategy=pruning_strategy,
|
||||||
|
zmq_port=zmq_port,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
time.time() - start_time
|
||||||
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
@@ -777,67 +563,23 @@ class LeannSearcher:
|
|||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
RESET = "\033[0m"
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply metadata filters if specified
|
|
||||||
if metadata_filters:
|
|
||||||
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
|
|
||||||
enriched_results = self.passage_manager.filter_search_results(
|
|
||||||
enriched_results, metadata_filters
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define color codes outside the loop for final message
|
|
||||||
GREEN = "\033[92m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Explicitly cleanup embedding server resources.
|
|
||||||
|
|
||||||
This method should be called after you're done using the searcher,
|
|
||||||
especially in test environments or batch processing scenarios.
|
|
||||||
"""
|
|
||||||
backend = getattr(self.backend_impl, "embedding_server_manager", None)
|
|
||||||
if backend is not None:
|
|
||||||
backend.stop_server()
|
|
||||||
|
|
||||||
# Enable automatic cleanup patterns
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
# Avoid noisy errors during interpreter shutdown
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[dict[str, Any]] = None,
|
llm_config: dict[str, Any] | None = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
searcher: Optional[LeannSearcher] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if searcher is None:
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
|
||||||
self._owns_searcher = True
|
|
||||||
else:
|
|
||||||
self.searcher = searcher
|
|
||||||
self._owns_searcher = False
|
|
||||||
self.llm = get_llm(llm_config)
|
self.llm = get_llm(llm_config)
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
@@ -849,10 +591,8 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
llm_kwargs: dict[str, Any] | None = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
|
||||||
batch_size: int = 0,
|
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
@@ -867,12 +607,10 @@ class LeannChat:
|
|||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
metadata_filters=metadata_filters,
|
|
||||||
batch_size=batch_size,
|
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
search_time = time.time() - search_time
|
||||||
logger.info(f" Search time: {search_time} seconds")
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -881,10 +619,7 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
ask_time = time.time()
|
|
||||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||||
ask_time = time.time() - ask_time
|
|
||||||
logger.info(f" Ask time: {ask_time} seconds")
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
@@ -901,30 +636,3 @@ class LeannChat:
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
print("\nGoodbye!")
|
print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Explicitly cleanup embedding server resources.
|
|
||||||
|
|
||||||
This method should be called after you're done using the chat interface,
|
|
||||||
especially in test environments or batch processing scenarios.
|
|
||||||
"""
|
|
||||||
# Only stop the embedding server if this LeannChat instance created the searcher.
|
|
||||||
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
|
|
||||||
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
|
|
||||||
self.searcher.cleanup()
|
|
||||||
|
|
||||||
# Enable automatic cleanup patterns
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import difflib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,12 +17,12 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_models(host: str) -> list[str]:
|
def check_ollama_models() -> list[str]:
|
||||||
"""Check available Ollama models and return a list"""
|
"""Check available Ollama models and return a list"""
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return [model["name"] for model in data.get("models", [])]
|
return [model["name"] for model in data.get("models", [])]
|
||||||
@@ -309,12 +309,10 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
return search_hf_models_fuzzy(query, limit)
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models(host)
|
available_models = check_ollama_models()
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -360,11 +358,7 @@ def validate_model_and_suggest(
|
|||||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||||
|
|
||||||
if suggestions:
|
if suggestions:
|
||||||
error_msg += (
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
"\n\nDid you mean one of these installed models?\n"
|
|
||||||
+ "\nTry to use ollama pull to install the model you need\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
else:
|
else:
|
||||||
@@ -422,6 +416,7 @@ class LLMInterface(ABC):
|
|||||||
top_k=10,
|
top_k=10,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
beam_width=8,
|
beam_width=8,
|
||||||
|
USE_DEFERRED_FETCH=True,
|
||||||
skip_search_reorder=True,
|
skip_search_reorder=True,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
dedup_node_dis=True,
|
dedup_node_dis=True,
|
||||||
@@ -433,6 +428,7 @@ class LLMInterface(ABC):
|
|||||||
Supported kwargs:
|
Supported kwargs:
|
||||||
- complexity (int): Search complexity parameter (default: 32)
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
- beam_width (int): Beam width for search (default: 4)
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
|
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
||||||
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||||
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||||
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||||
@@ -469,7 +465,7 @@ class OllamaChat(LLMInterface):
|
|||||||
requests.get(host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
model_error = validate_model_and_suggest(model, "ollama")
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -489,35 +485,11 @@ class OllamaChat(LLMInterface):
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
full_url = f"{self.host}/api/generate"
|
full_url = f"{self.host}/api/generate"
|
||||||
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
options = kwargs.copy()
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget:
|
|
||||||
# Remove thinking_budget from options as it's not a standard Ollama option
|
|
||||||
options.pop("thinking_budget", None)
|
|
||||||
# Only apply reasoning parameters to models that support it
|
|
||||||
reasoning_supported_models = [
|
|
||||||
"gpt-oss:20b",
|
|
||||||
"gpt-oss:120b",
|
|
||||||
"deepseek-r1",
|
|
||||||
"deepseek-coder",
|
|
||||||
]
|
|
||||||
|
|
||||||
if thinking_budget in ["low", "medium", "high"]:
|
|
||||||
if any(model in self.model.lower() for model in reasoning_supported_models):
|
|
||||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
|
||||||
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False, # Keep it simple for now
|
"stream": False, # Keep it simple for now
|
||||||
"options": options,
|
"options": kwargs,
|
||||||
}
|
}
|
||||||
logger.debug(f"Sending request to Ollama: {payload}")
|
logger.debug(f"Sending request to Ollama: {payload}")
|
||||||
try:
|
try:
|
||||||
@@ -680,64 +652,10 @@ class HFChat(LLMInterface):
|
|||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
class GeminiChat(LLMInterface):
|
|
||||||
"""LLM interface for Google Gemini models."""
|
|
||||||
|
|
||||||
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
|
|
||||||
self.model = model
|
|
||||||
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
|
||||||
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Initializing Gemini Chat with model='{model}'")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import google.genai as genai
|
|
||||||
|
|
||||||
self.client = genai.Client(api_key=self.api_key)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
|
|
||||||
)
|
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
logger.info(f"Sending request to Gemini with model {self.model}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from google.genai.types import GenerateContentConfig
|
|
||||||
|
|
||||||
generation_config = GenerateContentConfig(
|
|
||||||
temperature=kwargs.get("temperature", 0.7),
|
|
||||||
max_output_tokens=kwargs.get("max_tokens", 1000),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle top_p parameter
|
|
||||||
if "top_p" in kwargs:
|
|
||||||
generation_config.top_p = kwargs["top_p"]
|
|
||||||
|
|
||||||
response = self.client.models.generate_content(
|
|
||||||
model=self.model,
|
|
||||||
contents=prompt,
|
|
||||||
config=generation_config,
|
|
||||||
)
|
|
||||||
# Handle potential None response text
|
|
||||||
response_text = response.text
|
|
||||||
if response_text is None:
|
|
||||||
logger.warning("Gemini returned None response text")
|
|
||||||
return ""
|
|
||||||
return response_text.strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error communicating with Gemini: {e}")
|
|
||||||
return f"Error: Could not get a response from Gemini. Details: {e}"
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
@@ -762,38 +680,11 @@ class OpenAIChat(LLMInterface):
|
|||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
"temperature": kwargs.get("temperature", 0.7),
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle max_tokens vs max_completion_tokens based on model
|
|
||||||
max_tokens = kwargs.get("max_tokens", 1000)
|
|
||||||
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
|
|
||||||
# o-series models use max_completion_tokens
|
|
||||||
params["max_completion_tokens"] = max_tokens
|
|
||||||
params["temperature"] = 1.0
|
|
||||||
else:
|
|
||||||
# Other models use max_tokens
|
|
||||||
params["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
|
||||||
# Check if this is an o-series model (partial match for model names)
|
|
||||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
|
||||||
if any(model in self.model for model in o_series_models):
|
|
||||||
# Use the correct OpenAI reasoning parameter format
|
|
||||||
params["reasoning_effort"] = thinking_budget
|
|
||||||
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add other kwargs (excluding thinking_budget as it's handled above)
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if k not in ["max_tokens", "temperature", "thinking_budget"]:
|
|
||||||
params[k] = v
|
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -813,7 +704,7 @@ class SimulatedChat(LLMInterface):
|
|||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
@@ -847,8 +738,6 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "gemini":
|
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
return SimulatedChat()
|
return SimulatedChat()
|
||||||
else:
|
else:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,6 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -29,8 +28,6 @@ def compute_embeddings(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
manual_tokenize: bool = False,
|
|
||||||
max_length: int = 512,
|
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -38,7 +35,7 @@ def compute_embeddings(
|
|||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Model name
|
model_name: Model name
|
||||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
batch_size: Batch size for processing
|
batch_size: Batch size for processing
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
@@ -53,17 +50,11 @@ def compute_embeddings(
|
|||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
adaptive_optimization=adaptive_optimization,
|
adaptive_optimization=adaptive_optimization,
|
||||||
manual_tokenize=manual_tokenize,
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
|
||||||
elif mode == "gemini":
|
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
@@ -76,8 +67,6 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
manual_tokenize: bool = False,
|
|
||||||
max_length: int = 512,
|
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
@@ -221,130 +210,20 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(
|
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
||||||
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time = time.time()
|
# Use torch.inference_mode for optimal performance
|
||||||
if not manual_tokenize:
|
with torch.inference_mode():
|
||||||
# Use SentenceTransformer's optimized encode path (default)
|
embeddings = model.encode(
|
||||||
with torch.inference_mode():
|
texts,
|
||||||
embeddings = model.encode(
|
batch_size=batch_size,
|
||||||
texts,
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
batch_size=batch_size,
|
convert_to_numpy=True,
|
||||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
normalize_embeddings=False,
|
||||||
convert_to_numpy=True,
|
device=device,
|
||||||
normalize_embeddings=False,
|
)
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
# Synchronize if CUDA to measure accurate wall time
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
|
||||||
try:
|
|
||||||
from transformers import AutoModel, AutoTokenizer # type: ignore
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
|
||||||
|
|
||||||
# Cache tokenizer and model
|
|
||||||
tok_cache_key = f"hf_tokenizer_{model_name}"
|
|
||||||
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
|
||||||
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
|
||||||
hf_tokenizer = _model_cache[tok_cache_key]
|
|
||||||
hf_model = _model_cache[mdl_cache_key]
|
|
||||||
logger.info("Using cached HF tokenizer/model for manual path")
|
|
||||||
else:
|
|
||||||
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
|
||||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
|
||||||
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
|
||||||
hf_model.to(device)
|
|
||||||
hf_model.eval()
|
|
||||||
# Optional compile on supported devices
|
|
||||||
if device in ["cuda", "mps"]:
|
|
||||||
try:
|
|
||||||
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
_model_cache[tok_cache_key] = hf_tokenizer
|
|
||||||
_model_cache[mdl_cache_key] = hf_model
|
|
||||||
|
|
||||||
all_embeddings: list[np.ndarray] = []
|
|
||||||
# Progress bar when building or for large inputs
|
|
||||||
show_progress = is_build or len(texts) > 32
|
|
||||||
try:
|
|
||||||
if show_progress:
|
|
||||||
from tqdm import tqdm # type: ignore
|
|
||||||
|
|
||||||
batch_iter = tqdm(
|
|
||||||
range(0, len(texts), batch_size),
|
|
||||||
desc="Embedding (manual)",
|
|
||||||
unit="batch",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch_iter = range(0, len(texts), batch_size)
|
|
||||||
except Exception:
|
|
||||||
batch_iter = range(0, len(texts), batch_size)
|
|
||||||
|
|
||||||
start_time_manual = time.time()
|
|
||||||
with torch.inference_mode():
|
|
||||||
for start_index in batch_iter:
|
|
||||||
end_index = min(start_index + batch_size, len(texts))
|
|
||||||
batch_texts = texts[start_index:end_index]
|
|
||||||
tokenize_start_time = time.time()
|
|
||||||
inputs = hf_tokenizer(
|
|
||||||
batch_texts,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
tokenize_end_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
|
||||||
)
|
|
||||||
# Print shapes of all input tensors for debugging
|
|
||||||
for k, v in inputs.items():
|
|
||||||
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
|
||||||
to_device_start_time = time.time()
|
|
||||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
||||||
to_device_end_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
|
||||||
)
|
|
||||||
forward_start_time = time.time()
|
|
||||||
outputs = hf_model(**inputs)
|
|
||||||
forward_end_time = time.time()
|
|
||||||
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
|
||||||
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
|
||||||
attention_mask = inputs.get("attention_mask")
|
|
||||||
if attention_mask is None:
|
|
||||||
# Fallback: assume all tokens are valid
|
|
||||||
pooled = last_hidden_state.mean(dim=1)
|
|
||||||
else:
|
|
||||||
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
|
||||||
masked = last_hidden_state * mask
|
|
||||||
lengths = mask.sum(dim=1).clamp(min=1)
|
|
||||||
pooled = masked.sum(dim=1) / lengths
|
|
||||||
# Move to CPU float32
|
|
||||||
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
|
||||||
all_embeddings.append(batch_embeddings)
|
|
||||||
|
|
||||||
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
|
|
||||||
try:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
end_time = time.time()
|
|
||||||
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
|
|
||||||
end_time = time.time()
|
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
logger.info(f"Time taken: {end_time - start_time} seconds")
|
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
@@ -363,16 +242,6 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
# Validate input list
|
|
||||||
if not texts:
|
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
|
||||||
# Extra validation: abort early if any item is empty/whitespace
|
|
||||||
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
|
|
||||||
if invalid_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
|
||||||
)
|
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
@@ -392,16 +261,8 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
print(f"len of texts: {len(texts)}")
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 800 # Conservative batch size because the token limit is 300K
|
max_batch_size = 1000 # Conservative batch size
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
# get the avg len of texts
|
|
||||||
avg_len = sum(len(text) for text in texts) / len(texts)
|
|
||||||
print(f"avg len of texts: {avg_len}")
|
|
||||||
# if avg len is less than 1000, use the max batch size
|
|
||||||
if avg_len > 300:
|
|
||||||
max_batch_size = 500
|
|
||||||
|
|
||||||
# if avg len is less than 1000, use the max batch size
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -504,366 +365,3 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
# Stack numpy arrays
|
# Stack numpy arrays
|
||||||
return np.stack(all_embeddings)
|
return np.stack(all_embeddings)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
|
||||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
|
||||||
|
|
||||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to compute embeddings for
|
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Ollama is running
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
error_msg = (
|
|
||||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
|
||||||
"Please ensure Ollama is running:\n"
|
|
||||||
" • macOS/Linux: ollama serve\n"
|
|
||||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
|
||||||
"Installation: https://ollama.com/download"
|
|
||||||
)
|
|
||||||
raise RuntimeError(error_msg)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
|
||||||
|
|
||||||
# Check if model exists and provide helpful suggestions
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
models = response.json()
|
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
|
||||||
|
|
||||||
# Filter for embedding models (models that support embeddings)
|
|
||||||
embedding_models = []
|
|
||||||
suggested_embedding_models = [
|
|
||||||
"nomic-embed-text",
|
|
||||||
"mxbai-embed-large",
|
|
||||||
"bge-m3",
|
|
||||||
"all-minilm",
|
|
||||||
"snowflake-arctic-embed",
|
|
||||||
]
|
|
||||||
|
|
||||||
for model in model_names:
|
|
||||||
# Check if it's an embedding model (by name patterns or known models)
|
|
||||||
base_name = model.split(":")[0]
|
|
||||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
|
||||||
embedding_models.append(model)
|
|
||||||
|
|
||||||
# Check if model exists (handle versioned names) and resolve to full name
|
|
||||||
resolved_model_name = None
|
|
||||||
for name in model_names:
|
|
||||||
# Exact match
|
|
||||||
if model_name == name:
|
|
||||||
resolved_model_name = name
|
|
||||||
break
|
|
||||||
# Match without version tag (use the versioned name)
|
|
||||||
elif model_name == name.split(":")[0]:
|
|
||||||
resolved_model_name = name
|
|
||||||
break
|
|
||||||
|
|
||||||
if not resolved_model_name:
|
|
||||||
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
|
||||||
|
|
||||||
# Suggest pulling the model
|
|
||||||
error_msg += "📦 To install this embedding model:\n"
|
|
||||||
error_msg += f" ollama pull {model_name}\n\n"
|
|
||||||
|
|
||||||
# Show available embedding models
|
|
||||||
if embedding_models:
|
|
||||||
error_msg += "✅ Available embedding models:\n"
|
|
||||||
for model in embedding_models[:5]:
|
|
||||||
error_msg += f" • {model}\n"
|
|
||||||
if len(embedding_models) > 5:
|
|
||||||
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
|
||||||
else:
|
|
||||||
error_msg += "💡 Popular embedding models to install:\n"
|
|
||||||
for model in suggested_embedding_models[:3]:
|
|
||||||
error_msg += f" • ollama pull {model}\n"
|
|
||||||
|
|
||||||
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
# Use the resolved model name for all subsequent operations
|
|
||||||
if resolved_model_name != model_name:
|
|
||||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
|
||||||
model_name = resolved_model_name
|
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it
|
|
||||||
try:
|
|
||||||
test_response = requests.post(
|
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
|
||||||
)
|
|
||||||
if test_response.status_code != 200:
|
|
||||||
error_msg = (
|
|
||||||
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
|
||||||
f"Please use an embedding model like:\n"
|
|
||||||
)
|
|
||||||
for model in suggested_embedding_models[:3]:
|
|
||||||
error_msg += f" • {model}\n"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
# If test fails, continue anyway - model might still work
|
|
||||||
pass
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.warning(f"Could not verify model existence: {e}")
|
|
||||||
|
|
||||||
# Determine batch size based on device availability
|
|
||||||
# Check for CUDA/MPS availability using torch if available
|
|
||||||
batch_size = 32 # Default for MPS/CPU
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
batch_size = 128 # CUDA gets larger batch size
|
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
batch_size = 32 # MPS gets smaller batch size
|
|
||||||
except ImportError:
|
|
||||||
# If torch is not available, use conservative batch size
|
|
||||||
batch_size = 32
|
|
||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size}")
|
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
|
||||||
"""Get embeddings for a batch of texts."""
|
|
||||||
all_embeddings = []
|
|
||||||
failed_indices = []
|
|
||||||
|
|
||||||
for i, text in enumerate(batch_texts):
|
|
||||||
max_retries = 3
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
# Truncate very long texts to avoid API issues
|
|
||||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
|
||||||
while retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{host}/api/embeddings",
|
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
embedding = result.get("embedding")
|
|
||||||
|
|
||||||
if embedding is None:
|
|
||||||
raise ValueError(f"No embedding returned for text {i}")
|
|
||||||
|
|
||||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
|
||||||
raise ValueError(f"Invalid embedding format for text {i}")
|
|
||||||
|
|
||||||
all_embeddings.append(embedding)
|
|
||||||
break
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
return all_embeddings, failed_indices
|
|
||||||
|
|
||||||
# Process texts in batches
|
|
||||||
all_embeddings = []
|
|
||||||
all_failed_indices = []
|
|
||||||
|
|
||||||
# Setup progress bar if needed
|
|
||||||
show_progress = is_build or len(texts) > 10
|
|
||||||
try:
|
|
||||||
if show_progress:
|
|
||||||
from tqdm import tqdm
|
|
||||||
except ImportError:
|
|
||||||
show_progress = False
|
|
||||||
|
|
||||||
# Process batches
|
|
||||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
|
||||||
|
|
||||||
if show_progress:
|
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
|
||||||
else:
|
|
||||||
batch_iterator = range(num_batches)
|
|
||||||
|
|
||||||
for batch_idx in batch_iterator:
|
|
||||||
start_idx = batch_idx * batch_size
|
|
||||||
end_idx = min(start_idx + batch_size, len(texts))
|
|
||||||
batch_texts = texts[start_idx:end_idx]
|
|
||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
|
||||||
|
|
||||||
# Adjust failed indices to global indices
|
|
||||||
global_failed = [start_idx + idx for idx in batch_failed]
|
|
||||||
all_failed_indices.extend(global_failed)
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
|
||||||
# Handle failed embeddings
|
|
||||||
if all_failed_indices:
|
|
||||||
if len(all_failed_indices) == len(texts):
|
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
|
||||||
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
|
||||||
if valid_embedding:
|
|
||||||
embedding_dim = len(valid_embedding)
|
|
||||||
for i, embedding in enumerate(all_embeddings):
|
|
||||||
if embedding is None:
|
|
||||||
all_embeddings[i] = [0.0] * embedding_dim
|
|
||||||
|
|
||||||
# Remove None values
|
|
||||||
all_embeddings = [e for e in all_embeddings if e is not None]
|
|
||||||
|
|
||||||
if not all_embeddings:
|
|
||||||
raise RuntimeError("No valid embeddings were computed")
|
|
||||||
|
|
||||||
# Validate embedding dimensions
|
|
||||||
expected_dim = len(all_embeddings[0])
|
|
||||||
inconsistent_dims = []
|
|
||||||
for i, embedding in enumerate(all_embeddings):
|
|
||||||
if len(embedding) != expected_dim:
|
|
||||||
inconsistent_dims.append((i, len(embedding)))
|
|
||||||
|
|
||||||
if inconsistent_dims:
|
|
||||||
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
|
|
||||||
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
|
|
||||||
error_msg += f" - Text {idx}: {dim} dimensions\n"
|
|
||||||
if len(inconsistent_dims) > 10:
|
|
||||||
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
|
|
||||||
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
|
|
||||||
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
|
|
||||||
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
|
|
||||||
error_msg += (
|
|
||||||
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
|
|
||||||
)
|
|
||||||
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
# Convert to numpy array and normalize
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
|
|
||||||
# Normalize embeddings (L2 normalization)
|
|
||||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
||||||
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
|
||||||
|
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_gemini(
|
|
||||||
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Compute embeddings using Google Gemini API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to compute embeddings for
|
|
||||||
model_name: Gemini model name (default: "text-embedding-004")
|
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embeddings array, shape: (len(texts), embedding_dim)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
|
|
||||||
import google.genai as genai
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(f"Google GenAI package not installed: {e}")
|
|
||||||
|
|
||||||
api_key = os.getenv("GEMINI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
raise RuntimeError("GEMINI_API_KEY environment variable not set")
|
|
||||||
|
|
||||||
# Cache Gemini client
|
|
||||||
cache_key = "gemini_client"
|
|
||||||
if cache_key in _model_cache:
|
|
||||||
client = _model_cache[cache_key]
|
|
||||||
else:
|
|
||||||
client = genai.Client(api_key=api_key)
|
|
||||||
_model_cache[cache_key] = client
|
|
||||||
logger.info("Gemini client cached")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gemini supports batch embedding
|
|
||||||
max_batch_size = 100 # Conservative batch size for Gemini
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
|
|
||||||
batch_range = range(0, len(texts), max_batch_size)
|
|
||||||
batch_iterator = tqdm(
|
|
||||||
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
# Fallback when tqdm is not available
|
|
||||||
batch_iterator = range(0, len(texts), max_batch_size)
|
|
||||||
|
|
||||||
for i in batch_iterator:
|
|
||||||
batch_texts = texts[i : i + max_batch_size]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the embed_content method from the new Google GenAI SDK
|
|
||||||
response = client.models.embed_content(
|
|
||||||
model=model_name,
|
|
||||||
contents=batch_texts,
|
|
||||||
config=genai.types.EmbedContentConfig(
|
|
||||||
task_type="RETRIEVAL_DOCUMENT" # For document embedding
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract embeddings from response
|
|
||||||
for embedding_data in response.embeddings:
|
|
||||||
all_embeddings.append(embedding_data.values)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Batch {i} failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|||||||
@@ -6,9 +6,8 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
# Lightweight, self-contained server manager with no cross-process inspection
|
import psutil
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -43,7 +42,130 @@ def _check_port(port: int) -> bool:
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
# Note: All cross-process scanning helpers removed for simplicity
|
def _check_process_matches_config(
|
||||||
|
port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the process using the port matches our expected model and passages file.
|
||||||
|
Returns True if matches, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||||
|
if not _is_process_listening_on_port(proc, port):
|
||||||
|
continue
|
||||||
|
|
||||||
|
cmdline = proc.info["cmdline"]
|
||||||
|
if not cmdline:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return _check_cmdline_matches_config(
|
||||||
|
cmdline, port, expected_model, expected_passages_file
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"No process found listening on port {port}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not check process on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||||
|
"""Check if a process is listening on the given port."""
|
||||||
|
try:
|
||||||
|
connections = proc.net_connections()
|
||||||
|
for conn in connections:
|
||||||
|
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_cmdline_matches_config(
|
||||||
|
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if command line matches our expected configuration."""
|
||||||
|
cmdline_str = " ".join(cmdline)
|
||||||
|
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
|
# Check if it's our embedding server
|
||||||
|
is_embedding_server = any(
|
||||||
|
server_type in cmdline_str
|
||||||
|
for server_type in [
|
||||||
|
"embedding_server",
|
||||||
|
"leann_backend_diskann.embedding_server",
|
||||||
|
"leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_embedding_server:
|
||||||
|
logger.debug(f"Process on port {port} is not our embedding server")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check model name
|
||||||
|
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||||
|
|
||||||
|
# Check passages file if provided
|
||||||
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
|
result = model_matches and passages_matches
|
||||||
|
logger.debug(
|
||||||
|
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected model."""
|
||||||
|
if "--model-name" not in cmdline:
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_idx = cmdline.index("--model-name")
|
||||||
|
if model_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_model = cmdline[model_idx + 1]
|
||||||
|
return actual_model == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected passages file."""
|
||||||
|
if "--passages-file" not in cmdline:
|
||||||
|
return False # Expected but not found
|
||||||
|
|
||||||
|
passages_idx = cmdline.index("--passages-file")
|
||||||
|
if passages_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_passages = cmdline[passages_idx + 1]
|
||||||
|
expected_path = Path(expected_passages_file).resolve()
|
||||||
|
actual_path = Path(actual_passages).resolve()
|
||||||
|
return actual_path == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def _find_compatible_port_or_next_available(
|
||||||
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
|
) -> tuple[int, bool]:
|
||||||
|
"""
|
||||||
|
Find a port that either has a compatible server or is available.
|
||||||
|
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||||
|
"""
|
||||||
|
for port in range(start_port, start_port + max_attempts):
|
||||||
|
if not _check_port(port):
|
||||||
|
# Port is available
|
||||||
|
return port, False
|
||||||
|
|
||||||
|
# Port is in use, check if it's compatible
|
||||||
|
if _check_process_matches_config(port, model_name, passages_file):
|
||||||
|
logger.info(f"Found compatible server on port {port}")
|
||||||
|
return port, True
|
||||||
|
else:
|
||||||
|
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
@@ -60,18 +182,9 @@ class EmbeddingServerManager:
|
|||||||
e.g., "leann_backend_diskann.embedding_server"
|
e.g., "leann_backend_diskann.embedding_server"
|
||||||
"""
|
"""
|
||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: subprocess.Popen | None = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: int | None = None
|
||||||
# Track last-started config for in-process reuse only
|
|
||||||
self._server_config: Optional[dict] = None
|
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
|
|
||||||
try:
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, self._finalize_process)
|
|
||||||
except Exception:
|
|
||||||
self._finalizer = None
|
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
self,
|
self,
|
||||||
@@ -81,24 +194,26 @@ class EmbeddingServerManager:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
passages_file = kwargs.get("passages_file")
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
# Check if we have a compatible server already running
|
||||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
logger.info("Reusing in-process server")
|
logger.info("Found compatible running server!")
|
||||||
return True, self.server_port
|
return True, port
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
# Always pick a fresh available port
|
# Find a compatible port or next available
|
||||||
try:
|
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||||
actual_port = _get_available_port(port)
|
port, model_name, passages_file
|
||||||
except RuntimeError:
|
)
|
||||||
logger.error("No available ports found")
|
|
||||||
return False, port
|
if is_compatible:
|
||||||
|
logger.info(f"Found compatible server on port {actual_port}")
|
||||||
|
return True, actual_port
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
@@ -131,7 +246,17 @@ class EmbeddingServerManager:
|
|||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
|
|
||||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
||||||
|
"""Check if we have a compatible running server."""
|
||||||
|
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
|
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info("Existing server process is incompatible. Should start a new server.")
|
||||||
|
return False
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
@@ -178,62 +303,22 @@ class EmbeddingServerManager:
|
|||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
|
# Let server output go directly to console
|
||||||
# Embedding servers use many print statements that can fill stdout buffers
|
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||||
is_ci = os.environ.get("CI") == "true"
|
|
||||||
if is_ci:
|
|
||||||
stdout_target = subprocess.DEVNULL
|
|
||||||
stderr_target = None # Keep stderr for error debugging in CI
|
|
||||||
logger.info(
|
|
||||||
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stdout_target = None # Direct to console for visible logs
|
|
||||||
stderr_target = None # Direct to console for visible logs
|
|
||||||
|
|
||||||
# Start embedding server subprocess
|
|
||||||
logger.info(f"Starting server process with command: {' '.join(command)}")
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=stdout_target,
|
stdout=None, # Direct to console
|
||||||
stderr=stderr_target,
|
stderr=None, # Direct to console
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse
|
|
||||||
try:
|
|
||||||
self._server_config = {
|
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
|
||||||
if "--model-name" in command
|
|
||||||
else "",
|
|
||||||
"passages_file": command[command.index("--passages-file") + 1]
|
|
||||||
if "--passages-file" in command
|
|
||||||
else "",
|
|
||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
|
||||||
if "--embedding-mode" in command
|
|
||||||
else "sentence-transformers",
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
self._server_config = {
|
|
||||||
"model_name": "",
|
|
||||||
"passages_file": "",
|
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
}
|
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
# Always attempt best-effort finalize at interpreter exit
|
# Use a lambda to avoid issues with bound methods
|
||||||
atexit.register(self._finalize_process)
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
# Touch finalizer so it knows there is a live process
|
|
||||||
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
|
|
||||||
try:
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, self._finalize_process)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
@@ -258,35 +343,24 @@ class EmbeddingServerManager:
|
|||||||
if not self.server_process:
|
if not self.server_process:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.server_process and self.server_process.poll() is not None:
|
if self.server_process.poll() is not None:
|
||||||
# Process already terminated
|
# Process already terminated
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
self.server_port = None
|
|
||||||
self._server_config = None
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
|
self.server_process.terminate()
|
||||||
# Use simple termination first; if the server installed signal handlers,
|
|
||||||
# it will exit cleanly. Otherwise escalate to kill after a short wait.
|
|
||||||
try:
|
|
||||||
self.server_process.terminate()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
|
self.server_process.wait(timeout=3)
|
||||||
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
|
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
|
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
||||||
)
|
)
|
||||||
try:
|
self.server_process.kill()
|
||||||
self.server_process.kill()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=2)
|
self.server_process.wait(timeout=2)
|
||||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
||||||
@@ -294,33 +368,15 @@ class EmbeddingServerManager:
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
||||||
)
|
)
|
||||||
|
# Don't hang indefinitely
|
||||||
|
|
||||||
# Clean up process resources with timeout to avoid CI hang
|
# Clean up process resources to prevent resource tracker warnings
|
||||||
try:
|
try:
|
||||||
# Use shorter timeout in CI environments
|
self.server_process.wait() # Ensure process is fully cleaned up
|
||||||
is_ci = os.environ.get("CI") == "true"
|
|
||||||
timeout = 3 if is_ci else 10
|
|
||||||
self.server_process.wait(timeout=timeout)
|
|
||||||
logger.info(f"Server process {self.server_process.pid} cleanup completed")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error during process cleanup: {e}")
|
|
||||||
finally:
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
self._server_config = None
|
|
||||||
|
|
||||||
def _finalize_process(self) -> None:
|
|
||||||
"""Best-effort cleanup used by weakref.finalize/atexit."""
|
|
||||||
try:
|
|
||||||
self.stop_server()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _adopt_existing_server(self, *args, **kwargs) -> None:
|
self.server_process = None
|
||||||
# Removed: cross-process adoption no longer supported
|
|
||||||
return
|
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
@@ -336,16 +392,10 @@ class EmbeddingServerManager:
|
|||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback (unified)
|
# Register atexit callback
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
atexit.register(self._finalize_process)
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
# Record config for in-process reuse is best-effort in Colab mode
|
|
||||||
self._server_config = {
|
|
||||||
"model_name": "",
|
|
||||||
"passages_file": "",
|
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -34,9 +34,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
||||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
|
||||||
) -> int:
|
|
||||||
"""Ensure server is running"""
|
"""Ensure server is running"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -50,7 +48,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
@@ -76,7 +74,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
@@ -1,154 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def handle_request(request):
|
|
||||||
if request.get("method") == "initialize":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
elif request.get("method") == "tools/list":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "leann_search",
|
|
||||||
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
|
|
||||||
|
|
||||||
🎯 **Perfect for**:
|
|
||||||
- "How does authentication work?" → finds auth-related code
|
|
||||||
- "Error handling patterns" → locates try-catch blocks and error logic
|
|
||||||
- "Database connection setup" → finds DB initialization code
|
|
||||||
- "API endpoint definitions" → locates route handlers
|
|
||||||
- "Configuration management" → finds config files and usage
|
|
||||||
|
|
||||||
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
|
|
||||||
},
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
|
|
||||||
},
|
|
||||||
"top_k": {
|
|
||||||
"type": "integer",
|
|
||||||
"default": 5,
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 20,
|
|
||||||
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
|
|
||||||
},
|
|
||||||
"complexity": {
|
|
||||||
"type": "integer",
|
|
||||||
"default": 32,
|
|
||||||
"minimum": 16,
|
|
||||||
"maximum": 128,
|
|
||||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["index_name", "query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "leann_list",
|
|
||||||
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
|
||||||
"inputSchema": {"type": "object", "properties": {}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
elif request.get("method") == "tools/call":
|
|
||||||
tool_name = request["params"]["name"]
|
|
||||||
args = request["params"].get("arguments", {})
|
|
||||||
|
|
||||||
try:
|
|
||||||
if tool_name == "leann_search":
|
|
||||||
# Validate required parameters
|
|
||||||
if not args.get("index_name") or not args.get("query"):
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "Error: Both index_name and query are required",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Build simplified command with non-interactive flag for MCP compatibility
|
|
||||||
cmd = [
|
|
||||||
"leann",
|
|
||||||
"search",
|
|
||||||
args["index_name"],
|
|
||||||
args["query"],
|
|
||||||
f"--top-k={args.get('top_k', 5)}",
|
|
||||||
f"--complexity={args.get('complexity', 32)}",
|
|
||||||
"--non-interactive",
|
|
||||||
]
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": result.stdout
|
|
||||||
if result.returncode == 0
|
|
||||||
else f"Error: {result.stderr}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"error": {"code": -1, "message": str(e)},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
for line in sys.stdin:
|
|
||||||
try:
|
|
||||||
request = json.loads(line.strip())
|
|
||||||
response = handle_request(request)
|
|
||||||
if response:
|
|
||||||
print(json.dumps(response))
|
|
||||||
sys.stdout.flush()
|
|
||||||
except Exception as e:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": None,
|
|
||||||
"error": {"code": -1, "message": str(e)},
|
|
||||||
}
|
|
||||||
print(json.dumps(error_response))
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
"""
|
|
||||||
Metadata filtering engine for LEANN search results.
|
|
||||||
|
|
||||||
This module provides generic metadata filtering capabilities that can be applied
|
|
||||||
to search results from any LEANN backend. The filtering supports various
|
|
||||||
operators for different data types including numbers, strings, booleans, and lists.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Union
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Type alias for filter specifications
|
|
||||||
FilterValue = Union[str, int, float, bool, list]
|
|
||||||
FilterSpec = dict[str, FilterValue]
|
|
||||||
MetadataFilters = dict[str, FilterSpec]
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataFilterEngine:
|
|
||||||
"""
|
|
||||||
Engine for evaluating metadata filters against search results.
|
|
||||||
|
|
||||||
Supports various operators for filtering based on metadata fields:
|
|
||||||
- Comparison: ==, !=, <, <=, >, >=
|
|
||||||
- Membership: in, not_in
|
|
||||||
- String operations: contains, starts_with, ends_with
|
|
||||||
- Boolean operations: is_true, is_false
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the filter engine with supported operators."""
|
|
||||||
self.operators = {
|
|
||||||
"==": self._equals,
|
|
||||||
"!=": self._not_equals,
|
|
||||||
"<": self._less_than,
|
|
||||||
"<=": self._less_than_or_equal,
|
|
||||||
">": self._greater_than,
|
|
||||||
">=": self._greater_than_or_equal,
|
|
||||||
"in": self._in,
|
|
||||||
"not_in": self._not_in,
|
|
||||||
"contains": self._contains,
|
|
||||||
"starts_with": self._starts_with,
|
|
||||||
"ends_with": self._ends_with,
|
|
||||||
"is_true": self._is_true,
|
|
||||||
"is_false": self._is_false,
|
|
||||||
}
|
|
||||||
|
|
||||||
def apply_filters(
|
|
||||||
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Apply metadata filters to a list of search results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_results: List of result dictionaries, each containing 'metadata' field
|
|
||||||
metadata_filters: Dictionary of filter specifications
|
|
||||||
Format: {"field_name": {"operator": value}}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Filtered list of search results
|
|
||||||
"""
|
|
||||||
if not metadata_filters:
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
logger.debug(f"Applying filters: {metadata_filters}")
|
|
||||||
logger.debug(f"Input results count: {len(search_results)}")
|
|
||||||
|
|
||||||
filtered_results = []
|
|
||||||
for result in search_results:
|
|
||||||
if self._evaluate_filters(result, metadata_filters):
|
|
||||||
filtered_results.append(result)
|
|
||||||
|
|
||||||
logger.debug(f"Filtered results count: {len(filtered_results)}")
|
|
||||||
return filtered_results
|
|
||||||
|
|
||||||
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
|
|
||||||
"""
|
|
||||||
Evaluate all filters against a single search result.
|
|
||||||
|
|
||||||
All filters must pass (AND logic) for the result to be included.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: Full search result dictionary (including metadata, text, etc.)
|
|
||||||
filters: Filter specifications to evaluate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if all filters pass, False otherwise
|
|
||||||
"""
|
|
||||||
for field_name, filter_spec in filters.items():
|
|
||||||
if not self._evaluate_field_filter(result, field_name, filter_spec):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _evaluate_field_filter(
|
|
||||||
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Evaluate a single field filter against a search result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: Full search result dictionary
|
|
||||||
field_name: Name of the field to filter on
|
|
||||||
filter_spec: Filter specification for this field
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the filter passes, False otherwise
|
|
||||||
"""
|
|
||||||
# First check top-level fields, then check metadata
|
|
||||||
field_value = result.get(field_name)
|
|
||||||
if field_value is None:
|
|
||||||
# Try to get from metadata if not found at top level
|
|
||||||
metadata = result.get("metadata", {})
|
|
||||||
field_value = metadata.get(field_name)
|
|
||||||
|
|
||||||
# Handle missing fields - they fail all filters except existence checks
|
|
||||||
if field_value is None:
|
|
||||||
logger.debug(f"Field '{field_name}' not found in result or metadata")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Evaluate each operator in the filter spec
|
|
||||||
for operator, expected_value in filter_spec.items():
|
|
||||||
if operator not in self.operators:
|
|
||||||
logger.warning(f"Unsupported operator: {operator}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not self.operators[operator](field_value, expected_value):
|
|
||||||
logger.debug(
|
|
||||||
f"Filter failed: {field_name} {operator} {expected_value} "
|
|
||||||
f"(actual: {field_value})"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Comparison operators
|
|
||||||
def _equals(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value equals expected value."""
|
|
||||||
return field_value == expected_value
|
|
||||||
|
|
||||||
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value does not equal expected value."""
|
|
||||||
return field_value != expected_value
|
|
||||||
|
|
||||||
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is less than expected value."""
|
|
||||||
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
|
|
||||||
|
|
||||||
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is less than or equal to expected value."""
|
|
||||||
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
|
|
||||||
|
|
||||||
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is greater than expected value."""
|
|
||||||
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
|
|
||||||
|
|
||||||
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is greater than or equal to expected value."""
|
|
||||||
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
|
|
||||||
|
|
||||||
# Membership operators
|
|
||||||
def _in(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is in the expected list/collection."""
|
|
||||||
if not isinstance(expected_value, (list, tuple, set)):
|
|
||||||
raise ValueError("'in' operator requires a list, tuple, or set")
|
|
||||||
return field_value in expected_value
|
|
||||||
|
|
||||||
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is not in the expected list/collection."""
|
|
||||||
if not isinstance(expected_value, (list, tuple, set)):
|
|
||||||
raise ValueError("'not_in' operator requires a list, tuple, or set")
|
|
||||||
return field_value not in expected_value
|
|
||||||
|
|
||||||
# String operators
|
|
||||||
def _contains(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value contains the expected substring."""
|
|
||||||
field_str = str(field_value)
|
|
||||||
expected_str = str(expected_value)
|
|
||||||
return expected_str in field_str
|
|
||||||
|
|
||||||
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value starts with the expected prefix."""
|
|
||||||
field_str = str(field_value)
|
|
||||||
expected_str = str(expected_value)
|
|
||||||
return field_str.startswith(expected_str)
|
|
||||||
|
|
||||||
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value ends with the expected suffix."""
|
|
||||||
field_str = str(field_value)
|
|
||||||
expected_str = str(expected_value)
|
|
||||||
return field_str.endswith(expected_str)
|
|
||||||
|
|
||||||
# Boolean operators
|
|
||||||
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is truthy."""
|
|
||||||
return bool(field_value)
|
|
||||||
|
|
||||||
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
|
|
||||||
"""Check if field value is falsy."""
|
|
||||||
return not bool(field_value)
|
|
||||||
|
|
||||||
# Helper methods
|
|
||||||
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
|
|
||||||
"""
|
|
||||||
Helper for numeric comparisons with type coercion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
field_value: Value from metadata
|
|
||||||
expected_value: Value to compare against
|
|
||||||
compare_func: Comparison function to apply
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Result of comparison
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to convert both values to numbers for comparison
|
|
||||||
if isinstance(field_value, str) and isinstance(expected_value, str):
|
|
||||||
# String comparison if both are strings
|
|
||||||
return compare_func(field_value, expected_value)
|
|
||||||
|
|
||||||
# Numeric comparison - attempt to convert to float
|
|
||||||
field_num = (
|
|
||||||
float(field_value) if not isinstance(field_value, (int, float)) else field_value
|
|
||||||
)
|
|
||||||
expected_num = (
|
|
||||||
float(expected_value)
|
|
||||||
if not isinstance(expected_value, (int, float))
|
|
||||||
else expected_value
|
|
||||||
)
|
|
||||||
|
|
||||||
return compare_func(field_num, expected_num)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# Fall back to string comparison if numeric conversion fails
|
|
||||||
return compare_func(str(field_value), str(expected_value))
|
|
||||||
@@ -2,17 +2,11 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import json
|
from typing import TYPE_CHECKING
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
# Set up logger for this module
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -20,7 +14,7 @@ def register_backend(name: str):
|
|||||||
"""A decorator to register a new backend class."""
|
"""A decorator to register a new backend class."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls):
|
||||||
logger.debug(f"Registering backend '{name}'")
|
print(f"INFO: Registering backend '{name}'")
|
||||||
BACKEND_REGISTRY[name] = cls
|
BACKEND_REGISTRY[name] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@@ -45,54 +39,3 @@ def autodiscover_backends():
|
|||||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
pass
|
pass
|
||||||
# print("INFO: Backend auto-discovery finished.")
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|
||||||
|
|
||||||
def register_project_directory(project_dir: Optional[Union[str, Path]] = None):
|
|
||||||
"""
|
|
||||||
Register a project directory in the global LEANN registry.
|
|
||||||
|
|
||||||
This allows `leann list` to discover indexes created by apps or other tools.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_dir: Directory to register. If None, uses current working directory.
|
|
||||||
"""
|
|
||||||
if project_dir is None:
|
|
||||||
project_dir = Path.cwd()
|
|
||||||
else:
|
|
||||||
project_dir = Path(project_dir)
|
|
||||||
|
|
||||||
# Only register directories that have some kind of LEANN content
|
|
||||||
# Either .leann/indexes/ (CLI format) or *.leann.meta.json files (apps format)
|
|
||||||
has_cli_indexes = (project_dir / ".leann" / "indexes").exists()
|
|
||||||
has_app_indexes = any(project_dir.rglob("*.leann.meta.json"))
|
|
||||||
|
|
||||||
if not (has_cli_indexes or has_app_indexes):
|
|
||||||
# Don't register if there are no LEANN indexes
|
|
||||||
return
|
|
||||||
|
|
||||||
global_registry = Path.home() / ".leann" / "projects.json"
|
|
||||||
global_registry.parent.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
project_str = str(project_dir.resolve())
|
|
||||||
|
|
||||||
# Load existing registry
|
|
||||||
projects = []
|
|
||||||
if global_registry.exists():
|
|
||||||
try:
|
|
||||||
with open(global_registry) as f:
|
|
||||||
projects = json.load(f)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not load existing project registry")
|
|
||||||
projects = []
|
|
||||||
|
|
||||||
# Add project if not already present
|
|
||||||
if project_str not in projects:
|
|
||||||
projects.append(project_str)
|
|
||||||
|
|
||||||
# Save updated registry
|
|
||||||
try:
|
|
||||||
with open(global_registry, "w") as f:
|
|
||||||
json.dump(projects, f, indent=2)
|
|
||||||
logger.debug(f"Registered project directory: {project_str}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not save project registry: {e}")
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,147 +0,0 @@
|
|||||||
# 🔥 LEANN Claude Code Integration
|
|
||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
Install LEANN globally for MCP integration (with default backend):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv tool install leann-core --with leann
|
|
||||||
```
|
|
||||||
This installs the `leann` CLI into an isolated tool environment and includes both backends so `leann build` works out-of-the-box.
|
|
||||||
|
|
||||||
## 🚀 Quick Setup
|
|
||||||
|
|
||||||
Add the LEANN MCP server to Claude Code. Choose the scope based on how widely you want it available. Below is the command to install it globally; if you prefer a local install, skip this step:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Global (recommended): available in all projects for your user
|
|
||||||
claude mcp add --scope user leann-server -- leann_mcp
|
|
||||||
```
|
|
||||||
|
|
||||||
- `leann-server`: the display name of the MCP server in Claude Code (you can change it).
|
|
||||||
- `leann_mcp`: the Python entry point installed with LEANN that starts the MCP server.
|
|
||||||
|
|
||||||
Verify it is registered globally:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
claude mcp list | cat
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🛠️ Available Tools
|
|
||||||
|
|
||||||
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
|
|
||||||
|
|
||||||
- **`leann_list`** - List all available indexes across your projects
|
|
||||||
- **`leann_search`** - Perform semantic searches across code and documents
|
|
||||||
|
|
||||||
|
|
||||||
## 🎯 Quick Start Example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Add locally if you did not add it globally (current folder only; default if --scope is omitted)
|
|
||||||
claude mcp add leann-server -- leann_mcp
|
|
||||||
|
|
||||||
# Build an index for your project (change to your actual path)
|
|
||||||
# See the advanced examples below for more ways to configure indexing
|
|
||||||
# Set the index name (replace 'my-project' with your own)
|
|
||||||
leann build my-project --docs $(git ls-files)
|
|
||||||
|
|
||||||
# Start Claude Code
|
|
||||||
claude
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🚀 Advanced Usage Examples to build the index
|
|
||||||
|
|
||||||
### Index Entire Git Repository
|
|
||||||
```bash
|
|
||||||
# Index all tracked files in your Git repository.
|
|
||||||
# Note: submodules are currently skipped; we can add them back if needed.
|
|
||||||
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Index only tracked Python files from Git.
|
|
||||||
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# If you encounter empty requests caused by empty files (e.g., __init__.py), exclude zero-byte files. Thanks @ww2283 for pointing [that](https://github.com/yichuan-w/LEANN/issues/48) out
|
|
||||||
leann build leann-prospec-lig --docs $(find ./src -name "*.py" -not -empty) --embedding-mode openai --embedding-model text-embedding-3-small
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multiple Directories and Files
|
|
||||||
```bash
|
|
||||||
# Index multiple directories
|
|
||||||
leann build my-codebase --docs ./src ./tests ./docs ./config --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Mix files and directories
|
|
||||||
leann build my-project --docs ./README.md ./src/ ./package.json ./docs/ --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Specific files only
|
|
||||||
leann build my-configs --docs ./tsconfig.json ./package.json ./webpack.config.js --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
```
|
|
||||||
|
|
||||||
### Advanced Git Integration
|
|
||||||
```bash
|
|
||||||
# Index recently modified files
|
|
||||||
leann build recent-changes --docs $(git diff --name-only HEAD~10..HEAD) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Index files matching pattern
|
|
||||||
leann build frontend --docs $(git ls-files "*.tsx" "*.ts" "*.jsx" "*.js") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Index documentation and config files
|
|
||||||
leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.json" "*.toml") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## **Try this in Claude Code:**
|
|
||||||
```
|
|
||||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
|
||||||
```
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
If you see a prompt asking whether to proceed with LEANN, you can now use it in your chat!
|
|
||||||
|
|
||||||
## 🧠 How It Works
|
|
||||||
|
|
||||||
The integration consists of three key components working seamlessly together:
|
|
||||||
|
|
||||||
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
|
|
||||||
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
|
|
||||||
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
|
|
||||||
|
|
||||||
## 📁 File Support
|
|
||||||
|
|
||||||
LEANN understands **30+ file types** including:
|
|
||||||
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
|
|
||||||
- **Data**: SQL, YAML, JSON, CSV, XML
|
|
||||||
- **Documentation**: Markdown, TXT, PDF
|
|
||||||
- **And many more!**
|
|
||||||
|
|
||||||
## 💾 Storage & Organization
|
|
||||||
|
|
||||||
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
|
|
||||||
- **Global registry**: Project tracking at `~/.leann/projects.json`
|
|
||||||
- **Multi-project support**: Switch between different codebases seamlessly
|
|
||||||
- **Portable**: Transfer indexes between machines with minimal overhead
|
|
||||||
|
|
||||||
## 🗑️ Uninstalling
|
|
||||||
|
|
||||||
To remove the LEANN MCP server from Claude Code:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
claude mcp remove leann-server
|
|
||||||
```
|
|
||||||
To remove LEANN
|
|
||||||
```
|
|
||||||
uv pip uninstall leann leann-backend-hnsw leann-core
|
|
||||||
```
|
|
||||||
|
|
||||||
To globally remove LEANN (for version update)
|
|
||||||
```
|
|
||||||
uv tool list | cat
|
|
||||||
uv tool uninstall leann-core
|
|
||||||
command -v leann || echo "leann gone"
|
|
||||||
command -v leann_mcp || echo "leann_mcp gone"
|
|
||||||
```
|
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.2"
|
version = "0.1.16"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
__all__ = []
|
|
||||||
@@ -136,9 +136,5 @@ def export_sqlite(
|
|||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
app()
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ requires-python = ">=3.9"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
"leann-backend-hnsw",
|
"leann-backend-hnsw",
|
||||||
"typer>=0.12.3",
|
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"torch",
|
"torch",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"flask",
|
||||||
|
"flask_compress",
|
||||||
"datasets>=2.15.0",
|
"datasets>=2.15.0",
|
||||||
"evaluate",
|
"evaluate",
|
||||||
"colorama",
|
"colorama",
|
||||||
@@ -31,7 +32,7 @@ dependencies = [
|
|||||||
"pypdfium2>=4.30.0",
|
"pypdfium2>=4.30.0",
|
||||||
# LlamaIndex core and readers - updated versions
|
# LlamaIndex core and readers - updated versions
|
||||||
"llama-index>=0.12.44",
|
"llama-index>=0.12.44",
|
||||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
@@ -39,20 +40,9 @@ dependencies = [
|
|||||||
# Other dependencies
|
# Other dependencies
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
"pybind11>=3.0.0",
|
|
||||||
"pathspec>=0.12.1",
|
|
||||||
"nbconvert>=7.16.6",
|
|
||||||
"gitignore-parser>=0.1.12",
|
|
||||||
# AST-aware code chunking dependencies
|
|
||||||
"astchunk>=0.1.0",
|
|
||||||
"tree-sitter>=0.20.0",
|
|
||||||
"tree-sitter-python>=0.20.0",
|
|
||||||
"tree-sitter-java>=0.20.0",
|
|
||||||
"tree-sitter-c-sharp>=0.20.0",
|
|
||||||
"tree-sitter-typescript>=0.20.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -61,7 +51,7 @@ dev = [
|
|||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
"pytest-xdist>=3.0", # For parallel test execution
|
"pytest-xdist>=3.0", # For parallel test execution
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
"ruff>=0.1.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
"pre-commit>=3.5.0",
|
"pre-commit>=3.5.0",
|
||||||
@@ -71,7 +61,9 @@ test = [
|
|||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
"pytest-timeout>=2.0",
|
"pytest-timeout>=2.0",
|
||||||
"llama-index-core>=0.12.0",
|
"llama-index-core>=0.12.0",
|
||||||
|
"llama-index-readers-file>=0.4.0",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
diskann = [
|
diskann = [
|
||||||
@@ -88,11 +80,6 @@ documents = [
|
|||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = []
|
py-modules = []
|
||||||
packages = ["wechat_exporter"]
|
|
||||||
package-dir = { "wechat_exporter" = "packages/wechat-exporter" }
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
wechat-exporter = "wechat_exporter.main:main"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
@@ -101,10 +88,15 @@ leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = tr
|
|||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py310"
|
||||||
line-length = 100
|
line-length = 100
|
||||||
extend-exclude = ["third_party"]
|
extend-exclude = [
|
||||||
|
"third_party",
|
||||||
|
"*.egg-info",
|
||||||
|
"__pycache__",
|
||||||
|
".git",
|
||||||
|
".venv",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
@@ -127,19 +119,20 @@ ignore = [
|
|||||||
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
|
||||||
|
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
indent-style = "space"
|
indent-style = "space"
|
||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
|
|
||||||
[tool.lychee]
|
[dependency-groups]
|
||||||
accept = ["200", "403", "429", "503"]
|
dev = [
|
||||||
timeout = 20
|
"ruff>=0.12.4",
|
||||||
max_retries = 2
|
]
|
||||||
exclude = ["localhost", "127.0.0.1", "example.com"]
|
|
||||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
|
||||||
scheme = ["https", "http"]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
@@ -150,7 +143,7 @@ markers = [
|
|||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"openai: marks tests that require OpenAI API key",
|
"openai: marks tests that require OpenAI API key",
|
||||||
]
|
]
|
||||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
timeout = 600
|
||||||
addopts = [
|
addopts = [
|
||||||
"-v",
|
"-v",
|
||||||
"--tb=short",
|
"--tb=short",
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
name: leann-build
|
|
||||||
|
|
||||||
resources:
|
|
||||||
# Choose a GPU for fast embeddings (examples: L4, A10G, A100). CPU also works but is slower.
|
|
||||||
accelerators: L4:1
|
|
||||||
# Optionally pin a cloud, otherwise SkyPilot will auto-select
|
|
||||||
# cloud: aws
|
|
||||||
disk_size: 100
|
|
||||||
|
|
||||||
envs:
|
|
||||||
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
|
|
||||||
index_name: my-index
|
|
||||||
docs: ./data
|
|
||||||
backend: hnsw # hnsw | diskann
|
|
||||||
complexity: 64
|
|
||||||
graph_degree: 32
|
|
||||||
num_threads: 8
|
|
||||||
# Embedding selection
|
|
||||||
embedding_mode: sentence-transformers # sentence-transformers | openai | mlx | ollama
|
|
||||||
embedding_model: facebook/contriever
|
|
||||||
# Storage/latency knobs
|
|
||||||
recompute: true # true => selective recomputation (recommended)
|
|
||||||
compact: true # for HNSW only
|
|
||||||
# Optional pass-through
|
|
||||||
extra_args: ""
|
|
||||||
# Rebuild control
|
|
||||||
force: true
|
|
||||||
|
|
||||||
# Sync local paths to the remote VM. Adjust as needed.
|
|
||||||
file_mounts:
|
|
||||||
# Example: mount your local data directory used for building
|
|
||||||
~/leann-data: ${docs}
|
|
||||||
|
|
||||||
setup: |
|
|
||||||
set -e
|
|
||||||
# Install uv (package manager)
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
export PATH="$HOME/.local/bin:$PATH"
|
|
||||||
|
|
||||||
# Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30)
|
|
||||||
sudo apt-get update -y
|
|
||||||
sudo apt-get install -y libstdc++6 libgomp1
|
|
||||||
# Also upgrade conda's libstdc++ in base env (Skypilot images include conda)
|
|
||||||
if command -v conda >/dev/null 2>&1; then
|
|
||||||
conda install -y -n base -c conda-forge libstdcxx-ng
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Install LEANN CLI and backends into the user environment
|
|
||||||
uv pip install --upgrade pip
|
|
||||||
uv pip install leann-core leann-backend-hnsw leann-backend-diskann
|
|
||||||
|
|
||||||
run: |
|
|
||||||
export PATH="$HOME/.local/bin:$PATH"
|
|
||||||
# Derive flags from env
|
|
||||||
recompute_flag=""
|
|
||||||
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
|
|
||||||
recompute_flag="--no-recompute"
|
|
||||||
fi
|
|
||||||
force_flag=""
|
|
||||||
if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then
|
|
||||||
force_flag="--force"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Build command
|
|
||||||
python -m leann.cli build ${index_name} \
|
|
||||||
--docs ~/leann-data \
|
|
||||||
--backend ${backend} \
|
|
||||||
--complexity ${complexity} \
|
|
||||||
--graph-degree ${graph_degree} \
|
|
||||||
--num-threads ${num_threads} \
|
|
||||||
--embedding-mode ${embedding_mode} \
|
|
||||||
--embedding-model ${embedding_model} \
|
|
||||||
${recompute_flag} ${force_flag} ${extra_args}
|
|
||||||
|
|
||||||
# Print where the index is stored for downstream rsync
|
|
||||||
echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"
|
|
||||||
161
test/mail_reader_llamaindex.py
Normal file
161
test/mail_reader_llamaindex.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document, VectorStoreIndex
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
body += part.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create document content
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"file_path": filepath,
|
||||||
|
"subject": subject,
|
||||||
|
"from": from_addr,
|
||||||
|
"to": to_addr,
|
||||||
|
"date": date,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
||||||
|
if count == 0:
|
||||||
|
print("--------------------------------")
|
||||||
|
print("dir path", dirpath)
|
||||||
|
print(metadata)
|
||||||
|
print(doc_content)
|
||||||
|
print("--------------------------------")
|
||||||
|
body = []
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
print(
|
||||||
|
"-------------------------------- get content type -------------------------------"
|
||||||
|
)
|
||||||
|
print(part.get_content_type())
|
||||||
|
print(part)
|
||||||
|
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
|
||||||
|
print(
|
||||||
|
"-------------------------------- get content type -------------------------------"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
print(body)
|
||||||
|
|
||||||
|
print(body)
|
||||||
|
print("--------------------------------")
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
# Use the custom EmlxReader instead of MboxReader
|
||||||
|
documents = EmlxReader().load_data(
|
||||||
|
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||||
|
max_count=1000,
|
||||||
|
) # Returns list of documents
|
||||||
|
|
||||||
|
# Configure the index with larger chunk size to handle long metadata
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Create a custom text splitter with larger chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
|
||||||
|
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents, transformations=[text_splitter]
|
||||||
|
) # Initialize index with documents
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||||
|
print(res)
|
||||||
219
test/mail_reader_save_load.py
Normal file
219
test/mail_reader_save_load.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain":
|
||||||
|
body = part.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create document content
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"file_path": filepath,
|
||||||
|
"subject": subject,
|
||||||
|
"from": from_addr,
|
||||||
|
"to": to_addr,
|
||||||
|
"date": date,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Create the index from mail data and save it to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
save_dir: Directory to save the index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
"""
|
||||||
|
print("Creating index from mail data...")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create text splitter
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||||
|
|
||||||
|
# Save the index
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index"):
|
||||||
|
"""
|
||||||
|
Load the saved index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir: Directory where the index is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded index or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load storage context
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store, storage_context=storage_context
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
"""
|
||||||
|
Query the loaded index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The loaded index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||||
|
save_dir = "mail_index"
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||||
|
|
||||||
|
if index:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"What emails mention GSR appointments?",
|
||||||
|
"Find emails about deadlines",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
219
test/mail_reader_small_chunks.py
Normal file
219
test/mail_reader_small_chunks.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with reduced metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain":
|
||||||
|
body = part.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create minimal metadata (only essential info)
|
||||||
|
metadata = {
|
||||||
|
"subject": subject[:50], # Truncate subject
|
||||||
|
"from": from_addr[:30], # Truncate from
|
||||||
|
"date": date[:20], # Truncate date
|
||||||
|
"filename": filename, # Keep filename
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_save_index(
|
||||||
|
mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create the index from mail data and save it to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
save_dir: Directory to save the index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
"""
|
||||||
|
print("Creating index from mail data with small chunks...")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create text splitter with small chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||||
|
|
||||||
|
# Save the index
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
|
print(f"Index saved to {save_dir}")
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index_small"):
|
||||||
|
"""
|
||||||
|
Load the saved index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir: Directory where the index is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded index or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load storage context
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store, storage_context=storage_context
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
"""
|
||||||
|
Query the loaded index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The loaded index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"Query: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||||
|
save_dir = "mail_index_small"
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
|
print("Loading existing index...")
|
||||||
|
index = load_index(save_dir)
|
||||||
|
else:
|
||||||
|
print("Creating new index...")
|
||||||
|
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||||
|
|
||||||
|
if index:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"What emails mention GSR appointments?",
|
||||||
|
"Find emails about deadlines",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
154
test/mail_reader_test.py
Normal file
154
test/mail_reader_test.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document, VectorStoreIndex
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Check if directory exists and is accessible
|
||||||
|
if not os.path.exists(input_dir):
|
||||||
|
print(f"Error: Directory '{input_dir}' does not exist")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
if not os.access(input_dir, os.R_OK):
|
||||||
|
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
|
||||||
|
print("This is likely due to macOS security restrictions on Mail app data")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
print(f"Scanning directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
print(f"Found .emlx file: {filepath}")
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain":
|
||||||
|
body = part.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create document content
|
||||||
|
doc_content = f"""
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"file_path": filepath,
|
||||||
|
"subject": subject,
|
||||||
|
"from": from_addr,
|
||||||
|
"to": to_addr,
|
||||||
|
"date": date,
|
||||||
|
"filename": filename,
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Use the current directory where the sample.emlx file is located
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
print("Testing EmlxReader with sample .emlx file...")
|
||||||
|
print(f"Scanning directory: {current_dir}")
|
||||||
|
|
||||||
|
# Use the custom EmlxReader
|
||||||
|
documents = EmlxReader().load_data(current_dir, max_count=1000)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\nSuccessfully loaded {len(documents)} document(s)")
|
||||||
|
|
||||||
|
# Initialize index with documents
|
||||||
|
index = VectorStoreIndex.from_documents(documents)
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
|
||||||
|
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
|
||||||
|
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||||
|
print(f"Response: {res}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
105
test/query_saved_index.py
Normal file
105
test/query_saved_index.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from llama_index.core import StorageContext, VectorStoreIndex
|
||||||
|
|
||||||
|
|
||||||
|
def load_index(save_dir: str = "mail_index"):
|
||||||
|
"""
|
||||||
|
Load the saved index from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir: Directory where the index is saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded index or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load storage context
|
||||||
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
|
|
||||||
|
# Load index
|
||||||
|
index = VectorStoreIndex.from_vector_store(
|
||||||
|
storage_context.vector_store, storage_context=storage_context
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Index loaded from {save_dir}")
|
||||||
|
return index
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading index: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def query_index(index, query: str):
|
||||||
|
"""
|
||||||
|
Query the loaded index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: The loaded index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
if index is None:
|
||||||
|
print("No index available for querying.")
|
||||||
|
return
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine()
|
||||||
|
response = query_engine.query(query)
|
||||||
|
print(f"\nQuery: {query}")
|
||||||
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
save_dir = "mail_index"
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
if not os.path.exists(save_dir) or not os.path.exists(
|
||||||
|
os.path.join(save_dir, "vector_store.json")
|
||||||
|
):
|
||||||
|
print(f"Index not found in {save_dir}")
|
||||||
|
print("Please run mail_reader_save_load.py first to create the index.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load the index
|
||||||
|
index = load_index(save_dir)
|
||||||
|
|
||||||
|
if not index:
|
||||||
|
print("Failed to load index.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Email Query Interface")
|
||||||
|
print("=" * 60)
|
||||||
|
print("Type 'quit' to exit")
|
||||||
|
print("Type 'help' for example queries")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Interactive query loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\nEnter your query: ").strip()
|
||||||
|
|
||||||
|
if query.lower() == "quit":
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
elif query.lower() == "help":
|
||||||
|
print("\nExample queries:")
|
||||||
|
print("- Hows Berkeley Graduate Student Instructor")
|
||||||
|
print("- What emails mention GSR appointments?")
|
||||||
|
print("- Find emails about deadlines")
|
||||||
|
print("- Search for emails from specific sender")
|
||||||
|
print("- Find emails about meetings")
|
||||||
|
continue
|
||||||
|
elif not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
query_index(index, query)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing query: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,24 +1,9 @@
|
|||||||
# 🧪 LEANN Benchmarks & Testing
|
# 🧪 Leann Sanity Checks
|
||||||
|
|
||||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
### `diskann_vs_hnsw_speed_comparison.py`
|
|
||||||
Performance comparison between DiskANN and HNSW backends:
|
|
||||||
- ✅ **Search latency** comparison with both backends using recompute
|
|
||||||
- ✅ **Index size** and **build time** measurements
|
|
||||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
|
||||||
- ✅ **Configurable dataset sizes** for different scales
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Quick comparison with 500 docs, 10 queries
|
|
||||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
|
||||||
|
|
||||||
# Large-scale comparison with 2000 docs, 20 queries
|
|
||||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
|
||||||
```
|
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
117
test/sanity_checks/debug_zmq_issue.py
Normal file
117
test/sanity_checks/debug_zmq_issue.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
sys.path.append("packages/leann-backend-diskann")
|
||||||
|
from leann_backend_diskann import embedding_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def test_zmq_with_same_model():
|
||||||
|
print("=== Testing ZMQ with same model as main_cli_example.py ===")
|
||||||
|
|
||||||
|
# Test the exact same model that main_cli_example.py uses
|
||||||
|
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
|
||||||
|
# Start server with the same model
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
server_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||||
|
"--zmq-port",
|
||||||
|
"5556", # Use different port to avoid conflicts
|
||||||
|
"--model-name",
|
||||||
|
model_name,
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Starting server with command: {' '.join(server_cmd)}")
|
||||||
|
server_process = subprocess.Popen(
|
||||||
|
server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
print("Waiting for server to start...")
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
# Check if server is running
|
||||||
|
if server_process.poll() is not None:
|
||||||
|
stdout, stderr = server_process.communicate()
|
||||||
|
print(f"Server failed to start. stdout: {stdout}")
|
||||||
|
print(f"Server failed to start. stderr: {stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Server started with PID: {server_process.pid}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test client
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.connect("tcp://127.0.0.1:5556")
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||||
|
|
||||||
|
# Create request with same format as C++
|
||||||
|
request = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
|
||||||
|
|
||||||
|
print(f"Sending request with {len(request.node_ids)} node IDs...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
socket.send(request.SerializeToString())
|
||||||
|
|
||||||
|
# Receive response
|
||||||
|
response_data = socket.recv()
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"Received response in {end_time - start_time:.3f} seconds")
|
||||||
|
print(f"Response size: {len(response_data)} bytes")
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
response = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
response.ParseFromString(response_data)
|
||||||
|
|
||||||
|
print(f"Response dimensions: {list(response.dimensions)}")
|
||||||
|
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
|
||||||
|
print(f"Missing IDs: {list(response.missing_ids)}")
|
||||||
|
|
||||||
|
# Calculate expected size
|
||||||
|
if len(response.dimensions) == 2:
|
||||||
|
batch_size = response.dimensions[0]
|
||||||
|
embedding_dim = response.dimensions[1]
|
||||||
|
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
|
||||||
|
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
|
||||||
|
|
||||||
|
if len(response.embeddings_data) == expected_bytes:
|
||||||
|
print("✅ Response format is correct!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Response format mismatch!")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("❌ Invalid response dimensions!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during ZMQ test: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
server_process.terminate()
|
||||||
|
server_process.wait()
|
||||||
|
print("Server terminated")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_zmq_with_same_model()
|
||||||
|
if success:
|
||||||
|
print("\n✅ ZMQ communication test passed!")
|
||||||
|
else:
|
||||||
|
print("\n❌ ZMQ communication test failed!")
|
||||||
@@ -20,7 +20,7 @@ except ImportError:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever-msmarco"
|
model_path: str = "facebook/contriever"
|
||||||
batch_sizes: list[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
@@ -179,16 +179,10 @@ class Benchmark:
|
|||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
# print shape of input_ids and attention_mask
|
|
||||||
print(f"input_ids shape: {input_ids.shape}")
|
|
||||||
print(f"attention_mask shape: {attention_mask.shape}")
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
torch.mps.synchronize()
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
@@ -6,11 +6,10 @@ This directory contains automated tests for the LEANN project using pytest.
|
|||||||
|
|
||||||
### `test_readme_examples.py`
|
### `test_readme_examples.py`
|
||||||
Tests the examples shown in README.md:
|
Tests the examples shown in README.md:
|
||||||
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
- The basic example code that users see first
|
||||||
- Import statements work correctly
|
- Import statements work correctly
|
||||||
- Different backend options (HNSW, DiskANN)
|
- Different backend options (HNSW, DiskANN)
|
||||||
- Different LLM configuration options (parametrized for both backends)
|
- Different LLM configuration options
|
||||||
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
|
||||||
|
|
||||||
### `test_basic.py`
|
### `test_basic.py`
|
||||||
Basic functionality tests that verify:
|
Basic functionality tests that verify:
|
||||||
@@ -19,23 +18,13 @@ Basic functionality tests that verify:
|
|||||||
- Basic index building and searching works for both HNSW and DiskANN backends
|
- Basic index building and searching works for both HNSW and DiskANN backends
|
||||||
- Uses parametrized tests to test both backends
|
- Uses parametrized tests to test both backends
|
||||||
|
|
||||||
### `test_document_rag.py`
|
### `test_main_cli.py`
|
||||||
Tests the document RAG example functionality:
|
Tests the main CLI example functionality:
|
||||||
- Tests with facebook/contriever embeddings
|
- Tests with facebook/contriever embeddings
|
||||||
- Tests with OpenAI embeddings (if API key is available)
|
- Tests with OpenAI embeddings (if API key is available)
|
||||||
- Tests error handling with invalid parameters
|
- Tests error handling with invalid parameters
|
||||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||||
|
|
||||||
### `test_diskann_partition.py`
|
|
||||||
Tests DiskANN graph partitioning functionality:
|
|
||||||
- Tests DiskANN index building without partitioning (baseline)
|
|
||||||
- Tests automatic graph partitioning with `is_recompute=True`
|
|
||||||
- Verifies that partition files are created and large files are cleaned up for storage saving
|
|
||||||
- Tests search functionality with partitioned indices
|
|
||||||
- Validates medoid and max_base_norm file generation and usage
|
|
||||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
|
||||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
|
||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
### Install test dependencies:
|
### Install test dependencies:
|
||||||
@@ -65,23 +54,15 @@ pytest tests/ -m "not openai"
|
|||||||
|
|
||||||
# Skip slow tests
|
# Skip slow tests
|
||||||
pytest tests/ -m "not slow"
|
pytest tests/ -m "not slow"
|
||||||
|
|
||||||
# Run DiskANN partition tests (requires local machine, not CI)
|
|
||||||
pytest tests/test_diskann_partition.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run with specific backend:
|
### Run with specific backend:
|
||||||
```bash
|
```bash
|
||||||
# Test only HNSW backend
|
# Test only HNSW backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||||
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
|
||||||
|
|
||||||
# Test only DiskANN backend
|
# Test only DiskANN backend
|
||||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||||
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
|
||||||
|
|
||||||
# All DiskANN tests (parametrized + specialized partition tests)
|
|
||||||
pytest tests/ -k diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## CI/CD Integration
|
## CI/CD Integration
|
||||||
|
|||||||
@@ -1,397 +0,0 @@
|
|||||||
"""
|
|
||||||
Test suite for astchunk integration with LEANN.
|
|
||||||
Tests AST-aware chunking functionality, language detection, and fallback mechanisms.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Add apps directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent / "apps"))
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from chunking import (
|
|
||||||
create_ast_chunks,
|
|
||||||
create_text_chunks,
|
|
||||||
create_traditional_chunks,
|
|
||||||
detect_code_files,
|
|
||||||
get_language_from_extension,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MockDocument:
|
|
||||||
"""Mock LlamaIndex Document for testing."""
|
|
||||||
|
|
||||||
def __init__(self, content: str, file_path: str = "", metadata: Optional[dict] = None):
|
|
||||||
self.content = content
|
|
||||||
self.metadata = metadata or {}
|
|
||||||
if file_path:
|
|
||||||
self.metadata["file_path"] = file_path
|
|
||||||
|
|
||||||
def get_content(self) -> str:
|
|
||||||
return self.content
|
|
||||||
|
|
||||||
|
|
||||||
class TestCodeFileDetection:
|
|
||||||
"""Test code file detection and language mapping."""
|
|
||||||
|
|
||||||
def test_detect_code_files_python(self):
|
|
||||||
"""Test detection of Python files."""
|
|
||||||
docs = [
|
|
||||||
MockDocument("print('hello')", "/path/to/file.py"),
|
|
||||||
MockDocument("This is text", "/path/to/file.txt"),
|
|
||||||
]
|
|
||||||
|
|
||||||
code_docs, text_docs = detect_code_files(docs)
|
|
||||||
|
|
||||||
assert len(code_docs) == 1
|
|
||||||
assert len(text_docs) == 1
|
|
||||||
assert code_docs[0].metadata["language"] == "python"
|
|
||||||
assert code_docs[0].metadata["is_code"] is True
|
|
||||||
assert text_docs[0].metadata["is_code"] is False
|
|
||||||
|
|
||||||
def test_detect_code_files_multiple_languages(self):
|
|
||||||
"""Test detection of multiple programming languages."""
|
|
||||||
docs = [
|
|
||||||
MockDocument("def func():", "/path/to/script.py"),
|
|
||||||
MockDocument("public class Test {}", "/path/to/Test.java"),
|
|
||||||
MockDocument("interface ITest {}", "/path/to/test.ts"),
|
|
||||||
MockDocument("using System;", "/path/to/Program.cs"),
|
|
||||||
MockDocument("Regular text content", "/path/to/document.txt"),
|
|
||||||
]
|
|
||||||
|
|
||||||
code_docs, text_docs = detect_code_files(docs)
|
|
||||||
|
|
||||||
assert len(code_docs) == 4
|
|
||||||
assert len(text_docs) == 1
|
|
||||||
|
|
||||||
languages = [doc.metadata["language"] for doc in code_docs]
|
|
||||||
assert "python" in languages
|
|
||||||
assert "java" in languages
|
|
||||||
assert "typescript" in languages
|
|
||||||
assert "csharp" in languages
|
|
||||||
|
|
||||||
def test_detect_code_files_no_file_path(self):
|
|
||||||
"""Test handling of documents without file paths."""
|
|
||||||
docs = [
|
|
||||||
MockDocument("some content"),
|
|
||||||
MockDocument("other content", metadata={"some_key": "value"}),
|
|
||||||
]
|
|
||||||
|
|
||||||
code_docs, text_docs = detect_code_files(docs)
|
|
||||||
|
|
||||||
assert len(code_docs) == 0
|
|
||||||
assert len(text_docs) == 2
|
|
||||||
for doc in text_docs:
|
|
||||||
assert doc.metadata["is_code"] is False
|
|
||||||
|
|
||||||
def test_get_language_from_extension(self):
|
|
||||||
"""Test language detection from file extensions."""
|
|
||||||
assert get_language_from_extension("test.py") == "python"
|
|
||||||
assert get_language_from_extension("Test.java") == "java"
|
|
||||||
assert get_language_from_extension("component.tsx") == "typescript"
|
|
||||||
assert get_language_from_extension("Program.cs") == "csharp"
|
|
||||||
assert get_language_from_extension("document.txt") is None
|
|
||||||
assert get_language_from_extension("") is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestChunkingFunctions:
|
|
||||||
"""Test various chunking functionality."""
|
|
||||||
|
|
||||||
def test_create_traditional_chunks(self):
|
|
||||||
"""Test traditional text chunking."""
|
|
||||||
docs = [
|
|
||||||
MockDocument(
|
|
||||||
"This is a test document. It has multiple sentences. We want to test chunking."
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
|
||||||
|
|
||||||
assert len(chunks) > 0
|
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
|
||||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
|
||||||
|
|
||||||
def test_create_traditional_chunks_empty_docs(self):
|
|
||||||
"""Test traditional chunking with empty documents."""
|
|
||||||
chunks = create_traditional_chunks([], chunk_size=50, chunk_overlap=10)
|
|
||||||
assert chunks == []
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip astchunk tests in CI - dependency may not be available",
|
|
||||||
)
|
|
||||||
def test_create_ast_chunks_with_astchunk_available(self):
|
|
||||||
"""Test AST chunking when astchunk is available."""
|
|
||||||
python_code = '''
|
|
||||||
def hello_world():
|
|
||||||
"""Print hello world message."""
|
|
||||||
print("Hello, World!")
|
|
||||||
|
|
||||||
def add_numbers(a, b):
|
|
||||||
"""Add two numbers and return the result."""
|
|
||||||
return a + b
|
|
||||||
|
|
||||||
class Calculator:
|
|
||||||
"""A simple calculator class."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.history = []
|
|
||||||
|
|
||||||
def add(self, a, b):
|
|
||||||
result = a + b
|
|
||||||
self.history.append(f"{a} + {b} = {result}")
|
|
||||||
return result
|
|
||||||
'''
|
|
||||||
|
|
||||||
docs = [MockDocument(python_code, "/test/calculator.py", {"language": "python"})]
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
|
||||||
|
|
||||||
# Should have multiple chunks due to different functions/classes
|
|
||||||
assert len(chunks) > 0
|
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
|
||||||
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
|
||||||
|
|
||||||
# Check that code structure is somewhat preserved
|
|
||||||
combined_content = " ".join(chunks)
|
|
||||||
assert "def hello_world" in combined_content
|
|
||||||
assert "class Calculator" in combined_content
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
# astchunk not available, should fall back to traditional chunking
|
|
||||||
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
|
||||||
assert len(chunks) > 0 # Should still get chunks from fallback
|
|
||||||
|
|
||||||
def test_create_ast_chunks_fallback_to_traditional(self):
|
|
||||||
"""Test AST chunking falls back to traditional when astchunk is not available."""
|
|
||||||
docs = [MockDocument("def test(): pass", "/test/script.py", {"language": "python"})]
|
|
||||||
|
|
||||||
# Mock astchunk import to fail
|
|
||||||
with patch("chunking.create_ast_chunks"):
|
|
||||||
# First call (actual test) should import astchunk and potentially fail
|
|
||||||
# Let's call the actual function to test the import error handling
|
|
||||||
chunks = create_ast_chunks(docs)
|
|
||||||
|
|
||||||
# Should return some chunks (either from astchunk or fallback)
|
|
||||||
assert isinstance(chunks, list)
|
|
||||||
|
|
||||||
def test_create_text_chunks_traditional_mode(self):
|
|
||||||
"""Test text chunking in traditional mode."""
|
|
||||||
docs = [
|
|
||||||
MockDocument("def test(): pass", "/test/script.py"),
|
|
||||||
MockDocument("This is regular text.", "/test/doc.txt"),
|
|
||||||
]
|
|
||||||
|
|
||||||
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
|
||||||
|
|
||||||
assert len(chunks) > 0
|
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
|
||||||
|
|
||||||
def test_create_text_chunks_ast_mode(self):
|
|
||||||
"""Test text chunking in AST mode."""
|
|
||||||
docs = [
|
|
||||||
MockDocument("def test(): pass", "/test/script.py"),
|
|
||||||
MockDocument("This is regular text.", "/test/doc.txt"),
|
|
||||||
]
|
|
||||||
|
|
||||||
chunks = create_text_chunks(
|
|
||||||
docs,
|
|
||||||
use_ast_chunking=True,
|
|
||||||
ast_chunk_size=100,
|
|
||||||
ast_chunk_overlap=20,
|
|
||||||
chunk_size=50,
|
|
||||||
chunk_overlap=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(chunks) > 0
|
|
||||||
assert all(isinstance(chunk, str) for chunk in chunks)
|
|
||||||
|
|
||||||
def test_create_text_chunks_custom_extensions(self):
|
|
||||||
"""Test text chunking with custom code file extensions."""
|
|
||||||
docs = [
|
|
||||||
MockDocument("function test() {}", "/test/script.js"), # Not in default extensions
|
|
||||||
MockDocument("Regular text", "/test/doc.txt"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# First without custom extensions - should treat .js as text
|
|
||||||
chunks_without = create_text_chunks(docs, use_ast_chunking=True, code_file_extensions=None)
|
|
||||||
|
|
||||||
# Then with custom extensions - should treat .js as code
|
|
||||||
chunks_with = create_text_chunks(
|
|
||||||
docs, use_ast_chunking=True, code_file_extensions=[".js", ".jsx"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Both should return chunks
|
|
||||||
assert len(chunks_without) > 0
|
|
||||||
assert len(chunks_with) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestIntegrationWithDocumentRAG:
|
|
||||||
"""Integration tests with the document RAG system."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_code_dir(self):
|
|
||||||
"""Create a temporary directory with sample code files."""
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
temp_path = Path(temp_dir)
|
|
||||||
|
|
||||||
# Create sample Python file
|
|
||||||
python_file = temp_path / "example.py"
|
|
||||||
python_file.write_text('''
|
|
||||||
def fibonacci(n):
|
|
||||||
"""Calculate fibonacci number."""
|
|
||||||
if n <= 1:
|
|
||||||
return n
|
|
||||||
return fibonacci(n-1) + fibonacci(n-2)
|
|
||||||
|
|
||||||
class MathUtils:
|
|
||||||
@staticmethod
|
|
||||||
def factorial(n):
|
|
||||||
if n <= 1:
|
|
||||||
return 1
|
|
||||||
return n * MathUtils.factorial(n-1)
|
|
||||||
''')
|
|
||||||
|
|
||||||
# Create sample text file
|
|
||||||
text_file = temp_path / "readme.txt"
|
|
||||||
text_file.write_text("This is a sample text file for testing purposes.")
|
|
||||||
|
|
||||||
yield temp_path
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip integration tests in CI to avoid dependency issues",
|
|
||||||
)
|
|
||||||
def test_document_rag_with_ast_chunking(self, temp_code_dir):
|
|
||||||
"""Test document RAG with AST chunking enabled."""
|
|
||||||
with tempfile.TemporaryDirectory() as index_dir:
|
|
||||||
cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"apps/document_rag.py",
|
|
||||||
"--llm",
|
|
||||||
"simulated",
|
|
||||||
"--embedding-model",
|
|
||||||
"facebook/contriever",
|
|
||||||
"--embedding-mode",
|
|
||||||
"sentence-transformers",
|
|
||||||
"--index-dir",
|
|
||||||
index_dir,
|
|
||||||
"--data-dir",
|
|
||||||
str(temp_code_dir),
|
|
||||||
"--enable-code-chunking",
|
|
||||||
"--query",
|
|
||||||
"How does the fibonacci function work?",
|
|
||||||
]
|
|
||||||
|
|
||||||
env = os.environ.copy()
|
|
||||||
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
|
||||||
env["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=300, # 5 minutes
|
|
||||||
env=env,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should succeed even if astchunk is not available (fallback)
|
|
||||||
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
|
||||||
|
|
||||||
output = result.stdout + result.stderr
|
|
||||||
assert "Index saved to" in output or "Using existing index" in output
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
pytest.skip("Test timed out - likely due to model download in CI")
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("CI") == "true",
|
|
||||||
reason="Skip integration tests in CI to avoid dependency issues",
|
|
||||||
)
|
|
||||||
def test_code_rag_application(self, temp_code_dir):
|
|
||||||
"""Test the specialized code RAG application."""
|
|
||||||
with tempfile.TemporaryDirectory() as index_dir:
|
|
||||||
cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"apps/code_rag.py",
|
|
||||||
"--llm",
|
|
||||||
"simulated",
|
|
||||||
"--embedding-model",
|
|
||||||
"facebook/contriever",
|
|
||||||
"--index-dir",
|
|
||||||
index_dir,
|
|
||||||
"--repo-dir",
|
|
||||||
str(temp_code_dir),
|
|
||||||
"--query",
|
|
||||||
"What classes are defined in this code?",
|
|
||||||
]
|
|
||||||
|
|
||||||
env = os.environ.copy()
|
|
||||||
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
|
||||||
env["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, env=env)
|
|
||||||
|
|
||||||
# Should succeed
|
|
||||||
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
|
||||||
|
|
||||||
output = result.stdout + result.stderr
|
|
||||||
assert "Using AST-aware chunking" in output or "traditional chunking" in output
|
|
||||||
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
pytest.skip("Test timed out - likely due to model download in CI")
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
|
||||||
"""Test error handling and edge cases."""
|
|
||||||
|
|
||||||
def test_text_chunking_empty_documents(self):
|
|
||||||
"""Test text chunking with empty document list."""
|
|
||||||
chunks = create_text_chunks([])
|
|
||||||
assert chunks == []
|
|
||||||
|
|
||||||
def test_text_chunking_invalid_parameters(self):
|
|
||||||
"""Test text chunking with invalid parameters."""
|
|
||||||
docs = [MockDocument("test content")]
|
|
||||||
|
|
||||||
# Should handle negative chunk sizes gracefully
|
|
||||||
chunks = create_text_chunks(
|
|
||||||
docs, chunk_size=0, chunk_overlap=0, ast_chunk_size=0, ast_chunk_overlap=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should still return some result
|
|
||||||
assert isinstance(chunks, list)
|
|
||||||
|
|
||||||
def test_create_ast_chunks_no_language(self):
|
|
||||||
"""Test AST chunking with documents missing language metadata."""
|
|
||||||
docs = [MockDocument("def test(): pass", "/test/script.py")] # No language set
|
|
||||||
|
|
||||||
chunks = create_ast_chunks(docs)
|
|
||||||
|
|
||||||
# Should fall back to traditional chunking
|
|
||||||
assert isinstance(chunks, list)
|
|
||||||
assert len(chunks) >= 0 # May be empty if fallback also fails
|
|
||||||
|
|
||||||
def test_create_ast_chunks_empty_content(self):
|
|
||||||
"""Test AST chunking with empty content."""
|
|
||||||
docs = [MockDocument("", "/test/script.py", {"language": "python"})]
|
|
||||||
|
|
||||||
chunks = create_ast_chunks(docs)
|
|
||||||
|
|
||||||
# Should handle empty content gracefully
|
|
||||||
assert isinstance(chunks, list)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -64,9 +64,6 @@ def test_backend_basic(backend_name):
|
|||||||
assert isinstance(results[0], SearchResult)
|
assert isinstance(results[0], SearchResult)
|
||||||
assert "topic 2" in results[0].text or "document" in results[0].text
|
assert "topic 2" in results[0].text or "document" in results[0].text
|
||||||
|
|
||||||
# Ensure cleanup to avoid hanging background servers
|
|
||||||
searcher.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||||
@@ -93,5 +90,3 @@ def test_large_index():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
results = searcher.search(["word10 word20"], top_k=10)
|
results = searcher.search(["word10 word20"], top_k=10)
|
||||||
assert len(results[0]) == 10
|
assert len(results[0]) == 10
|
||||||
# Cleanup
|
|
||||||
searcher.cleanup()
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ def test_package_imports():
|
|||||||
def test_cli_help():
|
def test_cli_help():
|
||||||
"""Test that CLI example shows help."""
|
"""Test that CLI example shows help."""
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[sys.executable, "apps/document_rag.py", "--help"], capture_output=True, text=True
|
[sys.executable, "examples/main_cli_example.py", "--help"], capture_output=True, text=True
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.returncode == 0
|
assert result.returncode == 0
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user