Compare commits

..

22 Commits

Author SHA1 Message Date
Andy Lee
971653fa1a upgrade: switch from manylinux2014 to manylinux_2_35
- Use manylinux_2_35 (GLIBC 2.35) instead of manylinux2014 (GLIBC 2.17)
- Still compatible with Google Colab (requires ≤2.35)
- Benefits: newer toolchain, better performance, modern C++ features
- Switch from yum to dnf package manager
- Remove pyzmq version cap as manylinux_2_35 has newer ZeroMQ
- Update documentation to reflect the change
2025-07-25 13:40:07 -07:00
Andy Lee
02672c040d update: DiskANN submodule to support both build environments
- Point to fix/python-finding-compatibility branch
- Support both Development and Development.Module for Python finding
- Fixes compatibility with both standard and manylinux builds
2025-07-25 13:33:36 -07:00
Andy Lee
f55108feda fix: resolve CI conflicts and Python finding issues
- Remove duplicate ci-cibuildwheel.yml workflow to avoid confusion
- Fix DiskANN CMakeLists.txt to support both standard and manylinux builds
- Try Development.Module first (manylinux), fallback to Development (standard)
- Keep build-reusable.yml and build-cibuildwheel.yml as separate workflows
- They build different wheel types and are selected at release time
2025-07-25 13:21:22 -07:00
Andy Lee
74d485c908 fix: remove Chinese comments from code
- Replace all Chinese comments with English
- Ensure code is internationalization-friendly
2025-07-25 12:55:28 -07:00
Andy Lee
d1fefb6378 fix: correct DiskANN subdirectory path
- Change from src/third_party/DiskANN to third_party/DiskANN
- This was causing CMake to fail finding the subdirectory
2025-07-25 12:51:47 -07:00
Andy Lee
732384f4f8 fix: improve Python finding for DiskANN in manylinux environment
- Add explicit Python finding logic to DiskANN CMakeLists.txt
- Change Development to Development.Module to avoid Embed requirement
- Pass Python variables through CMake cache to submodules
2025-07-25 12:22:27 -07:00
Andy Lee
ae38e10d1b chore: focus on Linux builds for Colab compatibility
- Remove macOS from build matrix to simplify CI
- Keep macOS configurations for future reference but they won't be used
- This speeds up CI and focuses on the primary target (Colab/Linux)
2025-07-25 12:20:32 -07:00
Andy Lee
ca0fd88934 chore: clean up temporary test scripts 2025-07-25 11:53:23 -07:00
Andy Lee
3c8d32f156 fix: prevent pyzmq compilation during tests
- Cap pyzmq version to <27 for manylinux2014 compatibility
- Pre-install pyzmq binary wheel before tests using CIBW_BEFORE_TEST
- Force pip to use only binary wheels with --only-binary :all:
2025-07-25 11:35:31 -07:00
Andy Lee
b8ff00fc6a fix: address macOS deployment target and pyzmq compilation issues
- Set MACOSX_DEPLOYMENT_TARGET=11.0 for macOS builds
- Add pyzmq to test requirements to use pre-built wheels
- Configure deployment target in both workflow and pyproject.toml
- Skip ARM64 tests on GitHub Actions to avoid cross-compilation issues
2025-07-25 11:13:51 -07:00
Andy Lee
3c836766f8 fix: update faiss submodule to fix-python-finding-manylinux branch 2025-07-25 10:55:28 -07:00
Andy Lee
b4a1dfb9c7 fix: update faiss submodule pointer with Python finding fixes 2025-07-25 10:53:05 -07:00
Andy Lee
a4d66e95d8 fix: improve Python finding for Faiss in manylinux environment
- Use Development.Module instead of Development component
- Find Python in main Faiss CMakeLists.txt before python subdirectory
- Add debug output to trace Python variable passing
- Set Python_FIND_VIRTUALENV=ONLY for Faiss
2025-07-25 10:52:34 -07:00
Andy Lee
cf58b3e31b fix: re-enable Faiss Python bindings and improve Python finding
- Re-enable FAISS_ENABLE_PYTHON since we need the Python bindings
- Use Development.Module component for better compatibility
- Pass Python information to Faiss through CMake cache variables
- Add CMAKE_PREFIX_PATH= to help CMake find Python in manylinux
2025-07-25 10:45:07 -07:00
Andy Lee
e9c2ca7936 fix: remove cmake.verbose from scikit-build config
- cmake.verbose is not allowed when minimum-version is set to 0.10 or higher
- This was causing build failures in cibuildwheel
2025-07-25 10:37:06 -07:00
Andy Lee
dab154a77b fix: disable Faiss Python bindings to avoid CMake Python finding issues
- Set FAISS_ENABLE_PYTHON to OFF since we use our own Cython bindings
- This avoids the CMake Python finding issues in manylinux environments
- Simplify CMakeLists.txt by removing unnecessary Python finding logic
- Keep swig installation for other potential uses
2025-07-25 10:33:31 -07:00
Andy Lee
13413dfae5 fix: improve Python detection in manylinux environment
- Modify faiss CMakeLists.txt to try both FindPython and FindPython3
- Add scikit-build configuration to help with Python detection
- Simplify Linux environment variables in cibuildwheel
- Add CMake Python detection before faiss configuration
2025-07-25 10:27:04 -07:00
Andy Lee
0543cc9816 fix: add missing dependencies for CI builds
- Add libomp for macOS to fix OpenMP linking error
- Add python3-devel for Linux (though may not be needed in manylinux)
- Install numpy before building to satisfy faiss CMake requirements
- Add before-build configuration to cibuildwheel
2025-07-25 10:20:02 -07:00
Andy Lee
fb53ed9a0e fix: use manylinux2014 for Colab compatibility
- Switch from manylinux_2_28 to manylinux2014 (provides manylinux_2_17)
- This should produce wheels compatible with manylinux_2_35_x86_64 requirement
- Update package manager from dnf to yum for CentOS 7
- Use cmake3 with symlink for compatibility
2025-07-25 10:15:05 -07:00
Andy Lee
015f43733a fix: adjust configuration for Colab compatibility
- Remove Windows support as not needed
- Move Python finding hints to cibuildwheel environment variables
- Keep pyproject.toml clean to avoid breaking normal builds
- Target manylinux_2_28 for better Colab compatibility
2025-07-25 10:09:21 -07:00
Andy Lee
2957c8bf5a docs: add manylinux build strategy documentation 2025-07-25 10:03:57 -07:00
Andy Lee
a73194c3f6 feat: simplify build using cibuildwheel with standard configuration
- Add Python_FIND_VIRTUALENV hints to pyproject.toml for CMake
- Create standardized cibuildwheel workflow using manylinux_2_28
- Simplify system dependency installation
- Add global cibuildwheel configuration in root pyproject.toml
- Create streamlined test workflow for manylinux compatibility
2025-07-25 10:03:23 -07:00
133 changed files with 9490 additions and 17438 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text

View File

@@ -5,8 +5,7 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
jobs:
build:
uses: ./.github/workflows/build-reusable.yml
uses: ./.github/workflows/build-reusable.yml

171
.github/workflows/build-cibuildwheel.yml vendored Normal file
View File

@@ -0,0 +1,171 @@
name: Build with cibuildwheel
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest] # Focus on Linux/manylinux for Colab compatibility
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
# Build pure Python packages separately
- name: Build pure Python packages (leann-core, leann)
if: matrix.os == 'ubuntu-latest' # Only build once
run: |
python -m pip install --upgrade pip build
python -m build packages/leann-core --outdir wheelhouse/
python -m build packages/leann --outdir wheelhouse/
- name: Build leann-backend-hnsw wheels
uses: pypa/cibuildwheel@v2.20.0
with:
package-dir: packages/leann-backend-hnsw
output-dir: wheelhouse
env:
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-*
CIBW_SKIP: "*-win32 *-manylinux_i686 pp* *musllinux*"
# Use manylinux_2_35 for Colab compatibility with modern features
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_35
CIBW_MANYLINUX_AARCH64_IMAGE: manylinux_2_35
# Linux dependencies - using dnf for manylinux_2_35 (based on AlmaLinux 9)
CIBW_BEFORE_ALL_LINUX: |
dnf install -y epel-release
dnf install -y gcc-c++ boost-devel zeromq-devel openblas-devel cmake python3-devel
# Install numpy before building
CIBW_BEFORE_BUILD: |
pip install numpy
pip install --upgrade pip setuptools wheel
CIBW_BEFORE_BUILD_LINUX: |
pip install numpy
pip install --upgrade pip setuptools wheel swig
CIBW_BEFORE_ALL_MACOS: |
brew install boost zeromq openblas cmake libomp
# Pre-install test dependencies to avoid compilation
CIBW_BEFORE_TEST: |
pip install --only-binary :all: "pyzmq>=23.0.0"
# Test command to verify the wheel works
CIBW_TEST_COMMAND: |
python -c "import leann_backend_hnsw; print('HNSW backend imported successfully')"
# Skip problematic configurations
CIBW_TEST_SKIP: "*-macosx_arm64" # Skip ARM64 tests on GitHub Actions
# Test dependencies
CIBW_TEST_REQUIRES: "pytest numpy"
# Environment variables for build
CIBW_ENVIRONMENT: |
CMAKE_BUILD_PARALLEL_LEVEL=8
Python_FIND_VIRTUALENV=ONLY
Python3_FIND_VIRTUALENV=ONLY
# Linux-specific environment variables
CIBW_ENVIRONMENT_LINUX: |
CMAKE_BUILD_PARALLEL_LEVEL=8
# macOS-specific environment variables
CIBW_ENVIRONMENT_MACOS: |
CMAKE_BUILD_PARALLEL_LEVEL=8
MACOSX_DEPLOYMENT_TARGET=11.0
CMAKE_OSX_DEPLOYMENT_TARGET=11.0
Python_FIND_VIRTUALENV=ONLY
Python3_FIND_VIRTUALENV=ONLY
- name: Build leann-backend-diskann wheels
uses: pypa/cibuildwheel@v2.20.0
with:
package-dir: packages/leann-backend-diskann
output-dir: wheelhouse
env:
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313-*
CIBW_SKIP: "*-win32 *-manylinux_i686 pp* *musllinux*"
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_35
CIBW_MANYLINUX_AARCH64_IMAGE: manylinux_2_35
CIBW_BEFORE_ALL_LINUX: |
dnf install -y epel-release
dnf install -y gcc-c++ boost-devel zeromq-devel openblas-devel cmake python3-devel
# Install numpy before building
CIBW_BEFORE_BUILD: |
pip install numpy
pip install --upgrade pip setuptools wheel
CIBW_BEFORE_BUILD_LINUX: |
pip install numpy
pip install --upgrade pip setuptools wheel swig
CIBW_BEFORE_ALL_MACOS: |
brew install boost zeromq openblas cmake libomp
# Pre-install test dependencies to avoid compilation
CIBW_BEFORE_TEST: |
pip install --only-binary :all: "pyzmq>=23.0.0"
# Test command to verify the wheel works
CIBW_TEST_COMMAND: |
python -c "import leann_backend_diskann; print('DiskANN backend imported successfully')"
# Skip problematic configurations
CIBW_TEST_SKIP: "*-macosx_arm64" # Skip ARM64 tests on GitHub Actions
# Test dependencies - avoid pyzmq due to manylinux2014 compatibility issues
CIBW_TEST_REQUIRES: "pytest numpy"
CIBW_ENVIRONMENT: |
CMAKE_BUILD_PARALLEL_LEVEL=8
Python_FIND_VIRTUALENV=ONLY
Python3_FIND_VIRTUALENV=ONLY
# Linux-specific environment variables
CIBW_ENVIRONMENT_LINUX: |
CMAKE_BUILD_PARALLEL_LEVEL=8
CMAKE_PREFIX_PATH=$VIRTUAL_ENV
Python_FIND_VIRTUALENV=ONLY
Python3_FIND_VIRTUALENV=ONLY
Python_FIND_STRATEGY=LOCATION
Python3_FIND_STRATEGY=LOCATION
Python_EXECUTABLE=$VIRTUAL_ENV/bin/python
Python3_EXECUTABLE=$VIRTUAL_ENV/bin/python
# macOS-specific environment variables
CIBW_ENVIRONMENT_MACOS: |
CMAKE_BUILD_PARALLEL_LEVEL=8
MACOSX_DEPLOYMENT_TARGET=11.0
CMAKE_OSX_DEPLOYMENT_TARGET=11.0
Python_FIND_VIRTUALENV=ONLY
Python3_FIND_VIRTUALENV=ONLY
- uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.os }}
path: ./wheelhouse/*.whl

View File

@@ -10,393 +10,304 @@ on:
default: ''
jobs:
lint:
name: Lint and Format Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: |
uv tool install ruff
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build:
needs: lint
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
fail-fast: false
matrix:
include:
- os: ubuntu-22.04
- os: ubuntu-latest
python: '3.9'
- os: ubuntu-22.04
container: 'quay.io/pypa/manylinux2014_x86_64'
- os: ubuntu-latest
python: '3.10'
- os: ubuntu-22.04
container: 'quay.io/pypa/manylinux2014_x86_64'
- os: ubuntu-latest
python: '3.11'
- os: ubuntu-22.04
container: 'quay.io/pypa/manylinux2014_x86_64'
- os: ubuntu-latest
python: '3.12'
- os: ubuntu-22.04
container: 'quay.io/pypa/manylinux2014_x86_64'
- os: ubuntu-latest
python: '3.13'
# ARM64 Linux builds
- os: ubuntu-24.04-arm
container: 'quay.io/pypa/manylinux2014_x86_64'
- os: macos-latest
python: '3.9'
- os: ubuntu-24.04-arm
container: ''
- os: macos-latest
python: '3.10'
- os: ubuntu-24.04-arm
container: ''
- os: macos-latest
python: '3.11'
- os: ubuntu-24.04-arm
container: ''
- os: macos-latest
python: '3.12'
- os: ubuntu-24.04-arm
container: ''
- os: macos-latest
python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14
python: '3.10'
- os: macos-14
python: '3.11'
- os: macos-14
python: '3.12'
- os: macos-14
python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15
python: '3.10'
- os: macos-15
python: '3.11'
- os: macos-15
python: '3.12'
- os: macos-15
python: '3.13'
- os: macos-13
python: '3.9'
- os: macos-13
python: '3.10'
- os: macos-13
python: '3.11'
- os: macos-13
python: '3.12'
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
container: ''
runs-on: ${{ matrix.os }}
container: ${{ matrix.container }}
steps:
- uses: actions/checkout@v5
# For manylinux2014 compatibility, we'll handle checkout differently
- uses: actions/checkout@v4
if: matrix.container == ''
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
# Manual checkout for containers to avoid Node.js compatibility issues
- name: Manual checkout in container
if: matrix.container != ''
run: |
# Install git if not available
yum install -y git || true
# Configure git to handle the directory ownership issue
git config --global --add safe.directory ${GITHUB_WORKSPACE}
git config --global --add safe.directory /__w/LEANN/LEANN
git config --global --add safe.directory /github/workspace
git config --global --add safe.directory $(pwd)
# Clone the repository manually in the container
git init
git remote add origin https://github.com/${GITHUB_REPOSITORY}.git
# Fetch the appropriate ref
if [ -n "${{ inputs.ref }}" ]; then
git fetch --depth=1 origin ${{ inputs.ref }}
else
git fetch --depth=1 origin ${GITHUB_SHA}
fi
git checkout FETCH_HEAD
# Initialize submodules
git submodule update --init --recursive
- name: Setup Python (macOS and regular Ubuntu)
if: matrix.container == ''
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
- name: Setup Python (manylinux container)
if: matrix.container != ''
run: |
# Use the pre-installed Python version in manylinux container
# Convert Python version format (3.9 -> 39, 3.10 -> 310, etc.)
PY_VER=$(echo "${{ matrix.python }}" | sed 's/\.//g')
/opt/python/cp${PY_VER}-*/bin/python -m pip install --upgrade pip
# Create symlinks for convenience
ln -sf /opt/python/cp${PY_VER}-*/bin/python /usr/local/bin/python
ln -sf /opt/python/cp${PY_VER}-*/bin/pip /usr/local/bin/pip
- name: Install uv (macOS and regular Ubuntu)
if: matrix.container == ''
uses: astral-sh/setup-uv@v4
- name: Install uv (manylinux container)
if: matrix.container != ''
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install system dependencies (Ubuntu - regular)
if: runner.os == 'Linux' && matrix.container == ''
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
patchelf
# Debug: Show system information
echo "🔍 System Information:"
echo "Architecture: $(uname -m)"
echo "OS: $(uname -a)"
echo "CPU info: $(lscpu | head -5)"
# Install math library based on architecture
ARCH=$(uname -m)
echo "🔍 Setting up math library for architecture: $ARCH"
if [[ "$ARCH" == "x86_64" ]]; then
# Install Intel MKL for DiskANN on x86_64
echo "📦 Installing Intel MKL for x86_64..."
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
echo "✅ Intel MKL installed for x86_64"
# Debug: Check MKL installation
echo "🔍 MKL Installation Check:"
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
elif [[ "$ARCH" == "aarch64" ]]; then
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
echo "📦 Installing OpenBLAS for ARM64..."
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
echo "✅ OpenBLAS installed for ARM64"
# Debug: Check OpenBLAS installation
echo "🔍 OpenBLAS Installation Check:"
dpkg -l | grep openblas || echo "OpenBLAS package not found"
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (manylinux container)
if: runner.os == 'Linux' && matrix.container != ''
run: |
# manylinux2014 uses yum instead of apt
# Update yum cache first
yum clean all
yum makecache
# Install EPEL repository
yum install -y epel-release || true
# Update cache again after EPEL
yum makecache || true
# Install development packages
# Note: Some packages might have different names in CentOS 7
yum install -y \
gcc-c++ \
boost-devel \
protobuf-compiler \
protobuf-devel \
zeromq-devel \
pkgconfig \
openblas-devel \
cmake || {
echo "Some packages failed to install, trying alternatives..."
# Try alternative package names
yum install -y libzmq3-devel || true
yum install -y libzmq-devel || true
}
# Install optional packages that might not be available
yum install -y libaio-devel || echo "libaio-devel not available, continuing..."
# Verify zmq installation and create pkg-config file if needed
if [ ! -f /usr/lib64/pkgconfig/libzmq.pc ] && [ ! -f /usr/lib/pkgconfig/libzmq.pc ]; then
echo "Creating libzmq.pc file..."
mkdir -p /usr/lib64/pkgconfig
cat > /usr/lib64/pkgconfig/libzmq.pc << 'EOF'
prefix=/usr
exec_prefix=${prefix}
libdir=${exec_prefix}/lib64
includedir=${prefix}/include
Name: libzmq
Description: ZeroMQ library
Version: 4.1.4
Libs: -L${libdir} -lzmq
Cflags: -I${includedir}
EOF
fi
# Debug: Show final library paths
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
# Update PKG_CONFIG_PATH
echo "PKG_CONFIG_PATH=/usr/lib64/pkgconfig:/usr/lib/pkgconfig:$PKG_CONFIG_PATH" >> $GITHUB_ENV
# Build tools are pre-installed in manylinux
# MKL is more complex in container, skip for now and use OpenBLAS
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
# Don't install LLVM, use system clang for better compatibility
brew install libomp boost protobuf zeromq
brew install llvm libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
if [[ -n "${{ matrix.container }}" ]]; then
# In manylinux container, use regular pip
pip install scikit-build-core numpy swig Cython pybind11 auditwheel
else
uv pip install --system delocate
# Regular environment, use uv
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
fi
- name: Set macOS environment variables
if: runner.os == 'macOS'
run: |
# Use brew --prefix to automatically detect Homebrew installation path
HOMEBREW_PREFIX=$(brew --prefix)
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
# Set compiler flags for OpenMP (required for both backends)
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
- name: Build packages
run: |
# Choose build command based on environment
if [[ -n "${{ matrix.container }}" ]]; then
BUILD_CMD="pip wheel . --no-deps -w dist"
else
BUILD_CMD="uv build --wheel --python python"
fi
# Build core (platform independent)
cd packages/leann-core
uv build
cd ../..
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
cd packages/leann-core
if [[ -n "${{ matrix.container }}" ]]; then
pip wheel . --no-deps -w dist
else
uv build
fi
cd ../..
fi
# Build HNSW backend
cd packages/leann-backend-hnsw
if [[ "${{ matrix.os }}" == macos-* ]]; then
# Use system clang for better compatibility
export CC=clang
export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.0
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ $BUILD_CMD
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
eval $BUILD_CMD
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [[ "${{ matrix.os }}" == macos-* ]]; then
# Use system clang for better compatibility
export CC=clang
export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
# But Homebrew libraries on each macOS version require matching minimum version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ $BUILD_CMD
else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
eval $BUILD_CMD
fi
cd ../..
# Build meta package (platform independent)
cd packages/leann
uv build
cd ../..
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
cd packages/leann
if [[ -n "${{ matrix.container }}" ]]; then
pip wheel . --no-deps -w dist
else
uv build
fi
cd ../..
fi
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
# Show what platform auditwheel will use
auditwheel show dist/*.whl || true
# Let auditwheel auto-detect the appropriate manylinux tag
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
# Show what platform auditwheel will use
auditwheel show dist/*.whl || true
# Let auditwheel auto-detect the appropriate manylinux tag
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
fi
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Install built packages for testing
run: |
# Create a virtual environment with the correct Python version
uv venv --python ${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate
# Install packages using --find-links to prioritize local builds
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
# Install test dependencies using extras
uv pip install -e ".[test]"
- name: Run tests with pytest
env:
CI: true
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
run: |
source .venv/bin/activate || source .venv/Scripts/activate
pytest tests/ -v --tb=short
- name: Run sanity checks (optional)
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
# Run distance function tests if available
if [ -f test/sanity_checks/test_distance_functions.py ]; then
echo "Running distance function sanity checks..."
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
fi
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/
arch-smoke:
name: Arch Linux smoke test (install & import)
needs: build
runs-on: ubuntu-latest
container:
image: archlinux:latest
steps:
- name: Prepare system
run: |
pacman -Syu --noconfirm
pacman -S --noconfirm python python-pip gcc git zlib openssl
- name: Download ALL wheel artifacts from this run
uses: actions/download-artifact@v5
with:
# Don't specify name, download all artifacts
path: ./wheels
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Create virtual environment and install wheels
run: |
uv venv
source .venv/bin/activate || source .venv/Scripts/activate
uv pip install --find-links wheels leann-core
uv pip install --find-links wheels leann-backend-hnsw
uv pip install --find-links wheels leann-backend-diskann
uv pip install --find-links wheels leann
- name: Import & tiny runtime check
env:
OMP_NUM_THREADS: 1
MKL_NUM_THREADS: 1
run: |
source .venv/bin/activate || source .venv/Scripts/activate
python - <<'PY'
import leann
import leann_backend_hnsw as h
import leann_backend_diskann as d
from leann import LeannBuilder, LeannSearcher
b = LeannBuilder(backend_name="hnsw")
b.add_text("hello arch")
b.build_index("arch_demo.leann")
s = LeannSearcher("arch_demo.leann")
print("search:", s.search("hello", top_k=1))
PY
name: packages-${{ matrix.os }}-py${{ matrix.python }}${{ matrix.container && '-manylinux' || '' }}
path: packages/*/dist/

View File

@@ -1,19 +0,0 @@
name: Link Check
on:
push:
branches: [ main, master ]
pull_request:
schedule:
- cron: "0 3 * * 1"
jobs:
link-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -7,6 +7,11 @@ on:
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
use_cibuildwheel:
description: 'Use cibuildwheel for better compatibility (recommended for Colab)'
required: false
type: boolean
default: false
jobs:
update-version:
@@ -16,79 +21,75 @@ jobs:
contents: write
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
# Remove 'v' prefix if present for validation
VERSION_CLEAN="${{ inputs.version }}"
VERSION_CLEAN="${VERSION_CLEAN#v}"
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format"
exit 1
fi
echo "✅ Version format valid: ${{ inputs.version }}"
echo "✅ Version format valid"
- name: Update versions and push
id: push
run: |
# Check current version
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD)
else
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA"
fi
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
echo "✅ Pushed version update: $COMMIT_SHA"
build-packages:
name: Build packages
build-packages-reusable:
name: Build packages (Standard)
needs: update-version
if: ${{ !inputs.use_cibuildwheel }}
uses: ./.github/workflows/build-reusable.yml
with:
ref: 'main'
ref: ${{ needs.update-version.outputs.commit-sha }}
build-packages-cibuildwheel:
name: Build packages (cibuildwheel)
needs: update-version
if: ${{ inputs.use_cibuildwheel }}
uses: ./.github/workflows/build-cibuildwheel.yml
with:
ref: ${{ needs.update-version.outputs.commit-sha }}
publish:
name: Publish and Release
needs: [update-version, build-packages]
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
needs: [update-version, build-packages-reusable, build-packages-cibuildwheel]
if: always() && needs.update-version.result == 'success' && (needs.build-packages-reusable.result == 'success' || needs.build-packages-cibuildwheel.result == 'success')
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
ref: 'main'
ref: ${{ needs.update-version.outputs.commit-sha }}
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist-artifacts
- name: Collect packages
run: |
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
@@ -98,32 +99,20 @@ jobs:
echo "❌ PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create release
run: |
# Check if tag already exists
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
else
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}"
fi
# Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
else
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
echo "✅ Created GitHub release v${{ inputs.version }}"
fi
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

60
.github/workflows/test-manylinux.yml vendored Normal file
View File

@@ -0,0 +1,60 @@
name: Test Manylinux Build
on:
workflow_dispatch:
pull_request:
branches: [ main ]
paths:
- '.github/workflows/**'
- 'packages/**'
- 'pyproject.toml'
push:
branches:
- 'fix/manylinux-*'
- 'test/build-*'
jobs:
build:
uses: ./.github/workflows/build-cibuildwheel.yml
test-install:
needs: build
runs-on: ubuntu-22.04 # Simulating Colab environment
strategy:
matrix:
python-version: ['3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Download artifacts
uses: actions/download-artifact@v4
with:
pattern: wheels-*
path: dist
merge-multiple: true
- name: Test installation
run: |
python -m pip install --upgrade pip
# Find and install the appropriate wheels
pip install dist/leann_core-*.whl
pip install dist/leann_backend_hnsw-*manylinux*.whl
pip install dist/leann-*.whl
- name: Test import
run: |
python -c "
import leann
from leann import LeannBuilder, LeannSearcher
print('Successfully imported leann modules')
# Quick functionality test
builder = LeannBuilder(backend_name='hnsw')
builder.add_text('Test document')
print('LeannBuilder created and used successfully')
"

29
.gitignore vendored
View File

@@ -9,7 +9,7 @@ demo/indices/
outputs/
*.pkl
*.pdf
*.idx
*.idx
*.map
.history/
lm_eval.egg-info/
@@ -18,11 +18,9 @@ demo/experiment_results/**/*.json
*.eml
*.emlx
*.json
!.vscode/*.json
*.sh
*.txt
!CMakeLists.txt
!llms.txt
latency_breakdown*.json
experiment_results/eval_results/diskann/*.json
aws/
@@ -36,15 +34,11 @@ build/
nprobe_logs/
micro/results
micro/contriever-INT8
data/*
!data/2501.14312v1 (1).pdf
!data/2506.08276v1.pdf
!data/PrideandPrejudice.txt
!data/huawei_pangu.md
!data/ground_truth/
!data/indices/
!data/queries/
!data/.gitattributes
examples/data/*
!examples/data/2501.14312v1 (1).pdf
!examples/data/2506.08276v1.pdf
!examples/data/PrideandPrejudice.txt
!examples/data/README.md
*.qdstrm
benchmark_results/
results/
@@ -91,13 +85,4 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
paru-bin/
CLAUDE.md
CLAUDE.local.md
.claude/*.local.*
.claude/local/*
benchmarks/data/
batchtest.py

4
.gitmodules vendored
View File

@@ -14,7 +14,3 @@
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git
[submodule "packages/astchunk-leann"]
path = packages/astchunk-leann
url = git@github.com:yichuan-w/astchunk-leann.git
branch = main

View File

@@ -1,17 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7 # Fixed version to match pyproject.toml
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

View File

@@ -1,5 +0,0 @@
{
"recommendations": [
"charliermarsh.ruff",
]
}

22
.vscode/settings.json vendored
View File

@@ -1,22 +0,0 @@
{
"python.defaultInterpreterPath": ".venv/bin/python",
"python.terminal.activateEnvironment": true,
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",
"source.fixAll": "explicit"
},
"editor.insertSpaces": true,
"editor.tabSize": 4
},
"ruff.enable": true,
"files.watcherExclude": {
"**/.venv/**": true,
"**/__pycache__/**": true,
"**/*.egg-info/**": true,
"**/build/**": true,
"**/dist/**": true
}
}

View File

@@ -0,0 +1,50 @@
# Manylinux Build Strategy
## Problem
Google Colab requires wheels compatible with `manylinux_2_35_x86_64` or earlier. Our previous builds were producing `manylinux_2_39_x86_64` wheels, which are incompatible.
## Solution
We're using `cibuildwheel` with `manylinux_2_35` images to build wheels that are compatible with Google Colab while maintaining modern toolchain features.
### Key Changes
1. **cibuildwheel Configuration**
- Using `manylinux2014` images (provides `manylinux_2_17` compatibility)
- Using `yum` package manager (CentOS 7 based)
- Installing `cmake3` and creating symlink for compatibility
2. **Build Matrix**
- Python versions: 3.9, 3.10, 3.11, 3.12, 3.13
- Platforms: Linux (x86_64), macOS
- No Windows support (not required)
3. **Dependencies**
- Linux: gcc-c++, boost-devel, zeromq-devel, openblas-devel, cmake3
- macOS: boost, zeromq, openblas, cmake (via Homebrew)
4. **Environment Variables**
- `CMAKE_BUILD_PARALLEL_LEVEL=8`: Speed up builds
- `Python_FIND_VIRTUALENV=ONLY`: Help CMake find Python in cibuildwheel env
- `Python3_FIND_VIRTUALENV=ONLY`: Alternative variable for compatibility
## Testing Strategy
1. **CI Pipeline**: `test-manylinux.yml`
- Triggers on PR to main, manual dispatch, or push to `fix/manylinux-*` branches
- Builds wheels using cibuildwheel
- Tests installation on Ubuntu 22.04 (simulating Colab)
2. **Local Testing**
```bash
# Download built wheels
# Test in fresh environment
python -m venv test_env
source test_env/bin/activate
pip install leann_core-*.whl leann_backend_hnsw-*manylinux*.whl leann-*.whl
python -c "from leann import LeannBuilder; print('Success!')"
```
## References
- [cibuildwheel documentation](https://cibuildwheel.readthedocs.io/)
- [manylinux standards](https://github.com/pypa/manylinux)
- [PEP 599 - manylinux2014](https://peps.python.org/pep-0599/)

563
README.md
View File

@@ -3,27 +3,20 @@
</p>
<p align="center">
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
@@ -33,170 +26,49 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p>
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#-storage-comparison)
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Installation
### 📦 Prerequisites: Install uv
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
> `pip leann` coming soon!
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
### 🚀 Quick Install
Clone the repository to access all examples and try amazing applications,
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
```
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
```bash
uv venv
source .venv/bin/activate
uv pip install leann
```
<!--
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups). -->
<details>
<summary>
<strong>🔧 Build from Source (Recommended for development)</strong>
</summary>
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
git clone git@github.com:yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
```
**macOS:**
Note: DiskANN requires MacOS 13.3 or later.
```bash
brew install libomp boost protobuf zeromq pkgconf
uv sync --extra diskann
brew install llvm libomp boost protobuf zeromq pkgconf
# Install with HNSW backend (default, recommended for most users)
# Install uv first if you don't have it:
# curl -LsSf https://astral.sh/uv/install.sh | sh
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
```
**Linux (Ubuntu/Debian):**
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
**Linux:**
```bash
sudo apt-get update && sudo apt-get install -y \
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
libmkl-full-dev
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
uv sync --extra diskann
# Install with HNSW backend (default, recommended for most users)
uv sync
```
**Linux (Arch Linux):**
```bash
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
boost boost-libs protobuf abseil-cpp libaio zeromq
**Ollama Setup (Recommended for full privacy):**
# For MKL in DiskANN
sudo pacman -S --needed base-devel git
git clone https://aur.archlinux.org/paru-bin.git
cd paru-bin && makepkg -si
paru -S intel-oneapi-mkl intel-oneapi-compiler
source /opt/intel/oneapi/setvars.sh
> *You can skip this installation if you only want to use OpenAI API for generation.*
uv sync --extra diskann
```
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
```bash
sudo dnf groupinstall -y "Development Tools"
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
# For MKL in DiskANN
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
source /opt/intel/oneapi/setvars.sh
uv sync --extra diskann
```
</details>
## Quick Start
Our declarative API makes RAG as easy as writing a config file.
Check out [demo.ipynb](demo.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
from pathlib import Path
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their bananacrocodile hybrid back")
builder.build_index(INDEX_PATH)
# Search
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Chat with your data
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
```
## RAG on Everything!
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
### Generation Model Setup
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
<details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
Set your OpenAI API key as an environment variable:
```bash
export OPENAI_API_KEY="your-api-key-here"
```
</details>
<details>
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
**macOS:**
@@ -208,7 +80,6 @@ ollama pull llama3.2:1b
```
**Linux:**
```bash
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
@@ -220,55 +91,45 @@ ollama serve &
ollama pull llama3.2:1b
```
</details>
## Quick Start in 30s
Our declarative API makes RAG as easy as writing a config file.
[Try in this ipynb file →](demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
## ⭐ Flexible Configuration
```python
from leann.api import LeannBuilder, LeannSearcher, LeannChat
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
# 1. Build the index (no embeddings stored!)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("C# is a powerful programming language")
builder.add_text("Python is a powerful programming language and it is very popular")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
builder.build_index("knowledge.leann")
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
<details>
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
# 3. Chat with LEANN using retrieved results
llm_config = {
"type": "ollama",
"model": "llama3.2:1b"
}
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
```bash
# Core Parameters (General preprocessing for all examples)
--index-dir DIR # Directory to store the index (default: current directory)
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
--max-items N # Limit data preprocessing (default: -1, process all data)
--force-rebuild # Force rebuild index even if it exists
# Embedding Parameters
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
# Search Parameters
--top-k N # Number of results to retrieve (default: 20)
--search-complexity N # Search complexity for graph traversal (default: 32)
# Chunking Parameters
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
# Index Building Parameters
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
--graph-degree N # Graph degree for index construction (default: 32)
--build-complexity N # Build complexity for index construction (default: 64)
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
response = chat.ask(
"Compare the two retrieved programming languages and say which one is more popular today.",
top_k=2,
)
```
</details>
## RAG on Everything!
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
@@ -276,71 +137,51 @@ Ask questions directly about your personal PDFs, documents, and any directory co
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
</p>
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
```bash
source .venv/bin/activate # Don't forget to activate the virtual environment
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
# Drop your PDFs, .txt, .md files into examples/data/
uv run ./examples/main_cli_example.py
```
<details>
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
#### Parameters
```bash
--data-dir DIR # Directory containing documents to process (default: data)
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
```
# Or use python directly
source .venv/bin/activate
python ./examples/main_cli_example.py
```
#### Example Commands
```bash
# Process all documents with larger chunks for academic papers
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
# Filter only markdown and Python files with smaller chunks
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
# Enable AST-aware chunking for code files
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
# Or use the specialized code RAG for better code understanding
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
```
</details>
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
> **Note:** The examples below currently support macOS only. Windows support coming soon.
<p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p>
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
```
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
#### Parameters
```bash
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
--include-html # Include HTML content in processing (useful for newsletters)
```
# Use default mail path (works for most macOS setups)
python examples/mail_reader_leann.py
#### Example Commands
```bash
# Search work emails from a specific account
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
# Run with custom index directory
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
# Find all receipts and order confirmations (includes HTML)
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
# Process all emails (may take time but indexes everything)
python examples/mail_reader_leann.py --max-emails -1
# Limit number of emails processed (useful for testing)
python examples/mail_reader_leann.py --max-emails 1000
# Run a single query
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
```
</details>
@@ -354,32 +195,32 @@ Once the index is built, you can ask questions like:
- "Show me emails about travel expenses"
</details>
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
<p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
</p>
```bash
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
```
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
<details>
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
#### Parameters
```bash
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
```
# Use default Chrome profile (auto-finds all profiles)
python examples/google_history_reader_leann.py
#### Example Commands
```bash
# Search academic research from your browsing history
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
# Run with custom index directory
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
# Track competitor analysis across work profile
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
# Limit number of history entries processed (useful for testing)
python examples/google_history_reader_leann.py --max-entries 500
# Run a single query
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
```
</details>
@@ -419,7 +260,7 @@ Once the index is built, you can ask questions like:
</p>
```bash
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
```
**400K messages → 64MB storage** Search years of chat history in any language.
@@ -427,13 +268,7 @@ python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
```bash
brew install sunnyyoung/repo/wechattweak-cli
```
or install it manually (if you have issues with Homebrew):
First, you need to install the WeChat exporter:
```bash
sudo packages/wechat-exporter/wechattweak-cli install
@@ -442,28 +277,30 @@ sudo packages/wechat-exporter/wechattweak-cli install
**Troubleshooting:**
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
- **Export errors**: If you encounter the error below, try restarting WeChat
```bash
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
```
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
</details>
<details>
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
#### Parameters
```bash
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
--force-export # Force re-export even if data exists
```
# Use default settings (recommended for first run)
python examples/wechat_history_reader_leann.py
#### Example Commands
```bash
# Search for travel plans discussed in group chats
python -m apps.wechat_rag --query "travel plans" --max-items 10000
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
# Re-export and search recent chats (useful after new messages)
python -m apps.wechat_rag --force-export --query "work schedule"
# Run with custom index directory
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
# Limit number of chat entries processed (useful for testing)
python examples/wechat_history_reader_leann.py --max-entries 1000
# Run a single query
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
```
</details>
@@ -477,70 +314,17 @@ Once the index is built, you can ask questions like:
</details>
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
</details>
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
**Key features:**
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
- 📚 **Context-aware assistance** for debugging and development
- 🚀 **Zero-config setup** with automatic language detection
```bash
# Install LEANN globally for MCP integration
uv tool install leann-core --with leann
claude mcp add --scope user leann-server -- leann_mcp
# Setup is automatic - just start using Claude Code!
```
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
![LEANN MCP Integration](assets/mcp_leann.png)
**🔥 Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
## 🖥️ Command Line Interface
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
### Installation
If you followed the Quick Start, `leann` is already installed in your virtual environment:
```bash
source .venv/bin/activate
leann --help
```
# Build an index from documents
leann build my-docs --docs ./documents
**To make it globally available:**
```bash
# Install the LEANN CLI globally using uv tool
uv tool install leann-core --with leann
# Now you can use leann from anywhere without activating venv
leann --help
```
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
### Usage Examples
```bash
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
leann build my-docs --docs ./your_documents
# Search your documents
# Search your documents
leann search my-docs "machine learning concepts"
# Interactive chat with your documents
@@ -548,36 +332,30 @@ leann ask my-docs --interactive
# List all your indexes
leann list
# Remove an index
leann remove my-docs
```
**Key CLI features:**
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
- Smart text chunking with overlap for all other content
- Auto-detects document formats (PDF, TXT, MD, DOCX)
- Smart text chunking with overlap
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
- Organized index storage in `.leann/indexes/` (project-local)
- Organized index storage in `~/.leann/indexes/`
- Support for advanced search parameters
<details>
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help`, `leann list --help`, `leann remove --help` to get the complete CLI reference.
**Build Command:**
```bash
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
Options:
--backend {hnsw,diskann} Backend to use (default: hnsw)
--embedding-model MODEL Embedding model (default: facebook/contriever)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
--recompute / --no-recompute Enable recomputation (default: true)
--graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64)
--force Force rebuild existing index
--compact Use compact storage (default: true)
--recompute Enable recomputation (default: true)
```
**Search Command:**
@@ -585,9 +363,9 @@ Options:
leann search INDEX_NAME QUERY [OPTIONS]
Options:
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
--top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64)
--recompute-embeddings Use recomputation for highest accuracy
--pruning-strategy {global,local,proportional}
```
@@ -602,73 +380,8 @@ Options:
--top-k N Retrieval count (default: 20)
```
**List Command:**
```bash
leann list
# Lists all indexes across all projects with status indicators:
# ✅ - Index is complete and ready to use
# ❌ - Index is incomplete or corrupted
# 📁 - CLI-created index (in .leann/indexes/)
# 📄 - App-created index (*.leann.meta.json files)
```
**Remove Command:**
```bash
leann remove INDEX_NAME [OPTIONS]
Options:
--force, -f Force removal without confirmation
# Smart removal: automatically finds and safely removes indexes
# - Shows all matching indexes across projects
# - Requires confirmation for cross-project removal
# - Interactive selection when multiple matches found
# - Supports both CLI and app-created indexes
```
</details>
## 🚀 Advanced Features
### 🎯 Metadata Filtering
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
```python
# Add metadata during indexing
builder.add_text(
"def authenticate_user(token): ...",
metadata={"file_extension": ".py", "lines_of_code": 25}
)
# Search with filters
results = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
### 🔍 Grep Search
For exact text matching instead of semantic search, use the `use_grep` parameter:
```python
# Exact text search
results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
```
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
📖 **[Complete grep search guide →](docs/grep_search.md)**
## 🏗️ Architecture & How It Works
<p align="center">
@@ -679,21 +392,17 @@ results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
**Core techniques:**
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:**
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
**Backends:** DiskANN or HNSW - pick what works for your data size.
## Benchmarks
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
### 📊 Storage Comparison
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
### Storage Comparison
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|--------|-------------|------------|-------------|--------------|---------------|
@@ -707,8 +416,8 @@ results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
```bash
uv pip install -e ".[dev]" # Install dev dependencies
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
@@ -720,22 +429,22 @@ If you find Leann useful, please cite:
```bibtex
@misc{wang2025leannlowstoragevectorindex,
title={LEANN: A Low-Storage Vector Index},
title={LEANN: A Low-Storage Vector Index},
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
year={2025},
eprint={2506.08276},
archivePrefix={arXiv},
primaryClass={cs.DB},
url={https://arxiv.org/abs/2506.08276},
url={https://arxiv.org/abs/2506.08276},
}
```
## ✨ [Detailed Features →](docs/features.md)
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
## 🤝 [Contributing →](docs/contributing.md)
## [FAQ →](docs/faq.md)
## [FAQ →](docs/faq.md)
## 📈 [Roadmap →](docs/roadmap.md)
@@ -746,18 +455,9 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
---
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
We welcome more contributors! Feel free to open issues or submit PRs.
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=yichuan-w/LEANN&type=Date)](https://www.star-history.com/#yichuan-w/LEANN&Date)
<p align="center">
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
</p>
@@ -765,3 +465,4 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center">
Made with ❤️ by the Leann team
</p>

View File

View File

@@ -1,342 +0,0 @@
"""
Base class for unified RAG examples interface.
Provides common parameters and functionality for all RAG examples.
"""
import argparse
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.registry import register_project_directory
dotenv.load_dotenv()
class BaseRAGExample(ABC):
"""Base class for all RAG examples with unified interface."""
def __init__(
self,
name: str,
description: str,
default_index_name: str,
):
self.name = name
self.description = description
self.default_index_name = default_index_name
self.parser = self._create_parser()
def _create_parser(self) -> argparse.ArgumentParser:
"""Create argument parser with common parameters."""
parser = argparse.ArgumentParser(
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
)
# Core parameters (all examples share these)
core_group = parser.add_argument_group("Core Parameters")
core_group.add_argument(
"--index-dir",
type=str,
default=f"./{self.default_index_name}",
help=f"Directory to store the index (default: ./{self.default_index_name})",
)
core_group.add_argument(
"--query",
type=str,
default=None,
help="Query to run (if not provided, will run in interactive mode)",
)
# Allow subclasses to override default max_items
max_items_default = getattr(self, "max_items_default", -1)
core_group.add_argument(
"--max-items",
type=int,
default=max_items_default,
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
)
core_group.add_argument(
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
)
# Embedding parameters
embedding_group = parser.add_argument_group("Embedding Parameters")
# Allow subclasses to override default embedding_model
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
embedding_group.add_argument(
"--embedding-model",
type=str,
default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
)
embedding_group.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
llm_group.add_argument(
"--llm",
type=str,
default="openai",
choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend: openai, ollama, or hf (default: openai)",
)
llm_group.add_argument(
"--llm-model",
type=str,
default=None,
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
)
llm_group.add_argument(
"--llm-host",
type=str,
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
)
llm_group.add_argument(
"--thinking-budget",
type=str,
choices=["low", "medium", "high"],
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
ast_group.add_argument(
"--use-ast-chunking",
action="store_true",
help="Enable AST-aware chunking for code files (requires astchunk)",
)
ast_group.add_argument(
"--ast-chunk-size",
type=int,
default=512,
help="Maximum characters per AST chunk (default: 512)",
)
ast_group.add_argument(
"--ast-chunk-overlap",
type=int,
default=64,
help="Overlap between AST chunks (default: 64)",
)
ast_group.add_argument(
"--code-file-extensions",
nargs="+",
default=None,
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
)
ast_group.add_argument(
"--ast-fallback-traditional",
action="store_true",
default=True,
help="Fall back to traditional chunking if AST chunking fails (default: True)",
)
# Search parameters
search_group = parser.add_argument_group("Search Parameters")
search_group.add_argument(
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
)
search_group.add_argument(
"--search-complexity",
type=int,
default=32,
help="Search complexity for graph traversal (default: 64)",
)
# Index building parameters
index_group = parser.add_argument_group("Index Building Parameters")
index_group.add_argument(
"--backend-name",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
help="Backend to use for index (default: hnsw)",
)
index_group.add_argument(
"--graph-degree",
type=int,
default=32,
help="Graph degree for index construction (default: 32)",
)
index_group.add_argument(
"--build-complexity",
type=int,
default=64,
help="Build complexity for index construction (default: 64)",
)
index_group.add_argument(
"--no-compact",
action="store_true",
help="Disable compact index storage",
)
index_group.add_argument(
"--no-recompute",
action="store_true",
help="Disable embedding recomputation",
)
# Add source-specific parameters
self._add_specific_arguments(parser)
return parser
@abstractmethod
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add source-specific arguments. Override in subclasses."""
pass
@abstractmethod
async def load_data(self, args) -> list[str]:
"""Load data from the source. Returns list of text chunks."""
pass
def get_llm_config(self, args) -> dict[str, Any]:
"""Get LLM configuration based on arguments."""
config = {"type": args.llm}
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config
async def build_index(self, args, texts: list[str]) -> str:
"""Build LEANN index from texts."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
is_recompute=not args.no_recompute,
num_threads=1, # Force single-threaded mode
)
# Add texts in batches for better progress tracking
batch_size = 1000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
for text in batch:
builder.add_text(text)
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
print("Building index structure...")
builder.build_index(index_path)
print(f"Index saved to: {index_path}")
# Register project directory so leann list can discover this index
# The index is saved as args.index_dir/index_name.leann
# We want to register the current working directory where the app is run
register_project_directory(Path.cwd())
return index_path
async def run_interactive_chat(self, args, index_path: str):
"""Run interactive chat with the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
print("Type 'quit' or 'exit' to stop.\n")
while True:
try:
query = input("You: ").strip()
if query.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if not query:
continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
complexity=args.search_complexity,
)
print(f"\n[Query]: \033[36m{query}\033[0m")
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
)
print(f"\n[Response]: \033[36m{response}\033[0m")
async def run(self):
"""Main entry point for the example."""
args = self.parser.parse_args()
# Check if index exists
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
index_exists = Path(args.index_dir).exists()
if not index_exists or args.force_rebuild:
# Load data and build index
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
texts = await self.load_data(args)
if not texts:
print("No data found to index!")
return
index_path = await self.build_index(args, texts)
else:
print(f"\nUsing existing index in {args.index_dir}")
# Run query or interactive mode
if args.query:
await self.run_single_query(args, index_path, args.query)
else:
await self.run_interactive_chat(args, index_path)

View File

@@ -1,171 +0,0 @@
"""
Browser History RAG example using the unified interface.
Supports Chrome browser history.
"""
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .history_data.history import ChromeHistoryReader
class BrowserRAG(BaseRAGExample):
"""RAG example for Chrome browser history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Browser History",
description="Process and query Chrome browser history with LEANN",
default_index_name="google_history_index",
)
def _add_specific_arguments(self, parser):
"""Add browser-specific arguments."""
browser_group = parser.add_argument_group("Browser Parameters")
browser_group.add_argument(
"--chrome-profile",
type=str,
default=None,
help="Path to Chrome profile directory (auto-detected if not specified)",
)
browser_group.add_argument(
"--auto-find-profiles",
action="store_true",
default=True,
help="Automatically find all Chrome profiles (default: True)",
)
browser_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
browser_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _get_chrome_base_path(self) -> Path:
"""Get the base Chrome profile path based on OS."""
if sys.platform == "darwin":
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
elif sys.platform.startswith("linux"):
return Path.home() / ".config" / "google-chrome"
elif sys.platform == "win32":
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
else:
raise ValueError(f"Unsupported platform: {sys.platform}")
def _find_chrome_profiles(self) -> list[Path]:
"""Auto-detect all Chrome profiles."""
base_path = self._get_chrome_base_path()
if not base_path.exists():
return []
profiles = []
# Check Default profile
default_profile = base_path / "Default"
if default_profile.exists() and (default_profile / "History").exists():
profiles.append(default_profile)
# Check numbered profiles
for item in base_path.iterdir():
if item.is_dir() and item.name.startswith("Profile "):
if (item / "History").exists():
profiles.append(item)
return profiles
async def load_data(self, args) -> list[str]:
"""Load browser history and convert to text chunks."""
# Determine Chrome profiles
if args.chrome_profile and not args.auto_find_profiles:
profile_dirs = [Path(args.chrome_profile)]
else:
print("Auto-detecting Chrome profiles...")
profile_dirs = self._find_chrome_profiles()
# If specific profile given, filter to just that one
if args.chrome_profile:
profile_path = Path(args.chrome_profile)
profile_dirs = [p for p in profile_dirs if p == profile_path]
if not profile_dirs:
print("No Chrome profiles found!")
print("Please specify --chrome-profile manually")
return []
print(f"Found {len(profile_dirs)} Chrome profiles")
# Create reader
reader = ChromeHistoryReader()
# Process each profile
all_documents = []
total_processed = 0
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
try:
# Apply max_items limit per profile
max_per_profile = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_profile = remaining
# Load history
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_per_profile,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} history entries from this profile")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No browser history found to process!")
return []
print(f"\nTotal history entries processed: {len(all_documents)}")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for browser history RAG
print("\n🌐 Browser History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What websites did I visit about machine learning?'")
print("- 'Find my search history about programming'")
print("- 'What YouTube videos did I watch recently?'")
print("- 'Show me websites about travel planning'")
print("\nNote: Make sure Chrome is closed before running\n")
rag = BrowserRAG()
asyncio.run(rag.run())

View File

@@ -1,22 +0,0 @@
"""
Chunking utilities for LEANN RAG applications.
Provides AST-aware and traditional text chunking functionality.
"""
from .utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
__all__ = [
"CODE_EXTENSIONS",
"create_ast_chunks",
"create_text_chunks",
"create_traditional_chunks",
"detect_code_files",
"get_language_from_extension",
]

View File

@@ -1,320 +0,0 @@
"""
Enhanced chunking utilities with AST-aware code chunking support.
Provides unified interface for both traditional and AST-based text chunking.
"""
import logging
from pathlib import Path
from typing import Optional
from llama_index.core.node_parser import SentenceSplitter
logger = logging.getLogger(__name__)
# Code file extensions supported by astchunk
CODE_EXTENSIONS = {
".py": "python",
".java": "java",
".cs": "csharp",
".ts": "typescript",
".tsx": "typescript",
".js": "typescript",
".jsx": "typescript",
}
# Default chunk parameters for different content types
DEFAULT_CHUNK_PARAMS = {
"code": {
"max_chunk_size": 512,
"chunk_overlap": 64,
},
"text": {
"chunk_size": 256,
"chunk_overlap": 128,
},
}
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
"""
Separate documents into code files and regular text files.
Args:
documents: List of LlamaIndex Document objects
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
Returns:
Tuple of (code_documents, text_documents)
"""
if code_extensions is None:
code_extensions = CODE_EXTENSIONS
code_docs = []
text_docs = []
for doc in documents:
# Get file path from metadata
file_path = doc.metadata.get("file_path", "")
if not file_path:
# Fallback to file_name
file_path = doc.metadata.get("file_name", "")
if file_path:
file_ext = Path(file_path).suffix.lower()
if file_ext in code_extensions:
# Add language info to metadata
doc.metadata["language"] = code_extensions[file_ext]
doc.metadata["is_code"] = True
code_docs.append(doc)
else:
doc.metadata["is_code"] = False
text_docs.append(doc)
else:
# If no file path, treat as text
doc.metadata["is_code"] = False
text_docs.append(doc)
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
return code_docs, text_docs
def get_language_from_extension(file_path: str) -> Optional[str]:
"""Get the programming language from file extension."""
ext = Path(file_path).suffix.lower()
return CODE_EXTENSIONS.get(ext)
def create_ast_chunks(
documents,
max_chunk_size: int = 512,
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
"""
Create AST-aware chunks from code documents using astchunk.
Args:
documents: List of code documents
max_chunk_size: Maximum characters per chunk
chunk_overlap: Number of AST nodes to overlap between chunks
metadata_template: Template for chunk metadata
Returns:
List of text chunks with preserved code structure
"""
try:
from astchunk import ASTChunkBuilder
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
# Get language from metadata (set by detect_code_files)
language = doc.metadata.get("language")
if not language:
logger.warning(
"No language detected for document, falling back to traditional chunking"
)
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
continue
try:
# Configure astchunk
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
"metadata_template": metadata_template,
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
}
# Add repository-level metadata if available
repo_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
"creation_date": doc.metadata.get("creation_date", ""),
"last_modified_date": doc.metadata.get("last_modified_date", ""),
}
configs["repo_level_metadata"] = repo_metadata
# Create chunk builder and process
chunk_builder = ASTChunkBuilder(**configs)
code_content = doc.get_content()
if not code_content or not code_content.strip():
logger.warning("Empty code content, skipping")
continue
chunks = chunk_builder.chunkify(code_content)
# Extract text content from chunks
for chunk in chunks:
if hasattr(chunk, "text"):
chunk_text = chunk.text
elif isinstance(chunk, dict) and "text" in chunk:
chunk_text = chunk["text"]
elif isinstance(chunk, str):
chunk_text = chunk
else:
# Try to convert to string
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
all_chunks.append(chunk_text.strip())
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
)
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
return all_chunks
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""
Create traditional text chunks using LlamaIndex SentenceSplitter.
Args:
documents: List of documents to chunk
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
Returns:
List of text chunks
"""
# Handle invalid chunk_size values
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
# Ensure chunk_overlap is not negative and not larger than chunk_size
if chunk_overlap < 0:
chunk_overlap = 0
if chunk_overlap >= chunk_size:
chunk_overlap = chunk_size // 2
node_parser = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
all_texts = []
for doc in documents:
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
chunk_texts = [node.get_content() for node in nodes]
all_texts.extend(chunk_texts)
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
# As last resort, add the raw content
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
return all_texts
def create_text_chunks(
documents,
chunk_size: int = 256,
chunk_overlap: int = 128,
use_ast_chunking: bool = False,
ast_chunk_size: int = 512,
ast_chunk_overlap: int = 64,
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""
Create text chunks from documents with optional AST support for code files.
Args:
documents: List of LlamaIndex Document objects
chunk_size: Size for traditional text chunks
chunk_overlap: Overlap for traditional text chunks
use_ast_chunking: Whether to use AST chunking for code files
ast_chunk_size: Size for AST chunks
ast_chunk_overlap: Overlap for AST chunks
code_file_extensions: Custom list of code file extensions
ast_fallback_traditional: Fall back to traditional chunking on AST errors
Returns:
List of text chunks
"""
if not documents:
logger.warning("No documents provided for chunking")
return []
# Create a local copy of supported extensions for this function call
local_code_extensions = CODE_EXTENSIONS.copy()
# Update supported extensions if provided
if code_file_extensions:
# Map extensions to languages (simplified mapping)
ext_mapping = {
".py": "python",
".java": "java",
".cs": "c_sharp",
".ts": "typescript",
".tsx": "typescript",
}
for ext in code_file_extensions:
if ext.lower() not in local_code_extensions:
# Try to guess language from extension
if ext.lower() in ext_mapping:
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
else:
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
all_chunks = []
if use_ast_chunking:
# Separate code and text documents using local extensions
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
# Process code files with AST chunking
if code_docs:
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
try:
ast_chunks = create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
)
all_chunks.extend(ast_chunks)
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
except Exception as e:
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
logger.info("Falling back to traditional chunking for code files")
traditional_code_chunks = create_traditional_chunks(
code_docs, chunk_size, chunk_overlap
)
all_chunks.extend(traditional_code_chunks)
else:
raise
# Process text files with traditional chunking
if text_docs:
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
all_chunks.extend(text_chunks)
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
else:
# Use traditional chunking for all files
logger.info(f"Processing {len(documents)} documents with traditional chunking")
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")
return all_chunks

View File

@@ -1,211 +0,0 @@
"""
Code RAG example using AST-aware chunking for optimal code understanding.
Specialized for code repositories with automatic language detection and
optimized chunking parameters.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import CODE_EXTENSIONS, create_text_chunks
from llama_index.core import SimpleDirectoryReader
class CodeRAG(BaseRAGExample):
"""Specialized RAG example for code repositories with AST-aware chunking."""
def __init__(self):
super().__init__(
name="Code",
description="Process and query code repositories with AST-aware chunking",
default_index_name="code_index",
)
# Override defaults for code-specific usage
self.embedding_model_default = "facebook/contriever" # Good for code
self.max_items_default = -1 # Process all code files by default
def _add_specific_arguments(self, parser):
"""Add code-specific arguments."""
code_group = parser.add_argument_group("Code Repository Parameters")
code_group.add_argument(
"--repo-dir",
type=str,
default=".",
help="Code repository directory to index (default: current directory)",
)
code_group.add_argument(
"--include-extensions",
nargs="+",
default=list(CODE_EXTENSIONS.keys()),
help="File extensions to include (default: supported code extensions)",
)
code_group.add_argument(
"--exclude-dirs",
nargs="+",
default=[
".git",
"__pycache__",
"node_modules",
"venv",
".venv",
"build",
"dist",
"target",
],
help="Directories to exclude from indexing",
)
code_group.add_argument(
"--max-file-size",
type=int,
default=1000000, # 1MB
help="Maximum file size in bytes to process (default: 1MB)",
)
code_group.add_argument(
"--include-comments",
action="store_true",
help="Include comments in chunking (useful for documentation)",
)
code_group.add_argument(
"--preserve-imports",
action="store_true",
default=True,
help="Try to preserve import statements in chunks (default: True)",
)
async def load_data(self, args) -> list[str]:
"""Load code files and convert to AST-aware chunks."""
print(f"🔍 Scanning code repository: {args.repo_dir}")
print(f"📁 Including extensions: {args.include_extensions}")
print(f"🚫 Excluding directories: {args.exclude_dirs}")
# Check if repository directory exists
repo_path = Path(args.repo_dir)
if not repo_path.exists():
raise ValueError(f"Repository directory not found: {args.repo_dir}")
# Load code files with filtering
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
"required_exts": args.include_extensions,
"exclude_hidden": True,
}
# Create exclusion filter
def file_filter(file_path: str) -> bool:
"""Filter out unwanted files and directories."""
path = Path(file_path)
# Check file size
try:
if path.stat().st_size > args.max_file_size:
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
return False
except Exception:
return False
# Check if in excluded directory
for exclude_dir in args.exclude_dirs:
if exclude_dir in path.parts:
return False
return True
try:
# Load documents with file filtering
documents = SimpleDirectoryReader(
args.repo_dir,
file_extractor=None, # Use default extractors
**reader_kwargs,
).load_data(show_progress=True)
# Apply custom filtering
filtered_docs = []
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents = filtered_docs
except Exception as e:
print(f"❌ Error loading code files: {e}")
return []
if not documents:
print(
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
)
return []
print(f"✅ Loaded {len(documents)} code files")
# Show breakdown by language/extension
ext_counts = {}
for doc in documents:
file_path = doc.metadata.get("file_path", "")
if file_path:
ext = Path(file_path).suffix.lower()
ext_counts[ext] = ext_counts.get(ext, 0) + 1
print("📊 Files by extension:")
for ext, count in sorted(ext_counts.items()):
print(f" {ext}: {count} files")
# Use AST-aware chunking by default for code
print(
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
)
all_texts = create_text_chunks(
documents,
chunk_size=256, # Fallback for non-code files
chunk_overlap=64,
use_ast_chunking=True, # Always use AST for code RAG
ast_chunk_size=args.ast_chunk_size,
ast_chunk_overlap=args.ast_chunk_overlap,
code_file_extensions=args.include_extensions,
ast_fallback_traditional=True,
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
print(f"✅ Generated {len(all_texts)} code chunks")
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for code RAG
print("\n💻 Code RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'How does the embedding computation work?'")
print("- 'What are the main classes in this codebase?'")
print("- 'Show me the search implementation'")
print("- 'How is error handling implemented?'")
print("- 'What design patterns are used?'")
print("- 'Explain the chunking logic'")
print("\n🚀 Features:")
print("- ✅ AST-aware chunking preserves code structure")
print("- ✅ Automatic language detection")
print("- ✅ Smart filtering of large files and common excludes")
print("- ✅ Optimized for code understanding")
print("\nUsage examples:")
print(" python -m apps.code_rag --repo-dir ./my_project")
print(
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
)
print("\nOr run without --query for interactive mode\n")
rag = CodeRAG()
asyncio.run(rag.run())

View File

@@ -1,131 +0,0 @@
"""
Document RAG example using the unified interface.
Supports PDF, TXT, MD, and other document formats.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from llama_index.core import SimpleDirectoryReader
class DocumentRAG(BaseRAGExample):
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
def __init__(self):
super().__init__(
name="Document",
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
default_index_name="test_doc_files",
)
def _add_specific_arguments(self, parser):
"""Add document-specific arguments."""
doc_group = parser.add_argument_group("Document Parameters")
doc_group.add_argument(
"--data-dir",
type=str,
default="data",
help="Directory containing documents to index (default: data)",
)
doc_group.add_argument(
"--file-types",
nargs="+",
default=None,
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
)
doc_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
doc_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
doc_group.add_argument(
"--enable-code-chunking",
action="store_true",
help="Enable AST-aware chunking for code files in the data directory",
)
async def load_data(self, args) -> list[str]:
"""Load documents and convert to text chunks."""
print(f"Loading documents from: {args.data_dir}")
if args.file_types:
print(f"Filtering by file types: {args.file_types}")
else:
print("Processing all supported file types")
# Check if data directory exists
data_path = Path(args.data_dir)
if not data_path.exists():
raise ValueError(f"Data directory not found: {args.data_dir}")
# Load documents
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
}
if args.file_types:
reader_kwargs["required_exts"] = args.file_types
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
show_progress=True
)
if not documents:
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
return []
print(f"Loaded {len(documents)} documents")
# Determine chunking strategy
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
if use_ast:
print("Using AST-aware chunking for code files")
# Convert to text chunks with optional AST support
all_texts = create_text_chunks(
documents,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
use_ast_chunking=use_ast,
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
code_file_extensions=getattr(args, "code_file_extensions", None),
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for document RAG
print("\n📄 Document RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What are the main techniques LEANN uses?'")
print("- 'What is the technique DLPM?'")
print("- 'Who does Elizabeth Bennet marry?'")
print(
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
)
print("\n🚀 NEW: Code-aware chunking available!")
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
print("- Supports Python, Java, C#, TypeScript files")
print("- Better semantic understanding of code structure")
print("\nOr run without --query for interactive mode\n")
rag = DocumentRAG()
asyncio.run(rag.run())

View File

@@ -1,167 +0,0 @@
import email
import os
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str | None = None) -> list[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, _dirnames, _filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
total_files = 0
successful_files = 0
failed_files = 0
print(f"Starting to process directory: {input_dir}")
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
# Check if we've reached the max count (skip if max_count == -1)
if max_count > 0 and count >= max_count:
break
if filename.endswith(".emlx"):
total_files += 1
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if (
part.get_content_type() == "text/plain"
or part.get_content_type() == "text/html"
):
if (
part.get_content_type() == "text/html"
and not self.include_html
):
continue
try:
payload = part.get_payload(decode=True)
if payload:
body += payload.decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error decoding payload: {e}")
continue
else:
try:
payload = msg.get_payload(decode=True)
if payload:
body = payload.decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error decoding single part payload: {e}")
body = ""
# Only create document if we have some content
if body.strip() or subject != "No Subject":
# Create document content with metadata embedded in text
doc_content = f"""
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
successful_files += 1
# Print first few successful files for debugging
if successful_files <= 3:
print(
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
)
except Exception as e:
failed_files += 1
if failed_files <= 5: # Only print first few errors
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
failed_files += 1
if failed_files <= 5: # Only print first few errors
print(f"Error reading file {filepath}: {e}")
continue
print("Processing summary:")
print(f" Total .emlx files found: {total_files}")
print(f" Successfully loaded: {successful_files}")
print(f" Failed to load: {failed_files}")
print(f" Final documents: {len(docs)}")
return docs

View File

@@ -1,157 +0,0 @@
"""
Email RAG example using the unified interface.
Supports Apple Mail on macOS.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .email_data.LEANN_email_reader import EmlxReader
class EmailRAG(BaseRAGExample):
"""RAG example for Apple Mail processing."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all emails by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Email",
description="Process and query Apple Mail emails with LEANN",
default_index_name="mail_index",
)
def _add_specific_arguments(self, parser):
"""Add email-specific arguments."""
email_group = parser.add_argument_group("Email Parameters")
email_group.add_argument(
"--mail-path",
type=str,
default=None,
help="Path to Apple Mail directory (auto-detected if not specified)",
)
email_group.add_argument(
"--include-html", action="store_true", help="Include HTML content in email processing"
)
email_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
email_group.add_argument(
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
)
def _find_mail_directories(self) -> list[Path]:
"""Auto-detect all Apple Mail directories."""
mail_base = Path.home() / "Library" / "Mail"
if not mail_base.exists():
return []
# Find all Messages directories
messages_dirs = []
for item in mail_base.rglob("Messages"):
if item.is_dir():
messages_dirs.append(item)
return messages_dirs
async def load_data(self, args) -> list[str]:
"""Load emails and convert to text chunks."""
# Determine mail directories
if args.mail_path:
messages_dirs = [Path(args.mail_path)]
else:
print("Auto-detecting Apple Mail directories...")
messages_dirs = self._find_mail_directories()
if not messages_dirs:
print("No Apple Mail directories found!")
print("Please specify --mail-path manually")
return []
print(f"Found {len(messages_dirs)} mail directories")
# Create reader
reader = EmlxReader(include_html=args.include_html)
# Process each directory
all_documents = []
total_processed = 0
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
try:
# Count emlx files
emlx_files = list(messages_dir.glob("*.emlx"))
print(f"Found {len(emlx_files)} email files")
# Apply max_items limit per directory
max_per_dir = -1 # Default to process all
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_dir = remaining
# If args.max_items == -1, max_per_dir stays -1 (process all)
# Load emails - fix the parameter passing
documents = reader.load_data(
input_dir=str(messages_dir),
max_count=max_per_dir,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} emails from this directory")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No emails found to process!")
return []
print(f"\nTotal emails processed: {len(all_documents)}")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks
# Email reader uses chunk_overlap=25 as in original
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
print(" Windows/Linux support coming soon!\n")
# Example queries for email RAG
print("\n📧 Email RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did my boss say about deadlines?'")
print("- 'Find emails about travel expenses'")
print("- 'Show me emails from last month about the project'")
print("- 'What food did I order from DoorDash?'")
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
rag = EmailRAG()
asyncio.run(rag.run())

View File

@@ -1,189 +0,0 @@
"""
WeChat History RAG example using the unified interface.
Supports WeChat chat history export and search.
"""
import subprocess
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from .history_data.wechat_history import WeChatHistoryReader
class WeChatRAG(BaseRAGExample):
"""RAG example for WeChat chat history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Match original default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="WeChat History",
description="Process and query WeChat chat history with LEANN",
default_index_name="wechat_history_magic_test_11Debug_new",
)
def _add_specific_arguments(self, parser):
"""Add WeChat-specific arguments."""
wechat_group = parser.add_argument_group("WeChat Parameters")
wechat_group.add_argument(
"--export-dir",
type=str,
default="./wechat_export",
help="Directory to store WeChat exports (default: ./wechat_export)",
)
wechat_group.add_argument(
"--force-export",
action="store_true",
help="Force re-export of WeChat data even if exports exist",
)
wechat_group.add_argument(
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
)
wechat_group.add_argument(
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
)
def _export_wechat_data(self, export_dir: Path) -> bool:
"""Export WeChat data using wechattweak-cli."""
print("Exporting WeChat data...")
# Check if WeChat is running
try:
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
if result.returncode != 0:
print("WeChat is not running. Please start WeChat first.")
return False
except Exception:
pass # pgrep might not be available on all systems
# Create export directory
export_dir.mkdir(parents=True, exist_ok=True)
# Run export command
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
try:
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("WeChat data exported successfully!")
return True
else:
print(f"Export failed: {result.stderr}")
return False
except FileNotFoundError:
print("\nError: wechattweak-cli not found!")
print("Please install it first:")
print(" sudo packages/wechat-exporter/wechattweak-cli install")
return False
except Exception as e:
print(f"Export error: {e}")
return False
async def load_data(self, args) -> list[str]:
"""Load WeChat history and convert to text chunks."""
# Initialize WeChat reader with export capabilities
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Trying to find any existing exports...")
# Try to find any existing exports in common locations
export_dirs = reader.find_wechat_export_dirs()
if not export_dirs:
print("No WeChat data found. Please ensure WeChat exports exist.")
return []
# Load documents from all found export directories
all_documents = []
total_processed = 0
for i, export_dir in enumerate(export_dirs):
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
try:
# Apply max_items limit per export
max_per_export = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_export = remaining
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_per_export,
concatenate_messages=True, # Enable message concatenation for better context
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return []
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks with contact information
all_texts = []
for doc in all_documents:
# Split the document into chunks
from llama_index.core.node_parser import SentenceSplitter
text_splitter = SentenceSplitter(
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
# Add contact information to each chunk
contact_name = doc.metadata.get("contact_name", "Unknown")
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: WeChat export is only supported on macOS")
print(" You can still query existing exports on other platforms\n")
# Example queries for WeChat RAG
print("\n💬 WeChat History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'Show me conversations about travel plans'")
print("- 'Find group chats about weekend activities'")
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
print("- 'What did we discuss about the project last month?'")
print("\nNote: WeChat must be running for export to work\n")
rag = WeChatRAG()
asyncio.run(rag.run())

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 73 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 224 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 152 KiB

View File

@@ -1,148 +0,0 @@
import argparse
import os
import time
from pathlib import Path
from leann import LeannBuilder, LeannSearcher
def _meta_exists(index_path: str) -> bool:
p = Path(index_path)
return (p.parent / f"{p.stem}.meta.json").exists()
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
# if _meta_exists(index_path):
# return
kwargs = {}
if backend_name == "hnsw":
kwargs["is_compact"] = is_recompute
builder = LeannBuilder(
backend_name=backend_name,
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
graph_degree=32,
complexity=64,
is_recompute=is_recompute,
num_threads=4,
**kwargs,
)
for i in range(num_docs):
builder.add_text(
f"This is a test document number {i}. It contains some repeated text for benchmarking."
)
builder.build_index(index_path)
def _bench_group(
index_path: str,
recompute: bool,
query: str,
repeats: int,
complexity: int = 32,
top_k: int = 10,
) -> float:
# Independent searcher per group; fixed port when recompute
searcher = LeannSearcher(index_path=index_path)
# Warm-up once
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
def _once() -> float:
t0 = time.time()
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
return time.time() - t0
if repeats <= 1:
t = _once()
else:
vals = [_once() for _ in range(repeats)]
vals.sort()
t = vals[len(vals) // 2]
searcher.cleanup()
return t
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-docs", type=int, default=5000)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--complexity", type=int, default=32)
args = parser.parse_args()
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
base.parent.mkdir(parents=True, exist_ok=True)
# ---------- Build HNSW variants ----------
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
# ---------- Build DiskANN variants ----------
diskann_r = str(base / "diskann_r.leann")
diskann_nr = str(base / "diskann_nr.leann")
ensure_index(diskann_r, "diskann", args.num_docs, True)
ensure_index(diskann_nr, "diskann", args.num_docs, False)
# ---------- Helpers ----------
def _size_for(prefix: str) -> int:
p = Path(prefix)
base_dir = p.parent
stem = p.stem
total = 0
for f in base_dir.iterdir():
if f.is_file() and f.name.startswith(stem):
total += f.stat().st_size
return total
# ---------- HNSW benchmark ----------
t_hnsw_r = _bench_group(
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
t_hnsw_nr = _bench_group(
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
size_hnsw_r = _size_for(hnsw_r)
size_hnsw_nr = _size_for(hnsw_nr)
print("Benchmark results (HNSW):")
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
print(
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
)
print(" Expectation: no-recompute should be faster but larger on disk.")
# ---------- DiskANN benchmark ----------
t_diskann_r = _bench_group(
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
)
t_diskann_nr = _bench_group(
diskann_nr,
False,
"DiskANN NR test doc 123",
repeats=args.repeats,
complexity=args.complexity,
)
size_diskann_r = _size_for(diskann_r)
size_diskann_nr = _size_for(diskann_nr)
print("\nBenchmark results (DiskANN):")
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
if __name__ == "__main__":
main()

View File

@@ -1,286 +0,0 @@
#!/usr/bin/env python3
"""
DiskANN vs HNSW Search Performance Comparison
This benchmark compares search performance between DiskANN and HNSW backends:
- DiskANN: With graph partitioning enabled (is_recompute=True)
- HNSW: With recompute enabled (is_recompute=True)
- Tests performance across different dataset sizes
- Measures search latency, recall, and index size
"""
import gc
import multiprocessing as mp
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
try:
mp.set_start_method("fork", force=True)
except Exception:
pass
def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking."""
np.random.seed(42)
topics = [
"machine learning and artificial intelligence",
"natural language processing and text analysis",
"computer vision and image recognition",
"data science and statistical analysis",
"deep learning and neural networks",
"information retrieval and search engines",
"database systems and data management",
"software engineering and programming",
"cybersecurity and network protection",
"cloud computing and distributed systems",
]
texts = []
for i in range(n_docs):
topic = topics[i % len(topics)]
variation = np.random.randint(1, 100)
text = (
f"This is document {i} about {topic}. Content variation {variation}. "
f"Additional information about {topic} with details and examples. "
f"Technical discussion of {topic} including implementation aspects."
)
texts.append(text)
return texts
def benchmark_backend(
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
) -> dict[str, float]:
"""Benchmark a specific backend with the given configuration."""
from leann.api import LeannBuilder, LeannSearcher
print(f"\n🔧 Testing {backend_name.upper()} backend...")
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
# Build index
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
start_time = time.time()
builder = LeannBuilder(
backend_name=backend_name,
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
**backend_kwargs,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
build_time = time.time() - start_time
# Measure index size
index_dir = Path(index_path).parent
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
size_mb = total_size / (1024 * 1024)
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
# Search benchmark
print("🔍 Running search benchmark...")
searcher = LeannSearcher(index_path)
search_times = []
all_results = []
for query in test_queries:
start_time = time.time()
results = searcher.search(query, top_k=5)
search_time = time.time() - start_time
search_times.append(search_time)
all_results.append(results)
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
# Check for valid scores (detect -inf issues)
all_scores = [
result.score
for results in all_results
for result in results
if result.score is not None
]
valid_scores = [
score for score in all_scores if score != float("-inf") and score != float("inf")
]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up (ensure embedding server shutdown and object GC)
try:
if hasattr(searcher, "cleanup"):
searcher.cleanup()
del searcher
del builder
gc.collect()
except Exception as e:
print(f"⚠️ Warning: Resource cleanup error: {e}")
return {
"build_time": build_time,
"avg_search_time_ms": avg_search_time,
"index_size_mb": size_mb,
"score_validity_rate": score_validity_rate,
}
def run_comparison(n_docs: int = 500, n_queries: int = 10):
"""Run performance comparison between DiskANN and HNSW."""
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
# Create test data
texts = create_test_texts(n_docs)
test_queries = [
"machine learning algorithms",
"natural language processing",
"computer vision techniques",
"data analysis methods",
"neural network architectures",
"database query optimization",
"software development practices",
"security vulnerabilities",
"cloud infrastructure",
"distributed computing",
][:n_queries]
# HNSW benchmark
hnsw_results = benchmark_backend(
backend_name="hnsw",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable recompute for fair comparison
"M": 16,
"efConstruction": 200,
},
)
# DiskANN benchmark
diskann_results = benchmark_backend(
backend_name="diskann",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable graph partitioning
"num_neighbors": 32,
"search_list_size": 50,
},
)
# Performance comparison
print("\n📈 Performance Comparison Results")
print(f"{'=' * 60}")
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
print(f"{'-' * 60}")
# Build time comparison
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
print(
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
)
# Search time comparison
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
print(
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
)
# Index size comparison
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
print(
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
)
# Score validity
print(
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
)
print(f"{'=' * 60}")
print("\n🎯 Summary:")
if search_speedup > 1:
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
else:
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
if size_ratio > 1:
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
else:
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
print(
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
)
if __name__ == "__main__":
import sys
try:
# Handle help request
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
print()
print("Arguments:")
print(" n_docs Number of documents to index (default: 500)")
print(" n_queries Number of test queries to run (default: 10)")
print()
print("Examples:")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
sys.exit(0)
# Parse command line arguments
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Dataset: {n_docs} documents, {n_queries} queries")
print()
run_comparison(n_docs=n_docs, n_queries=n_queries)
except KeyboardInterrupt:
print("\n⚠️ Benchmark interrupted by user")
sys.exit(130)
except Exception as e:
print(f"\n❌ Benchmark failed: {e}")
sys.exit(1)
finally:
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
try:
gc.collect()
print("\n🧹 Cleanup completed")
# Flush stdio to ensure message is visible before hard-exit
try:
import sys as _sys
_sys.stdout.flush()
_sys.stderr.flush()
except Exception:
pass
except Exception:
pass
# Use os._exit to bypass atexit handlers that may hang in rare cases
import os as _os
_os._exit(0)

82
data/.gitattributes vendored Normal file
View File

@@ -0,0 +1,82 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.lz4 filter=lfs diff=lfs merge=lfs -text
*.mds filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
# Audio files - uncompressed
*.pcm filter=lfs diff=lfs merge=lfs -text
*.sam filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
# Audio files - compressed
*.aac filter=lfs diff=lfs merge=lfs -text
*.flac filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
# Image files - uncompressed
*.bmp filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
# Image files - compressed
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.webp filter=lfs diff=lfs merge=lfs -text
# Video files - compressed
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text

0
benchmarks/data/README.md → data/README.md Executable file → Normal file
View File

View File

@@ -4,11 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start \n",
"\n",
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
"\n",
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
"# Quick Start in 30s"
]
},
{
@@ -17,25 +13,8 @@
"metadata": {},
"outputs": [],
"source": [
"# install this if you are using colab\n",
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
"! uv pip install leann --no-deps\n",
"# For Colab environment, we need to set some environment variables\n",
"import os\n",
"\n",
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"INDEX_DIR = Path(\"./\").resolve()\n",
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
"# install this if you areusing colab\n",
"! pip install leann"
]
},
{
@@ -47,21 +26,91 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"M: 64 for level: 0\n",
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
"[0.00s] Reading Index HNSW header...\n",
"[0.00s] Header read: d=768, ntotal=5\n",
"[0.00s] Reading HNSW struct vectors...\n",
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
"[0.00s] Read assign_probas (6)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
"[0.11s] Read cum_nneighbor_per_level (7)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
"[0.21s] Read levels (5)\n",
"[0.30s] Probing for compact storage flag...\n",
"[0.30s] Found compact flag: False\n",
"[0.30s] Compact flag is False, reading original format...\n",
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
"[0.30s] Read offsets (6)\n",
"[0.40s] Attempting to read neighbors vector...\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
"[0.40s] Read neighbors (320)\n",
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
"[0.50s] Checking for storage data...\n",
"[0.50s] Found storage fourcc: 49467849.\n",
"[0.50s] Converting to CSR format...\n",
"[0.50s] Conversion loop finished. \n",
"[0.50s] Running validation checks...\n",
" Checking total valid neighbor count...\n",
" OK: Total valid neighbors = 20\n",
" Checking final pointer indices...\n",
" OK: Final pointers match data size.\n",
"[0.50s] Deleting original neighbors and offsets arrays...\n",
" CSR Stats: |data|=20, |level_ptr|=10\n",
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
" Pruning embeddings: Writing NULL storage marker.\n",
"[0.69s] Conversion complete.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
]
}
],
"source": [
"from leann.api import LeannBuilder\n",
"\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
"builder.add_text(\n",
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
")\n",
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
"builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.build_index(INDEX_PATH)"
"builder.build_index(\"knowledge.leann\")"
]
},
{
@@ -73,13 +122,97 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'programming languages'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n",
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n"
]
},
{
"data": {
"text/plain": [
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from leann.api import LeannSearcher\n",
"\n",
"searcher = LeannSearcher(INDEX_PATH)\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"programming languages\", top_k=2)\n",
"results"
]
@@ -95,7 +228,79 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n",
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"data": {
"text/plain": [
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from leann.api import LeannChat\n",
"\n",
@@ -104,11 +309,11 @@
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
"}\n",
"\n",
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
"response = chat.ask(\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
" top_k=2,\n",
" llm_kwargs={\"max_tokens\": 128},\n",
" llm_kwargs={\"max_tokens\": 128}\n",
")\n",
"response"
]

View File

@@ -1,220 +0,0 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## 🚀 Development Setup
### Prerequisites
1. **Install uv** (fast Python package installer):
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. **Clone the repository**:
```bash
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
cd LEANN-RAG
```
3. **Install system dependencies**:
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq pkgconf
```
**Ubuntu/Debian:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
```
4. **Build from source**:
```bash
# macOS
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
# Ubuntu/Debian
uv sync
```
## 🔨 Pre-commit Hooks
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
### Setup Pre-commit
1. **Install pre-commit** (already included when you run `uv sync`):
```bash
uv pip install pre-commit
```
2. **Install the git hooks**:
```bash
pre-commit install
```
3. **Run pre-commit manually** (optional):
```bash
pre-commit run --all-files
```
### Pre-commit Checks
Our pre-commit configuration includes:
- **Trailing whitespace removal**
- **End-of-file fixing**
- **YAML validation**
- **Large file prevention**
- **Merge conflict detection**
- **Debug statement detection**
- **Code formatting with ruff**
- **Code linting with ruff**
## 🧪 Testing
### Running Tests
```bash
# Run all tests
uv run pytest
# Run specific test file
uv run pytest test/test_filename.py
# Run with coverage
uv run pytest --cov=leann
```
### Writing Tests
- Place tests in the `test/` directory
- Follow the naming convention `test_*.py`
- Use descriptive test names that explain what's being tested
- Include both positive and negative test cases
## 📝 Code Style
We use `ruff` for both linting and formatting to ensure consistent code style.
### Format Your Code
```bash
# Format all files
ruff format
# Check formatting without changing files
ruff format --check
```
### Lint Your Code
```bash
# Run linter with auto-fix
ruff check --fix
# Just check without fixing
ruff check
```
### Style Guidelines
- Follow PEP 8 conventions
- Use descriptive variable names
- Add type hints where appropriate
- Write docstrings for all public functions and classes
- Keep functions focused and single-purpose
## 🚦 CI/CD
Our CI pipeline runs automatically on all pull requests. It includes:
1. **Linting and Formatting**: Ensures code follows our style guidelines
2. **Multi-platform builds**: Tests on Ubuntu and macOS
3. **Python version matrix**: Tests on Python 3.9-3.13
4. **Wheel building**: Ensures packages can be built and distributed
### CI Commands
The CI uses the same commands as pre-commit to ensure consistency:
```bash
# Linting
ruff check .
# Format checking
ruff format --check .
```
Make sure your code passes these checks locally before pushing!
## 🔄 Pull Request Process
1. **Fork the repository** and create your branch from `main`:
```bash
git checkout -b feature/your-feature-name
```
2. **Make your changes**:
- Write clean, documented code
- Add tests for new functionality
- Update documentation as needed
3. **Run pre-commit checks**:
```bash
pre-commit run --all-files
```
4. **Test your changes**:
```bash
uv run pytest
```
5. **Commit with descriptive messages**:
```bash
git commit -m "feat: add new search algorithm"
```
Follow [Conventional Commits](https://www.conventionalcommits.org/):
- `feat:` for new features
- `fix:` for bug fixes
- `docs:` for documentation changes
- `test:` for test additions/changes
- `refactor:` for code refactoring
- `perf:` for performance improvements
6. **Push and create a pull request**:
- Provide a clear description of your changes
- Reference any related issues
- Include examples or screenshots if applicable
## 📚 Documentation
When adding new features or making significant changes:
1. Update relevant documentation in `/docs`
2. Add docstrings to new functions/classes
3. Update README.md if needed
4. Include usage examples
## 🤔 Getting Help
- **Discord**: Join our community for discussions
- **Issues**: Check existing issues or create a new one
- **Discussions**: For general questions and ideas
## 📄 License
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
---
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟

View File

@@ -19,4 +19,4 @@ That's it! The workflow will automatically:
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions
Check progress: https://github.com/yichuan-w/LEANN/actions

View File

@@ -1,123 +0,0 @@
# Thinking Budget Feature Implementation
## Overview
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
## Feature Description
The thinking budget feature provides three levels of computational effort for reasoning models:
- **`low`**: Fast responses, basic reasoning (default for simple queries)
- **`medium`**: Balanced speed and reasoning depth
- **`high`**: Maximum reasoning effort, best for complex analytical questions
## Implementation Details
### 1. Command Line Interface
Added `--thinking-budget` parameter to both CLI and RAG examples:
```bash
# LEANN CLI
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# RAG Examples
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
```
### 2. LLM Backend Support
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
options.pop("thinking_budget", None)
if thinking_budget in ["low", "medium", "high"]:
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
```
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
params["reasoning_effort"] = thinking_budget
```
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
### 3. Parameter Propagation
The thinking budget parameter is properly propagated through the LEANN architecture:
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
4. **LLM Interface**: Handles the parameter in backend-specific implementations
## Files Modified
### Core Implementation
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
### Documentation
- `README.md`: Added thinking budget parameter to usage examples
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
### Examples
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
## Usage Examples
### Basic Usage
```bash
# High reasoning effort for complex questions
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# Medium reasoning for balanced performance
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
# Low reasoning for fast responses
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
```
### RAG Examples
```bash
# Email RAG with high reasoning
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
# Document RAG with medium reasoning
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
```
## Supported Models
### Ollama Models
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
### OpenAI Models
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
- **GPT-OSS models**: Models that support reasoning capabilities
## Testing
The implementation includes comprehensive testing:
- Parameter handling verification
- Backend-specific API format validation
- CLI argument parsing tests
- Integration with existing LEANN architecture

View File

@@ -1,143 +0,0 @@
# AST-Aware Code chunking guide
## Overview
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
## Quick Start
### Basic Usage
```bash
# Enable AST chunking for mixed content (code + docs)
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
# Specialized code repository indexing
python -m apps.code_rag --repo-dir ./my_codebase
# Global CLI with AST support
leann build my-code-index --docs ./src --use-ast-chunking
```
### Installation
```bash
# Install LEANN with AST chunking support
uv pip install -e "."
```
#### For normal users (PyPI install)
- Use `pip install leann` or `uv pip install leann`.
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
#### For developers (from source, editable)
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
## Best Practices
### When to Use AST Chunking
**Recommended for:**
- Code repositories with multiple languages
- Mixed documentation and code content
- Complex codebases with deep function/class hierarchies
- When working with Claude Code for code assistance
**Not recommended for:**
- Pure text documents
- Very large files (>1MB)
- Languages not supported by tree-sitter
### Optimal Configuration
```bash
# Recommended settings for most codebases
python -m apps.code_rag \
--repo-dir ./src \
--ast-chunk-size 768 \
--ast-chunk-overlap 96 \
--exclude-dirs .git __pycache__ node_modules build dist
```
### Supported Languages
| Extension | Language | Status |
|-----------|----------|--------|
| `.py` | Python | ✅ Full support |
| `.java` | Java | ✅ Full support |
| `.cs` | C# | ✅ Full support |
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
## Integration Examples
### Document RAG with Code Support
```python
# Enable code chunking in document RAG
python -m apps.document_rag \
--enable-code-chunking \
--data-dir ./project \
--query "How does authentication work in the codebase?"
```
### Claude Code Integration
When using with Claude Code MCP server, AST chunking provides better context for:
- Code completion and suggestions
- Bug analysis and debugging
- Architecture understanding
- Refactoring assistance
## Troubleshooting
### Common Issues
1. **Fallback to Traditional Chunking**
- Normal behavior for unsupported languages
- Check logs for specific language support
2. **Performance with Large Files**
- Adjust `--max-file-size` parameter
- Use `--exclude-dirs` to skip unnecessary directories
3. **Quality Issues**
- Try different `--ast-chunk-size` values (512, 768, 1024)
- Adjust overlap for better context preservation
### Debug Mode
```bash
export LEANN_LOG_LEVEL=DEBUG
python -m apps.code_rag --repo-dir ./my_code
```
## Migration from Traditional Chunking
Existing workflows continue to work without changes. To enable AST chunking:
```bash
# Before
python -m apps.document_rag --chunk-size 256
# After (maintains traditional chunking for non-code files)
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
```
## References
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
- [Research Paper](https://arxiv.org/html/2506.15655v1)
---
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.

View File

@@ -1,98 +0,0 @@
"""
Comparison between Sentence Transformers and OpenAI embeddings
This example shows how different embedding models handle complex queries
and demonstrates the differences between local and API-based embeddings.
"""
import numpy as np
from leann.embedding_compute import compute_embeddings
# OpenAI API key should be set as environment variable
# export OPENAI_API_KEY="your-api-key-here"
# Test data
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
# Two queries with same intent but different wording
query1 = "Tell me my browser history about some conference i often visit"
query2 = "browser history about conference I often visit"
texts = [query1, query2, conference_text, browser_text]
def cosine_similarity(a, b):
return np.dot(a, b) # Already normalized
def analyze_embeddings(embeddings, model_name):
print(f"\n=== {model_name} Results ===")
# Results for Query 1
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
print(f"Query 1: '{query1}'")
print(f" → Conference similarity: {sim1_conf:.4f} {'' if sim1_conf > sim1_browser else ''}")
print(
f" → Browser similarity: {sim1_browser:.4f} {'' if sim1_browser > sim1_conf else ''}"
)
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
# Results for Query 2
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
print(f"\nQuery 2: '{query2}'")
print(f" → Conference similarity: {sim2_conf:.4f} {'' if sim2_conf > sim2_browser else ''}")
print(
f" → Browser similarity: {sim2_browser:.4f} {'' if sim2_browser > sim2_conf else ''}"
)
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
# Show the impact
print("\n=== Impact Analysis ===")
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
print("✅ STABLE: Conference remains winner in both queries")
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
print("✅ STABLE: Browser remains winner in both queries")
else:
print("🔄 MIXED: Results vary between queries")
return {
"query1_conf": sim1_conf,
"query1_browser": sim1_browser,
"query2_conf": sim2_conf,
"query2_browser": sim2_browser,
}
# Test Sentence Transformers
print("Testing Sentence Transformers (facebook/contriever)...")
try:
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
except Exception as e:
print(f"❌ Sentence Transformers failed: {e}")
st_results = None
# Test OpenAI
print("\n" + "=" * 60)
print("Testing OpenAI (text-embedding-3-small)...")
try:
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
except Exception as e:
print(f"❌ OpenAI failed: {e}")
openai_results = None
# Compare results
if st_results and openai_results:
print("\n" + "=" * 60)
print("=== COMPARISON SUMMARY ===")

View File

@@ -1,384 +0,0 @@
# LEANN Configuration Guide
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
## Getting Started: Simple is Better
When first trying LEANN, start with a small dataset to quickly validate your approach:
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
```bash
python -m apps.document_rag --query "What techniques does LEANN use?"
```
**For other data sources**: Limit the dataset size for quick testing
```bash
# WeChat: Test with recent messages only
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
# Browser history: Last few days
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
# Email: Recent inbox
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
```
Once validated, scale up gradually:
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
- This helps identify issues early before committing to long processing times
## Embedding Model Selection: Understanding the Trade-offs
Based on our experience developing LEANN, embedding models fall into three categories:
### Small Models (< 100M parameters)
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
- **Pros**: Lightweight, fast for both indexing and inference
- **Cons**: Lower semantic understanding, may miss nuanced relationships
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
### Medium Models (100M-500M parameters)
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
- **Pros**: Balanced performance, good multilingual support, reasonable speed
- **Cons**: Requires more compute than small models
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
### Large Models (500M+ parameters)
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
- **Cons**: Slower inference, longer index build times
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
### Quick Start: Cloud and Local Embedding Options
**OpenAI Embeddings (Fastest Setup)**
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
```bash
# Set OpenAI embeddings (requires OPENAI_API_KEY)
--embedding-mode openai --embedding-model text-embedding-3-small
```
**Ollama Embeddings (Privacy-Focused)**
For local embeddings with complete privacy:
```bash
# First, pull an embedding model
ollama pull nomic-embed-text
# Use Ollama embeddings
--embedding-mode ollama --embedding-model nomic-embed-text
```
<details>
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
**OpenAI Embeddings** (`text-embedding-3-small/large`)
- **Pros**: No local compute needed, consistently fast, high quality
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- **When to use**: Prototyping, non-sensitive data, need immediate results
**Local Embeddings**
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
- **Cons**: Slower than cloud APIs, requires local compute resources
- **When to use**: Production systems, sensitive data, cost-sensitive applications
</details>
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
- Full recomputation required
- High memory usage during build phase
- Excellent recall (95%+)
```bash
# Optimal for most use cases
--backend-name hnsw --graph-degree 32 --build-complexity 64
```
### DiskANN
**Best for**: Large datasets, especially when you want `recompute=True`.
**Key advantages:**
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
- **Better scaling**: Designed for 100k+ documents
**Recompute behavior:**
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
```bash
# Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64
```
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison
### LLM Engines
**OpenAI** (`--llm openai`)
- **Pros**: Best quality, consistent performance, no local resources needed
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
**Ollama** (`--llm ollama`)
- **Pros**: Fully local, free, privacy-preserving, good model variety
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
**HuggingFace** (`--llm hf`)
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
- **Cons**: More complex initial setup
- **Models**: `Qwen/Qwen3-1.7B-FP8`
## Parameter Tuning Guide
### Search Complexity Parameters
**`--build-complexity`** (index building)
- Controls thoroughness during index construction
- Higher = better recall but slower build
- Recommendations:
- 32: Quick prototyping
- 64: Balanced (default)
- 128: Production systems
- 256: Maximum quality
**`--search-complexity`** (query time)
- Controls search thoroughness
- Higher = better results but slower
- Recommendations:
- 16: Fast/Interactive search
- 32: High quality with diversity
- 64+: Maximum accuracy
### Top-K Selection
**`--top-k`** (number of retrieved chunks)
- More chunks = better context but slower LLM processing
- Should be always smaller than `--search-complexity`
- Guidelines:
- 10-20: General questions (default: 20)
- 30+: Complex multi-hop reasoning requiring comprehensive context
**Trade-off formula**:
- Retrieval time ∝ log(n) × search_complexity
- LLM processing time ∝ top_k × chunk_size
- Total context = top_k × chunk_size tokens
### Thinking Budget for Reasoning Models
**`--thinking-budget`** (reasoning effort level)
- Controls the computational effort for reasoning models
- Options: `low`, `medium`, `high`
- Guidelines:
- `low`: Fast responses, basic reasoning (default for simple queries)
- `medium`: Balanced speed and reasoning depth
- `high`: Maximum reasoning effort, best for complex analytical questions
- **Supported Models**:
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
- **Example**: `--thinking-budget high` for complex analytical questions
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
**💡 Quick Examples:**
```bash
# OpenAI o-series reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm openai --llm-model o3 --thinking-budget medium
# Ollama reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
```
### Graph Degree (HNSW/DiskANN)
**`--graph-degree`**
- Number of connections per node in the graph
- Higher = better recall but more memory
- HNSW: 16-32 (default: 32)
- DiskANN: 32-128 (default: 64)
## Performance Optimization Checklist
### If Embedding is Too Slow
1. **Switch to smaller model**:
```bash
# From large model
--embedding-model Qwen/Qwen3-Embedding-0.6B
# To small model
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
2. **Limit dataset size for testing**:
```bash
--max-items 1000 # Process first 1k items only
```
3. **Use MLX on Apple Silicon** (optional optimization):
```bash
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
```
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
4. **Use Ollama**
```bash
--embedding-mode ollama --embedding-model nomic-embed-text
```
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
### If Search Quality is Poor
1. **Increase retrieval count**:
```bash
--top-k 30 # Retrieve more candidates
```
2. **Upgrade embedding model**:
```bash
# For English
--embedding-model BAAI/bge-base-en-v1.5
# For multilingual
--embedding-model intfloat/multilingual-e5-large
```
## Understanding the Trade-offs
Every configuration choice involves trade-offs:
| Factor | Small/Fast | Large/Quality |
|--------|------------|---------------|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
| Chunk Size | 512 tokens | 128 tokens |
| Index Type | HNSW | DiskANN |
| LLM | `qwen3:1.7b` | `gpt-4o` |
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Low-resource setups
If you dont have a local GPU or builds/searches are too slow, use one or more of the options below.
### 1) Use OpenAI embeddings (no local compute)
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
```bash
export OPENAI_API_KEY=sk-...
# Build with OpenAI embeddings
leann build my-index \
--embedding-mode openai \
--embedding-model text-embedding-3-small
# Search with OpenAI embeddings (recompute at query time)
leann search my-index "your query" \
--recompute
```
### 2) Run remote builds with SkyPilot (cloud GPU)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
```bash
# One-time: install and configure SkyPilot
pip install skypilot
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
sky launch -c leann-gpu sky/leann-build.yaml
# Override parameters via -e key=value (optional)
sky launch -c leann-gpu sky/leann-build.yaml \
-e index_name=my-index \
-e backend=hnsw \
-e embedding_mode=sentence-transformers \
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
# Copy the built index back to your local .leann (use rsync)
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
```
### 3) Disable recomputation to trade storage for speed
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
```bash
# Build without recomputation (HNSW requires non-compact in this mode)
leann build my-index --no-recompute --no-compact
# Search without recomputation
leann search my-index "your query" --no-recompute
```
When to use:
- Extreme low latency requirements (high QPS, interactive assistants)
- Read-heavy workloads where storage is cheaper than latency
- No always-available GPU
Constraints:
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
- DiskANN: supported; `--no-recompute` skips selective recompute during search
Storage impact:
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
Converting an existing index (rebuild required):
```bash
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
leann build my-index --force --no-recompute --no-compact
```
Python API usage:
```python
from leann import LeannSearcher
searcher = LeannSearcher("/path/to/my-index.leann")
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
```
Trade-offs:
- Lower latency and fewer network hops at query time
- Significantly higher storage (10100× vs selective recomputation)
- Slightly larger memory footprint during build and search
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
- HNSW
```text
recompute=True: search_time=0.818s, size=1.1MB
recompute=False: search_time=0.012s, size=16.6MB
```
- DiskANN
```text
recompute=True: search_time=0.041s, size=5.9MB
recompute=False: search_time=0.013s, size=24.6MB
```
Conclusion:
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
## Further Reading
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)

11
docs/contributing.md Normal file
View File

@@ -0,0 +1,11 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results

View File

@@ -7,4 +7,4 @@ You can speed up the process by using a lightweight embedding model. Add this to
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

View File

@@ -3,10 +3,9 @@
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
@@ -14,10 +13,10 @@
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
## 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
- **Comprehensive examples** - From basic usage to production deployment

View File

@@ -1,149 +0,0 @@
# LEANN Grep Search Usage Guide
## Overview
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
## Basic Usage
### Simple Grep Search
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("your_index_path")
# Exact text search
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:100]}...")
print("-" * 40)
```
### Comparison: Semantic vs Grep Search
```python
# Semantic search - finds conceptually similar content
semantic_results = searcher.search("machine learning algorithms", top_k=3)
# Grep search - finds exact text matches
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
```
## When to Use Grep Search
### Use Cases
- **Code Search**: Finding specific function definitions, class names, or variable references
- **Error Debugging**: Locating exact error messages or stack traces
- **Documentation**: Finding specific API endpoints or exact terminology
### Examples
```python
# Find function definitions
functions = searcher.search("def __init__", use_grep=True)
# Find import statements
imports = searcher.search("from sklearn import", use_grep=True)
# Find specific error types
errors = searcher.search("FileNotFoundError", use_grep=True)
# Find TODO comments
todos = searcher.search("TODO:", use_grep=True)
# Find configuration entries
configs = searcher.search("server_port=", use_grep=True)
```
## Technical Details
### How It Works
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
3. **Result Processing**: Parses JSON lines and extracts text and metadata
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
### Search Process
```
Query: "def train_model"
grep -i -n "def train_model" documents.leann.passages.jsonl
Parse matching JSON lines
Calculate scores based on term frequency
Return top_k results
```
### Scoring Algorithm
```python
# Term frequency in document
score = text.lower().count(query.lower())
```
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
## Error Handling
### Common Issues
#### Grep Command Not Found
```
RuntimeError: grep command not found. Please install grep or use semantic search.
```
**Solution**: Install grep on your system:
- **Ubuntu/Debian**: `sudo apt-get install grep`
- **macOS**: grep is pre-installed
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
#### No Results Found
```python
# Check if your query exists in the raw data
results = searcher.search("your_query", use_grep=True)
if not results:
print("No exact matches found. Try:")
print("1. Check spelling and case")
print("2. Use partial terms")
print("3. Switch to semantic search")
```
## Complete Example
```python
#!/usr/bin/env python3
"""
Grep Search Example
Demonstrates grep search for exact text matching.
"""
from leann.api import LeannSearcher
def demonstrate_grep_search():
# Initialize searcher
searcher = LeannSearcher("my_index")
print("=== Function Search ===")
functions = searcher.search("def __init__", use_grep=True, top_k=5)
for i, result in enumerate(functions, 1):
print(f"{i}. Score: {result.score}")
print(f" Preview: {result.text[:60]}...")
print()
print("=== Error Search ===")
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
for result in errors:
print(f"Content: {result.text.strip()}")
print("-" * 40)
if __name__ == "__main__":
demonstrate_grep_search()
```

View File

@@ -1,300 +0,0 @@
# LEANN Metadata Filtering Usage Guide
## Overview
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
## Basic Usage
### Adding Metadata to Your Documents
When building your index, add metadata to each text chunk:
```python
from leann.api import LeannBuilder
builder = LeannBuilder("hnsw")
# Add text with metadata
builder.add_text(
text="Chapter 1: Alice falls down the rabbit hole",
metadata={
"chapter": 1,
"character": "Alice",
"themes": ["adventure", "curiosity"],
"word_count": 150
}
)
builder.build_index("alice_in_wonderland_index")
```
### Searching with Metadata Filters
Use the `metadata_filters` parameter in search calls:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("alice_in_wonderland_index")
# Search with filters
results = searcher.search(
query="What happens to Alice?",
top_k=10,
metadata_filters={
"chapter": {"<=": 5}, # Only chapters 1-5
"spoiler_level": {"!=": "high"} # No high spoilers
}
)
```
## Filter Syntax
### Basic Structure
```python
metadata_filters = {
"field_name": {"operator": value},
"another_field": {"operator": value}
}
```
### Supported Operators
#### Comparison Operators
- `"=="`: Equal to
- `"!="`: Not equal to
- `"<"`: Less than
- `"<="`: Less than or equal
- `">"`: Greater than
- `">="`: Greater than or equal
```python
# Examples
{"chapter": {"==": 1}} # Exactly chapter 1
{"page": {">": 100}} # Pages after 100
{"rating": {">=": 4.0}} # Rating 4.0 or higher
{"word_count": {"<": 500}} # Short passages
```
#### Membership Operators
- `"in"`: Value is in list
- `"not_in"`: Value is not in list
```python
# Examples
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
```
#### String Operators
- `"contains"`: String contains substring
- `"starts_with"`: String starts with prefix
- `"ends_with"`: String ends with suffix
```python
# Examples
{"title": {"contains": "alice"}} # Title contains "alice"
{"filename": {"ends_with": ".py"}} # Python files
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
```
#### Boolean Operators
- `"is_true"`: Field is truthy
- `"is_false"`: Field is falsy
```python
# Examples
{"is_published": {"is_true": True}} # Published content
{"is_draft": {"is_false": False}} # Not drafts
```
### Multiple Operators on Same Field
You can apply multiple operators to the same field (AND logic):
```python
metadata_filters = {
"word_count": {
">=": 100, # At least 100 words
"<=": 500 # At most 500 words
}
}
```
### Compound Filters
Multiple fields are combined with AND logic:
```python
metadata_filters = {
"chapter": {"<=": 10}, # Up to chapter 10
"character": {"==": "Alice"}, # About Alice
"spoiler_level": {"!=": "high"} # No major spoilers
}
```
## Use Case Examples
### 1. Spoiler-Free Book Search
```python
# Reader has only read up to chapter 5
def search_spoiler_free(query, max_chapter):
return searcher.search(
query=query,
metadata_filters={
"chapter": {"<=": max_chapter},
"spoiler_level": {"in": ["none", "low"]}
}
)
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
```
### 2. Document Management by Date
```python
# Find recent documents
recent_docs = searcher.search(
query="project updates",
metadata_filters={
"date": {">=": "2024-01-01"},
"document_type": {"==": "report"}
}
)
```
### 3. Code Search by File Type
```python
# Search only Python files
python_code = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
### 4. Content Filtering by Audience
```python
# Age-appropriate content
family_content = searcher.search(
query="adventure stories",
metadata_filters={
"age_rating": {"in": ["G", "PG"]},
"content_warnings": {"not_in": ["violence", "adult_themes"]}
}
)
```
### 5. Multi-Book Series Management
```python
# Search across first 3 books only
early_series = searcher.search(
query="character development",
metadata_filters={
"series": {"==": "Harry Potter"},
"book_number": {"<=": 3}
}
)
```
## Running the Example
You can see metadata filtering in action with our spoiler-free book RAG example:
```bash
# Don't forget to set up the environment
uv venv
source .venv/bin/activate
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
export OPENAI_API_KEY="your-api-key-here"
# Run the spoiler-free book RAG example
uv run examples/spoiler_free_book_rag.py
```
This example demonstrates:
- Building an index with metadata (chapter numbers, characters, themes, locations)
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
- Different scenarios for readers at various points in the book
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
## Advanced Patterns
### Custom Chunking with metadata
```python
def chunk_book_with_metadata(book_text, book_info):
chunks = []
for chapter_num, chapter_text in parse_chapters(book_text):
# Extract entities, themes, etc.
characters = extract_characters(chapter_text)
themes = classify_themes(chapter_text)
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
# Create chunks with rich metadata
for paragraph in split_paragraphs(chapter_text):
chunks.append({
"text": paragraph,
"metadata": {
"book_title": book_info["title"],
"chapter": chapter_num,
"characters": characters,
"themes": themes,
"spoiler_level": spoiler_level,
"word_count": len(paragraph.split()),
"reading_level": calculate_reading_level(paragraph)
}
})
return chunks
```
## Performance Considerations
### Efficient Filtering Strategies
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
### Best Practices
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
### Adding Metadata to Existing Indices
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
```python
# Read existing passages and add metadata
def add_metadata_to_existing_chunks(chunks):
for chunk in chunks:
# Extract or assign metadata based on content
chunk["metadata"] = extract_metadata(chunk["text"])
return chunks
# Rebuild index with metadata
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
builder = LeannBuilder("hnsw")
for chunk in enhanced_chunks:
builder.add_text(chunk["text"], chunk["metadata"])
builder.build_index("enhanced_index")
```

View File

@@ -1,75 +0,0 @@
# Normalized Embeddings Support in LEANN
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
## What are Normalized Embeddings?
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
## Automatic Detection
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
1. **Automatically set `distance_metric="cosine"`** if not specified
2. **Show a warning** if you manually specify a different distance metric
3. **Provide optimal search performance** with the correct metric
## Supported Normalized Embedding Models
### OpenAI
All OpenAI text embedding models are normalized:
- `text-embedding-ada-002`
- `text-embedding-3-small`
- `text-embedding-3-large`
### Voyage AI
All Voyage AI embedding models are normalized:
- `voyage-2`
- `voyage-3`
- `voyage-large-2`
- `voyage-multilingual-2`
- `voyage-code-2`
### Cohere
All Cohere embedding models are normalized:
- `embed-english-v3.0`
- `embed-multilingual-v3.0`
- `embed-english-light-v3.0`
- `embed-multilingual-light-v3.0`
## Example Usage
```python
from leann.api import LeannBuilder
# Automatic detection - will use cosine distance
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai"
)
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
# Automatically setting distance_metric='cosine'
# Manual override (not recommended)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
distance_metric="mips" # Will show warning
)
# Warning: Using 'mips' distance metric with normalized embeddings...
```
## Non-Normalized Embeddings
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
## Why This Matters
Using the wrong distance metric with normalized embeddings can lead to:
- **Poor search quality** due to HNSW's early termination with narrow score ranges
- **Incorrect ranking** of search results
- **Suboptimal performance** compared to using the correct metric
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.

View File

@@ -2,8 +2,8 @@
## 🎯 Q2 2025
- [X] HNSW backend integration
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
@@ -18,4 +18,4 @@
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion
- [ ] Query rewrtiting, rerank and expansion

View File

@@ -3,15 +3,14 @@
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import gc
import logging
import os
import subprocess
import sys
import time
from pathlib import Path
import psutil
import gc
import subprocess
from pathlib import Path
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
@@ -62,7 +61,7 @@ def test_faiss_hnsw():
try:
result = subprocess.run(
[sys.executable, "benchmarks/faiss_only.py"],
[sys.executable, "examples/faiss_only.py"],
capture_output=True,
text=True,
timeout=300,
@@ -84,7 +83,9 @@ def test_faiss_hnsw():
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
peak_memory = float(
line.split("Peak Memory:")[1].split("MB")[0].strip()
)
return {"peak_memory": peak_memory}
@@ -110,12 +111,13 @@ def test_leann_hnsw():
tracker.checkpoint("After imports")
from leann.api import LeannBuilder
from llama_index.core import SimpleDirectoryReader
from leann.api import LeannBuilder, LeannSearcher
# Load and parse documents
documents = SimpleDirectoryReader(
"data",
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
@@ -195,14 +197,16 @@ def test_leann_hnsw():
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
@@ -300,15 +304,21 @@ def main():
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
print(
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
)
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
print(
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
)
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
print(
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
)
else:
print(" Storage Size: similar")
else:

View File

@@ -1,5 +1,5 @@
The Project Gutenberg eBook of Pride and Prejudice
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
Updated editions will replace the previous one—the old editions will
be renamed.
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
at www.gutenberg.org. If you
are not located in the United States, you will have to check the laws
of the country where you are located before using this eBook.
1.E.2. If an individual Project Gutenberg™ electronic work is
derived from texts not protected by U.S. copyright law (does not
contain a notice indicating that it is posted with permission of the
@@ -14724,7 +14724,7 @@ provided that:
Gutenberg Literary Archive Foundation at the address specified in
Section 4, “Information about donations to the Project Gutenberg
Literary Archive Foundation.”
• You provide a full refund of any money paid by a user who notifies
you in writing (or by e-mail) within 30 days of receipt that s/he
does not agree to the terms of the full Project Gutenberg™
@@ -14732,15 +14732,15 @@ provided that:
copies of the works possessed in a physical medium and discontinue
all use of and all access to other copies of Project Gutenberg™
works.
• You provide, in accordance with paragraph 1.F.3, a full refund of
any money paid for a work or a replacement copy, if a defect in the
electronic work is discovered and reported to you within 90 days of
receipt of the work.
• You comply with all other terms of this agreement for free
distribution of Project Gutenberg™ works.
1.E.9. If you wish to charge a fee or distribute a Project
Gutenberg™ electronic work or group of works on different terms than
@@ -14903,3 +14903,5 @@ This website includes information about Project Gutenberg™,
including how to make donations to the Project Gutenberg Literary
Archive Foundation, how to help produce our new eBooks, and how to
subscribe to our email newsletter to hear about new eBooks.

146
examples/document_search.py Normal file
View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
Document search demo with recompute mode
"""
import os
from pathlib import Path
import shutil
import time
# Import backend packages to trigger plugin registration
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# Import upper-level API from leann-core
from leann.api import LeannBuilder, LeannSearcher, LeannChat
def load_sample_documents():
"""Create sample documents for demonstration"""
docs = [
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
]
return docs
def main():
print("==========================================================")
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
print("==========================================================")
INDEX_DIR = Path("./test_indices")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
BACKEND_TO_TEST = "diskann"
if INDEX_DIR.exists():
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
shutil.rmtree(INDEX_DIR)
# --- 1. Build index ---
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
builder = LeannBuilder(
backend_name=BACKEND_TO_TEST,
graph_degree=32,
complexity=64
)
documents = load_sample_documents()
print(f"Loaded {len(documents)} sample documents.")
for doc in documents:
builder.add_text(doc["content"], metadata={"title": doc["title"]})
builder.build_index(INDEX_PATH)
print(f"\nIndex built!")
# --- 2. Basic search demo ---
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
searcher = LeannSearcher(index_path=INDEX_PATH)
query = "What is machine learning?"
print(f"\nQuery: '{query}'")
print("\n--- Basic search mode (PQ computation) ---")
start_time = time.time()
results = searcher.search(query, top_k=2)
basic_time = time.time() - start_time
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<")
for i, res in enumerate(results, 1):
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...")
print("\n--- Recompute search mode (get real embeddings via network) ---")
# Configure recompute parameters
recompute_params = {
"recompute_beighbor_embeddings": True, # Enable network recomputation
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
"skip_search_reorder": True, # Skip search reordering
"dedup_node_dis": True, # Enable node distance deduplication
"prune_ratio": 0.1, # Pruning ratio 10%
"batch_recompute": False, # Don't use batch recomputation
"global_pruning": False, # Don't use global pruning
"zmq_port": 5555, # ZMQ port
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
}
print("Recompute parameter configuration:")
for key, value in recompute_params.items():
print(f" {key}: {value}")
print(f"\n🔄 Executing Recompute search...")
try:
start_time = time.time()
recompute_results = searcher.search(query, top_k=2, **recompute_params)
recompute_time = time.time() - start_time
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results
print(f"\n--- Result comparison ---")
print(f"Basic search time: {basic_time:.3f} seconds")
print(f"Recompute time: {recompute_time:.3f} seconds")
print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))):
basic_score = results[i].score
recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
if recompute_time > basic_time:
print(f"✅ Recompute mode working correctly (more accurate but slower)")
else:
print(f" Recompute time is unusually fast, network recomputation may not be enabled")
except Exception as e:
print(f"❌ Recompute search failed: {e}")
print("This usually indicates an embedding server connection issue")
# --- 4. Chat demo ---
print(f"\n[PHASE 4] Starting chat session...")
chat = LeannChat(index_path=INDEX_PATH)
chat_response = chat.ask(query)
print(f"You: {query}")
print(f"Leann: {chat_response}")
print("\n==========================================================")
print("✅ Demo finished successfully!")
print("==========================================================")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,122 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str = None) -> List[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, dirnames, filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
if part.get_content_type() == "text/html" and not self.include_html:
continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs

View File

@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
import logging
from pathlib import Path
from typing import Any
from typing import Any, Dict, List, Optional
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
@@ -27,7 +27,11 @@ class MboxReader(BaseReader):
"""
DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
"Date: {_date}\n"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
)
def __init__(
@@ -41,7 +45,9 @@ class MboxReader(BaseReader):
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
raise ImportError(
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs)
self.max_count = max_count
@@ -50,9 +56,9 @@ class MboxReader(BaseReader):
def load_data(
self,
file: Path,
extra_info: dict | None = None,
fs: AbstractFileSystem | None = None,
) -> list[Document]:
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse file into string."""
# Import required libraries
import mailbox
@@ -68,7 +74,7 @@ class MboxReader(BaseReader):
)
i = 0
results: list[str] = []
results: List[str] = []
# Load file using mailbox
bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
@@ -118,7 +124,7 @@ class MboxReader(BaseReader):
class EmlxMboxReader(MboxReader):
"""
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory
2. Converting them to mbox format in memory
@@ -128,13 +134,13 @@ class EmlxMboxReader(MboxReader):
def load_data(
self,
directory: Path,
extra_info: dict | None = None,
fs: AbstractFileSystem | None = None,
) -> list[Document]:
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic."""
import os
import tempfile
import os
if fs:
logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading "
@@ -144,37 +150,37 @@ class EmlxMboxReader(MboxReader):
# Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files:
logger.warning(f"No .emlx files found in {directory}")
return []
# Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format
for emlx_file in emlx_files:
try:
# Read the .emlx file
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx format: first line is length, rest is email content
lines = content.split("\n", 1)
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}")
continue
# Close the temporary file so MboxReader can read it
temp_mbox.close()
try:
# Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs)
@@ -182,5 +188,5 @@ class EmlxMboxReader(MboxReader):
# Clean up temporary file
try:
os.unlink(temp_mbox_path)
except OSError:
pass
except:
pass

View File

@@ -1,11 +1,11 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import os
import sys
import time
import psutil
import gc
import os
def get_memory_usage():
@@ -37,20 +37,20 @@ def main():
import faiss
except ImportError:
print("Faiss is not installed.")
print(
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
)
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
sys.exit(1)
from llama_index.core import (
Settings,
SimpleDirectoryReader,
StorageContext,
VectorStoreIndex,
StorageContext,
Settings,
node_parser,
Document,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
@@ -65,7 +65,7 @@ def main():
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"data",
"examples/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
@@ -90,9 +90,8 @@ def main():
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print("Index loaded from ./storage_faiss")
print(f"Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
@@ -100,18 +99,19 @@ def main():
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context, transformations=[node_parser]
documents,
storage_context=storage_context,
transformations=[node_parser]
)
tracker.checkpoint("After index building")
@@ -124,10 +124,10 @@ def main():
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
@@ -141,7 +141,7 @@ def main():
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")

View File

@@ -0,0 +1,286 @@
import os
import asyncio
import argparse
try:
import dotenv
dotenv.load_dotenv()
except ModuleNotFoundError:
# python-dotenv is not installed; skip loading environment variables
dotenv = None
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
# dotenv.load_dotenv() # handled above if python-dotenv is available
# Default Chrome profile path
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
"""
Create LEANN index from multiple Chrome profile data sources.
Args:
profile_dirs: List of Path objects pointing to Chrome profile directories
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process per profile
"""
print("Creating LEANN index from multiple Chrome profile data sources...")
# Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Chrome profile directory
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
try:
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_count
)
if documents:
print(f"Loaded {len(documents)} history documents from {profile_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {profile_dir}")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = node.get_content()
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
"""
Create LEANN index from Chrome history data.
Args:
profile_path: Path to the Chrome profile directory (optional, uses default if None)
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process
"""
print("Creating LEANN index from Chrome history data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader()
documents = reader.load_data(
chrome_profile_path=profile_path,
max_count=max_count
)
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} history documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={
"temperature": 0.0,
"max_tokens": 1000
}
)
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./google_history_index",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')
parser.add_argument('--query', type=str, default=None,
help='Single query to run (default: runs example queries)')
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
help='Automatically find all Chrome profiles (default: True)')
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
print(f"Using Chrome profile: {args.chrome_profile}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Find Chrome profile directories
from history_data.history import ChromeHistoryReader
if args.auto_find_profiles:
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return
else:
# Use single specified profile
profile_path = Path(args.chrome_profile)
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"What websites did I visit about machine learning?",
"Find my search history about programming"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,35 +0,0 @@
"""
Grep Search Example
Shows how to use grep-based text search instead of semantic search.
Useful when you need exact text matches rather than meaning-based results.
"""
from leann import LeannSearcher
# Load your index
searcher = LeannSearcher("my-documents.leann")
# Regular semantic search
print("=== Semantic Search ===")
results = searcher.search("machine learning algorithms", top_k=3)
for result in results:
print(f"Score: {result.score:.3f}")
print(f"Text: {result.text[:80]}...")
print()
# Grep-based search for exact text matches
print("=== Grep Search ===")
results = searcher.search("def train_model", top_k=3, use_grep=True)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:80]}...")
print()
# Find specific error messages
error_results = searcher.search("FileNotFoundError", use_grep=True)
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
# Search for function definitions
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
print(f"Found {len(func_results)} class definitions")

View File

@@ -1,3 +1,3 @@
from .history import ChromeHistoryReader
__all__ = ["ChromeHistoryReader"]
__all__ = ['ChromeHistoryReader']

View File

@@ -1,81 +1,77 @@
import os
import sqlite3
import os
from pathlib import Path
from typing import Any
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader):
"""
Chrome browser history reader that extracts browsing data from SQLite database.
Reads Chrome history from the default Chrome profile location and creates documents
with embedded metadata similar to the email reader structure.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load Chrome history data from the default Chrome profile location.
Args:
input_dir: Not used for Chrome history (kept for compatibility)
**load_kwargs:
max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory.
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
# Default Chrome profile path on macOS
if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser(
"~/Library/Application Support/Google/Chrome/Default"
)
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return docs
try:
# Connect to the Chrome history database
print(f"Connecting to database: {history_db_path}")
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
# Query to get browsing history with metadata (removed created_time column)
query = """
SELECT
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
FROM urls
ORDER BY last_visit_time DESC
"""
print(f"Executing query on database: {history_db_path}")
cursor.execute(query)
rows = cursor.fetchall()
print(f"Query returned {len(rows)} rows")
count = 0
for row in rows:
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text
doc_content = f"""
[Title]: {title}
@@ -84,43 +80,38 @@ class ChromeHistoryReader(BaseReader):
[Visit times]: {visit_count}
[Typed times]: {typed_count}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={"title": title[0:150]})
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
# if len(title) > 150:
# print(f"Title is too long: {title}")
docs.append(doc)
count += 1
conn.close()
print(f"Loaded {len(docs)} Chrome history documents")
except Exception as e:
print(f"Error reading Chrome history: {e}")
# add you may need to close your browser to make the database file available
# also highlight in red
print(
"\033[91mYou may need to close your browser to make the database file available\033[0m"
)
return docs
return docs
@staticmethod
def find_chrome_profiles() -> list[Path]:
def find_chrome_profiles() -> List[Path]:
"""
Find all Chrome profile directories.
Returns:
List of Path objects pointing to Chrome profile directories
"""
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
profile_dirs = []
if not chrome_base_path.exists():
print(f"Chrome directory not found at: {chrome_base_path}")
return profile_dirs
# Find all profile directories
for profile_dir in chrome_base_path.iterdir():
if profile_dir.is_dir() and profile_dir.name != "System Profile":
@@ -128,59 +119,53 @@ class ChromeHistoryReader(BaseReader):
if history_path.exists():
profile_dirs.append(profile_dir)
print(f"Found Chrome profile: {profile_dir}")
print(f"Found {len(profile_dirs)} Chrome profiles")
return profile_dirs
@staticmethod
def export_history_to_file(
output_file: str = "chrome_history_export.txt", max_count: int = 1000
):
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
"""
Export Chrome history to a text file using the same SQL query format.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
"""
chrome_profile_path = os.path.expanduser(
"~/Library/Application Support/Google/Chrome/Default"
)
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return
try:
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
query = """
SELECT
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
FROM urls
ORDER BY last_visit_time DESC
LIMIT ?
"""
cursor.execute(query, (max_count,))
rows = cursor.fetchall()
with open(output_file, "w", encoding="utf-8") as f:
with open(output_file, 'w', encoding='utf-8') as f:
for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row
f.write(
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
)
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
conn.close()
print(f"Exported {len(rows)} history entries to {output_file}")
except Exception as e:
print(f"Error exporting Chrome history: {e}")
print(f"Error exporting Chrome history: {e}")

View File

@@ -2,31 +2,30 @@ import json
import os
import re
import subprocess
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import List, Any, Dict, Optional
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
from datetime import datetime
class WeChatHistoryReader(BaseReader):
"""
WeChat chat history reader that extracts chat data from exported JSON files.
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
and creates documents with embedded metadata similar to the Chrome history reader structure.
Also includes utilities for automatic WeChat chat history export.
"""
def __init__(self) -> None:
"""Initialize."""
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
def check_wechat_running(self) -> bool:
"""Check if WeChat is currently running."""
try:
@@ -34,30 +33,24 @@ class WeChatHistoryReader(BaseReader):
return result.returncode == 0
except Exception:
return False
def install_wechattweak(self) -> bool:
"""Install WeChatTweak CLI tool."""
try:
# Create wechat-exporter directory if it doesn't exist
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...")
subprocess.run(
[
"curl",
"-L",
"-o",
str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
],
check=True,
)
subprocess.run([
"curl", "-L", "-o", str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
], check=True)
# Make executable
wechattweak_path.chmod(0o755)
# Install WeChatTweak
print("Installing WeChatTweak...")
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
@@ -65,7 +58,7 @@ class WeChatHistoryReader(BaseReader):
except Exception as e:
print(f"Error installing WeChatTweak: {e}")
return False
def restart_wechat(self):
"""Restart WeChat to apply WeChatTweak."""
try:
@@ -76,325 +69,302 @@ class WeChatHistoryReader(BaseReader):
time.sleep(5) # Wait for WeChat to start
except Exception as e:
print(f"Error restarting WeChat: {e}")
def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available."""
try:
result = subprocess.run(
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
capture_output=True,
text=True,
timeout=5,
)
result = subprocess.run([
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
], capture_output=True, text=True, timeout=5)
return result.returncode == 0 and result.stdout.strip()
except Exception:
return False
def _extract_readable_text(self, content: str) -> str:
"""
Extract readable text from message content, removing XML and system messages.
Args:
content: The raw message content (can be string or dict)
Returns:
Cleaned, readable text
"""
if not content:
return ""
# Handle dictionary content (like quoted messages)
if isinstance(content, dict):
# Extract text from dictionary structure
text_parts = []
if "title" in content:
text_parts.append(str(content["title"]))
if "quoted" in content:
text_parts.append(str(content["quoted"]))
if "content" in content:
text_parts.append(str(content["content"]))
if "text" in content:
text_parts.append(str(content["text"]))
if 'title' in content:
text_parts.append(str(content['title']))
if 'quoted' in content:
text_parts.append(str(content['quoted']))
if 'content' in content:
text_parts.append(str(content['content']))
if 'text' in content:
text_parts.append(str(content['text']))
if text_parts:
return " | ".join(text_parts)
else:
# If we can't extract meaningful text from dict, return empty
return ""
# Handle string content
if not isinstance(content, str):
return ""
# Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If it's just XML or system message, return empty
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
return ""
return clean_content.strip()
def _is_text_message(self, content: str) -> bool:
"""
Check if a message contains readable text content.
Args:
content: The message content (can be string or dict)
Returns:
True if the message contains readable text, False otherwise
"""
if not content:
return False
# Handle dictionary content
if isinstance(content, dict):
# Check if dict has any readable text fields
text_fields = ["title", "quoted", "content", "text"]
text_fields = ['title', 'quoted', 'content', 'text']
for field in text_fields:
if content.get(field):
if field in content and content[field]:
return True
return False
# Handle string content
if not isinstance(content, str):
return False
# Skip image messages (contain XML with img tags)
if "<img" in content and "cdnurl" in content:
if '<img' in content and 'cdnurl' in content:
return False
# Skip emoji messages (contain emoji XML tags)
if "<emoji" in content and "productid" in content:
if '<emoji' in content and 'productid' in content:
return False
# Skip voice messages
if "<voice" in content:
if '<voice' in content:
return False
# Skip video messages
if "<video" in content:
if '<video' in content:
return False
# Skip file messages
if "<appmsg" in content and "appid" in content:
if '<appmsg' in content and 'appid' in content:
return False
# Skip system messages (like "recalled a message")
if "recalled a message" in content:
if 'recalled a message' in content:
return False
# Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
return True
return False
def _concatenate_messages(
self,
messages: list[dict],
max_length: int = 128,
time_window_minutes: int = 30,
overlap_messages: int = 0,
) -> list[dict]:
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
"""
Concatenate messages based on length and time rules.
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
List of concatenated message groups
"""
if not messages:
return []
concatenated_groups = []
current_group = []
current_length = 0
last_timestamp = None
for message in messages:
# Extract message info
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
message.get("fromUser", "")
message.get("toUser", "")
message.get("isSentFromSelf", False)
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Skip empty messages
if not readable_text.strip():
continue
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
if current_group:
concatenated_groups.append(
{
"messages": current_group,
"total_length": current_length,
"start_time": current_group[0].get("createTime", 0),
"end_time": current_group[-1].get("createTime", 0),
}
)
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Check length constraint (only if max_length != -1)
message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append(
{
"messages": current_group,
"total_length": current_length,
"start_time": current_group[0].get("createTime", 0),
"end_time": current_group[-1].get("createTime", 0),
}
)
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Add message to current group
current_group.append(message)
current_length += message_length
last_timestamp = create_time
# Add the last group if it exists
if current_group:
concatenated_groups.append(
{
"messages": current_group,
"total_length": current_length,
"start_time": current_group[0].get("createTime", 0),
"end_time": current_group[-1].get("createTime", 0),
}
)
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
return concatenated_groups
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
"""
Create concatenated content from a group of messages.
Args:
message_group: Dictionary containing messages and metadata
contact_name: Name of the contact
Returns:
Formatted concatenated content
"""
messages = message_group["messages"]
start_time = message_group["start_time"]
end_time = message_group["end_time"]
messages = message_group['messages']
start_time = message_group['start_time']
end_time = message_group['end_time']
# Format timestamps
if start_time:
try:
start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
start_time_str = str(start_time)
else:
start_time_str = "Unknown"
if end_time:
try:
end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
end_time_str = str(end_time)
else:
end_time_str = "Unknown"
# Build concatenated message content
message_parts = []
for message in messages:
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
is_sent_from_self = message.get("isSentFromSelf", False)
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Format individual message
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
# Create final document content
doc_content = f"""
Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text}
"""
# TODO @yichuan give better format and rich info here!
# TODO @yichuan give better format and rich info here!
doc_content = f"""
{concatenated_text}
"""
return doc_content, contact_name
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load WeChat chat history data from exported JSON files.
Args:
input_dir: Directory containing exported WeChat JSON files
**load_kwargs:
@@ -406,104 +376,97 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
include_non_text = load_kwargs.get("include_non_text", False)
concatenate_messages = load_kwargs.get("concatenate_messages", False)
max_length = load_kwargs.get("max_length", 1000)
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
include_non_text = load_kwargs.get('include_non_text', False)
concatenate_messages = load_kwargs.get('concatenate_messages', False)
max_length = load_kwargs.get('max_length', 1000)
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
# Default WeChat export path
if wechat_export_dir is None:
wechat_export_dir = "./wechat_export_test"
if not os.path.exists(wechat_export_dir):
print(f"WeChat export directory not found at: {wechat_export_dir}")
return docs
try:
# Find all JSON files in the export directory
json_files = list(Path(wechat_export_dir).glob("*.json"))
print(f"Found {len(json_files)} WeChat chat history files")
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, encoding="utf-8") as f:
with open(json_file, 'r', encoding='utf-8') as f:
chat_data = json.load(f)
# Extract contact name from filename
contact_name = json_file.stem
if concatenate_messages:
# Filter messages to only include readable text messages
readable_messages = []
for message in chat_data:
try:
content = message.get("content", "")
content = message.get('content', '')
if not include_non_text and not self._is_text_message(content):
continue
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
readable_messages.append(message)
except Exception as e:
print(f"Error processing message in {json_file}: {e}")
continue
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=max_length,
time_window_minutes=time_window_minutes,
overlap_messages=0, # No overlap between groups
readable_messages,
max_length=-1,
time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups
)
# Create documents from concatenated groups
for message_group in message_groups:
if count >= max_count and max_count > 0:
break
doc_content, contact_name = self._create_concatenated_content(
message_group, contact_name
)
doc = Document(
text=doc_content,
metadata={"contact_name": contact_name},
)
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
docs.append(doc)
count += 1
print(
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
)
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
else:
# Original single-message processing
for message in chat_data:
if count >= max_count and max_count > 0:
break
# Extract message information
message.get("fromUser", "")
message.get("toUser", "")
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
is_sent_from_self = message.get("isSentFromSelf", False)
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Handle content that might be dict or string
try:
# Check if this is a readable text message
if not include_non_text and not self._is_text_message(content):
continue
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
@@ -512,17 +475,17 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
# Skip messages that cause processing errors
print(f"Error processing message in {json_file}: {e}")
continue
# Convert timestamp to readable format
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
# Create document content with metadata header and contact info
doc_content = f"""
Contact: {contact_name}
@@ -530,66 +493,57 @@ Is sent from self: {is_sent_from_self}
Time: {time_str}
Message: {readable_text if readable_text else message_text}
"""
# Create document with embedded metadata
doc = Document(
text=doc_content, metadata={"contact_name": contact_name}
)
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error reading {json_file}: {e}")
continue
print(f"Loaded {len(docs)} WeChat chat documents")
except Exception as e:
print(f"Error reading WeChat history: {e}")
return docs
return docs
@staticmethod
def find_wechat_export_dirs() -> list[Path]:
def find_wechat_export_dirs() -> List[Path]:
"""
Find all WeChat export directories.
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for common export directory names
possible_dirs = [
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export"),
Path("./chat_export")
]
for export_dir in possible_dirs:
if export_dir.exists() and export_dir.is_dir():
json_files = list(export_dir.glob("*.json"))
if json_files:
export_dirs.append(export_dir)
print(
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
)
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs
@staticmethod
def export_chat_to_file(
output_file: str = "wechat_chat_export.txt",
max_count: int = 1000,
export_dir: str | None = None,
include_non_text: bool = False,
):
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
"""
Export WeChat chat history to a text file.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
@@ -598,36 +552,36 @@ Message: {readable_text if readable_text else message_text}
"""
if export_dir is None:
export_dir = "./wechat_export_test"
if not os.path.exists(export_dir):
print(f"WeChat export directory not found at: {export_dir}")
return
try:
json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, "w", encoding="utf-8") as f:
with open(output_file, 'w', encoding='utf-8') as f:
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, encoding="utf-8") as json_f:
with open(json_file, 'r', encoding='utf-8') as json_f:
chat_data = json.load(json_f)
contact_name = json_file.stem
f.write(f"\n=== Chat with {contact_name} ===\n")
for message in chat_data:
if count >= max_count and max_count > 0:
break
from_user = message.get("fromUser", "")
content = message.get("content", "")
message_text = message.get("message", "")
create_time = message.get("createTime", 0)
from_user = message.get('fromUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
# Skip non-text messages unless requested
if not include_non_text:
reader = WeChatHistoryReader()
@@ -637,90 +591,83 @@ Message: {readable_text if readable_text else message_text}
if not readable_text:
continue
message_text = readable_text
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
f.write(f"[{time_str}] {from_user}: {message_text}\n")
count += 1
except Exception as e:
print(f"Error processing {json_file}: {e}")
continue
print(f"Exported {count} chat entries to {output_file}")
except Exception as e:
print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
"""
Export WeChat chat history using wechat-exporter tool.
Args:
export_dir: Directory to save exported chat history
Returns:
Path to export directory if successful, None otherwise
"""
try:
import subprocess
import sys
# Create export directory
export_path = Path(export_dir)
export_path.mkdir(exist_ok=True)
print(f"Exporting WeChat chat history to {export_path}...")
# Check if wechat-exporter directory exists
if not self.wechat_exporter_dir.exists():
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
return None
# Install requirements if needed
requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists():
print("Installing wechat-exporter requirements...")
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
subprocess.run([
"uv", "pip", "install", "-r", str(requirements_file)
], check=True)
# Run the export command
print("Running wechat-exporter...")
result = subprocess.run(
[
sys.executable,
str(self.wechat_exporter_dir / "main.py"),
"export-all",
str(export_path),
],
capture_output=True,
text=True,
check=True,
)
result = subprocess.run([
sys.executable, str(self.wechat_exporter_dir / "main.py"),
"export-all", str(export_path)
], capture_output=True, text=True, check=True)
print("Export command output:")
print(result.stdout)
if result.stderr:
print("Export errors:")
print(result.stderr)
# Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json"))
print(
f"Successfully exported {len(json_files)} chat history files to {export_path}"
)
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
return export_path
else:
print("Export completed but no JSON files found")
return None
except subprocess.CalledProcessError as e:
print(f"Export command failed: {e}")
print(f"Command output: {e.stdout}")
@@ -731,18 +678,18 @@ Message: {readable_text if readable_text else message_text}
print("Please ensure WeChat is running and WeChatTweak is installed.")
return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
"""
Find existing WeChat exports or create new ones.
Args:
export_dir: Directory to save exported chat history if needed
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for existing exports in common locations
possible_export_dirs = [
Path("./wechat_database_export"),
@@ -750,25 +697,23 @@ Message: {readable_text if readable_text else message_text}
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export"),
Path("./chat_export")
]
for export_dir_path in possible_export_dirs:
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
export_dirs.append(export_dir_path)
print(f"Found existing export: {export_dir_path}")
# If no existing exports, try to export automatically
if not export_dirs:
print("No existing WeChat exports found. Starting direct export...")
# Try to export using wechat-exporter
exported_path = self.export_wechat_chat_history(export_dir)
if exported_path:
export_dirs = [exported_path]
else:
print(
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
)
return export_dirs
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
return export_dirs

View File

@@ -0,0 +1,291 @@
import os
import sys
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
# Auto-detect user's mail path
def get_mail_path():
"""Get the mail path for the current user"""
home_dir = os.path.expanduser("~")
return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from multiple mail data sources.
Args:
messages_dirs: List of Path objects pointing to Messages directories
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process per directory
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from multiple mail data sources...")
# Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Messages directory
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
try:
documents = reader.load_data(messages_dir)
if documents:
print(f"Loaded {len(documents)} email documents from {messages_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {messages_dir}")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = node.get_content()
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
all_texts.append(text)
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from mail data.
Args:
mail_path: Path to the mail directory
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from mail data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
documents = reader.load_data(Path(mail_path))
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path,
llm_config={"type": "openai", "model": "gpt-4o"})
print(f"You: {query}")
import time
start_time = time.time()
chat_response = chat.ask(
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1,
)
end_time = time.time()
# print(f"Time taken: {end_time - start_time} seconds")
# highlight the answer
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
help='Single query to run (default: runs example queries)')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
help='Embedding model to use (default: facebook/contriever)')
args = parser.parse_args()
print(f"args: {args}")
# Automatically find all Messages directories under the current user's Mail directory
from examples.email_data.LEANN_email_reader import find_all_messages_directories
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
# messages_dirs = [DEFAULT_MAIL_PATH]
# messages_dirs = messages_dirs[:1]
print('len(messages_dirs): ', len(messages_dirs))
if not messages_dirs:
print("No Messages directories found. Exiting.")
return
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
print(f"Index directory: {INDEX_DIR}")
print(f"Found {len(messages_dirs)} Messages directories.")
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,108 @@
import os
import sys
import argparse
from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
# --- EMBEDDING MODEL ---
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
# --- END EMBEDDING MODEL ---
# Import EmlxReader from the new module
from examples.email_data.LEANN_email_reader import EmlxReader
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
print("Creating index from mail data with embedded metadata...")
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Use facebook/contriever as the embedder
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
# set on device
import torch
if torch.cuda.is_available():
embed_model._model.to("cuda")
# set mps
elif torch.backends.mps.is_available():
embed_model._model.to("mps")
else:
embed_model._model.to("cpu")
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter],
embed_model=embed_model
)
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_embedded"):
try:
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
parser.add_argument('--mail-path', type=str,
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
help='Path to mail data directory')
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
help='Directory to store the index (default: mail_index_embedded)')
parser.add_argument('--max-emails', type=int, default=10000,
help='Maximum number of emails to process')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
args = parser.parse_args()
mail_path = args.mail_path
save_dir = args.save_dir
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
if index:
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,119 @@
import argparse
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import dotenv
from leann.api import LeannBuilder, LeannChat
from pathlib import Path
dotenv.load_dotenv()
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
documents = SimpleDirectoryReader(
args.data_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
print(f"\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
llm_config = {"type": "openai", "model": "gpt-4o"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
query = args.query
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Leann Chat with various LLM backends."
)
parser.add_argument(
"--llm",
type=str,
default="hf",
choices=["simulated", "ollama", "hf", "openai"],
help="The LLM backend to use.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-0.6B",
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
)
parser.add_argument(
"--host",
type=str,
default="http://localhost:11434",
help="The host for the Ollama API.",
)
parser.add_argument(
"--index-dir",
type=str,
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
parser.add_argument(
"--data-dir",
type=str,
default="examples/data",
help="Directory containing documents to index (PDF, TXT, MD files).",
)
parser.add_argument(
"--query",
type=str,
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
help="The query to ask the Leann chat system.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Multi-Vector Aggregator for Fat Embeddings
==========================================
This module implements aggregation strategies for multi-vector embeddings,
similar to ColPali's approach where multiple patch vectors represent a single document.
Key features:
- MaxSim aggregation (take maximum similarity across patches)
- Voting-based aggregation (count patch matches)
- Weighted aggregation (attention-score weighted)
- Spatial clustering of matching patches
- Document-level result consolidation
"""
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import json
@dataclass
class PatchResult:
"""Represents a single patch search result."""
patch_id: int
image_name: str
image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float
attention_score: float
scale: float
metadata: Dict[str, Any]
@dataclass
class AggregatedResult:
"""Represents an aggregated document-level result."""
image_name: str
image_path: str
doc_score: float
patch_count: int
best_patch: PatchResult
all_patches: List[PatchResult]
aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None
class MultiVectorAggregator:
"""
Aggregates multiple patch-level results into document-level results.
"""
def __init__(self,
aggregation_method: str = "maxsim",
spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0):
"""
Initialize the aggregator.
Args:
aggregation_method: "maxsim", "voting", "weighted", or "mean"
spatial_clustering: Whether to cluster spatially close patches
cluster_distance_threshold: Distance threshold for spatial clustering
"""
self.aggregation_method = aggregation_method
self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self,
search_results: List[Dict[str, Any]],
top_k: int = 10) -> List[AggregatedResult]:
"""
Aggregate patch-level search results into document-level results.
Args:
search_results: List of search results from LeannSearcher
top_k: Number of top documents to return
Returns:
List of aggregated document results
"""
# Group results by image
image_groups = defaultdict(list)
for result in search_results:
metadata = result.metadata
if "image_name" in metadata and "patch_id" in metadata:
patch_result = PatchResult(
patch_id=metadata["patch_id"],
image_name=metadata["image_name"],
image_path=metadata["image_path"],
coordinates=tuple(metadata["coordinates"]),
score=result.score,
attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0),
metadata=metadata
)
image_groups[metadata["image_name"]].append(patch_result)
# Aggregate each image group
aggregated_results = []
for image_name, patches in image_groups.items():
if len(patches) == 0:
continue
agg_result = self._aggregate_image_patches(image_name, patches)
aggregated_results.append(agg_result)
# Sort by aggregated score and return top-k
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
"""Aggregate patches for a single image."""
if self.aggregation_method == "maxsim":
doc_score = max(patch.score for patch in patches)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "voting":
# Count patches above threshold
threshold = np.percentile([p.score for p in patches], 75)
doc_score = sum(1 for patch in patches if patch.score >= threshold)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "weighted":
# Weight by attention scores
total_weighted_score = sum(p.score * p.attention_score for p in patches)
total_weights = sum(p.attention_score for p in patches)
doc_score = total_weighted_score / max(total_weights, 1e-8)
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
elif self.aggregation_method == "mean":
doc_score = np.mean([patch.score for patch in patches])
best_patch = max(patches, key=lambda p: p.score)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
# Spatial clustering if enabled
spatial_clusters = None
if self.spatial_clustering:
spatial_clusters = self._cluster_patches_spatially(patches)
return AggregatedResult(
image_name=image_name,
image_path=patches[0].image_path,
doc_score=float(doc_score),
patch_count=len(patches),
best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters
)
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
"""Cluster patches that are spatially close to each other."""
if len(patches) <= 1:
return [patches]
clusters = []
remaining_patches = patches.copy()
while remaining_patches:
# Start new cluster with highest scoring remaining patch
seed_patch = max(remaining_patches, key=lambda p: p.score)
current_cluster = [seed_patch]
remaining_patches.remove(seed_patch)
# Add nearby patches to cluster
added_to_cluster = True
while added_to_cluster:
added_to_cluster = False
for patch in remaining_patches.copy():
if self._is_patch_nearby(patch, current_cluster):
current_cluster.append(patch)
remaining_patches.remove(patch)
added_to_cluster = True
clusters.append(current_cluster)
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
(patch_center[1] - cluster_center[1])**2)
if distance <= self.cluster_distance_threshold:
return True
return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of a patch."""
x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
"""Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80)
for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}")
# Show best patch
best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
# Show top patches
print(f" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
# Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality."""
print("=== Multi-Vector Aggregation Demo ===")
# Simulate some patch-level search results
# In real usage, these would come from LeannSearcher.search()
class MockResult:
def __init__(self, score, metadata):
self.score = score
self.metadata = metadata
# Simulate results for 2 images with multiple patches each
mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 3,
"coordinates": [100, 50, 224, 174], # Kitchen area
"attention_score": 0.92,
"scale": 1.0
}),
MockResult(0.78, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 7,
"coordinates": [200, 300, 324, 424], # Cat area
"attention_score": 0.88,
"scale": 1.0
}),
MockResult(0.72, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 12,
"coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.65, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0
}),
# Image 2: city_street.jpg - 3 patches
MockResult(0.68, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 2,
"coordinates": [300, 100, 424, 224], # Buildings
"attention_score": 0.80,
"scale": 1.0
}),
MockResult(0.62, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 8,
"coordinates": [100, 350, 224, 474], # Street level
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.55, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0
}),
]
# Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
aggregator = MultiVectorAggregator(
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__":
demo_aggregation()

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""
OpenAI Embedding Example
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
"""
import os
import dotenv
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher
# Load environment variables
dotenv.load_dotenv()
def main():
# Check if OpenAI API key is available
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY environment variable not set")
return False
print(f"✅ OpenAI API key found: {api_key[:10]}...")
# Sample texts
sample_texts = [
"Machine learning is a powerful technology that enables computers to learn from data.",
"Natural language processing helps computers understand and generate human language.",
"Deep learning uses neural networks with multiple layers to solve complex problems.",
"Computer vision allows machines to interpret and understand visual information.",
"Reinforcement learning trains agents to make decisions through trial and error.",
"Data science combines statistics, math, and programming to extract insights from data.",
"Artificial intelligence aims to create machines that can perform human-like tasks.",
"Python is a popular programming language used extensively in data science and AI.",
"Neural networks are inspired by the structure and function of the human brain.",
"Big data refers to extremely large datasets that require special tools to process."
]
INDEX_DIR = Path("./simple_openai_test_index")
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
print(f"\n=== Building Index with OpenAI Embeddings ===")
print(f"Index path: {INDEX_PATH}")
try:
# Use proper configuration for OpenAI embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# HNSW settings for OpenAI embeddings
M=16, # Smaller graph degree
efConstruction=64, # Smaller construction complexity
is_compact=True, # Enable compact storage for recompute
is_recompute=True, # MUST enable for OpenAI embeddings
num_threads=1,
)
print(f"Adding {len(sample_texts)} texts to the index...")
for i, text in enumerate(sample_texts):
metadata = {"id": f"doc_{i}", "topic": "AI"}
builder.add_text(text, metadata)
print("Building index...")
builder.build_index(INDEX_PATH)
print(f"✅ Index built successfully!")
except Exception as e:
print(f"❌ Error building index: {e}")
import traceback
traceback.print_exc()
return False
print(f"\n=== Testing Search ===")
try:
searcher = LeannSearcher(INDEX_PATH)
test_queries = [
"What is machine learning?",
"How do neural networks work?",
"Programming languages for data science"
]
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
results = searcher.search(query, top_k=3)
print(f" Found {len(results)} results:")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result.score:.4f}")
print(f" Text: {result.text[:80]}...")
print(f"\n✅ Search test completed successfully!")
return True
except Exception as e:
print(f"❌ Error during search: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
if success:
print(f"\n🎉 Simple OpenAI index test completed successfully!")
else:
print(f"\n💥 Simple OpenAI index test failed!")

18
examples/resue_index.py Normal file
View File

@@ -0,0 +1,18 @@
import asyncio
from leann.api import LeannChat
from pathlib import Path
INDEX_DIR = Path("./test_pdf_index_huawei")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
print(f"\n[PHASE 2] Response: {response}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -5,21 +5,24 @@ It correctly compares results by fetching the text content for both the new sear
results and the golden standard results, making the comparison robust to ID changes.
"""
import argparse
import json
import sys
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from leann.api import LeannBuilder, LeannChat, LeannSearcher
from typing import List
from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try:
from huggingface_hub import snapshot_download
@@ -60,7 +63,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
@@ -98,7 +101,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = No
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
@@ -110,20 +113,24 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts
def load_queries(file_path: Path) -> list[str]:
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, encoding="utf-8") as f:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
"""
Build a LEANN index from pre-computed embeddings.
@@ -166,7 +173,9 @@ def build_index_from_embeddings(embeddings_file: str, output_path: str, backend:
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument(
"index_path",
type=str,
@@ -193,41 +202,26 @@ def main():
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Batch size for HNSW batched search (0 disables batching)",
)
parser.add_argument(
"--llm-type",
type=str,
choices=["ollama", "hf", "openai", "gemini", "simulated"],
default="ollama",
help="LLM backend type to optionally query during evaluation (default: ollama)",
)
parser.add_argument(
"--llm-model",
type=str,
default="qwen3:1.7b",
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'benchmarks/'
# and evaluation data is in 'benchmarks/data/'.
script_dir = Path(__file__).resolve().parent
data_root = script_dir / "data"
# Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
@@ -268,7 +262,9 @@ def main():
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y":
print("Index building complete. Exiting.")
return
@@ -297,9 +293,11 @@ def main():
break
if not args.index_path:
print("No indices found. The data download should have included pre-built indices.")
print(
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
"No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
@@ -312,10 +310,14 @@ def main():
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
@@ -325,7 +327,7 @@ def main():
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file) as f:
with open(golden_results_file, "r") as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
@@ -338,23 +340,10 @@ def main():
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time)
# Optional: also call the LLM with configurable backend/model (does not affect recall)
llm_config = {"type": args.llm_type, "model": args.llm_model}
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
answer = chat.ask(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
print(f"Answer: {answer}")
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}

View File

@@ -1,28 +1,21 @@
"""
Simple demo showing basic leann usage
Run: uv run python examples/basic_demo.py
Run: uv run python examples/simple_demo.py
"""
import argparse
from leann import LeannBuilder, LeannChat, LeannSearcher
from leann import LeannBuilder, LeannSearcher, LeannChat
def main():
parser = argparse.ArgumentParser(
description="Simple demo of Leann with selectable embedding models."
)
parser.add_argument(
"--embedding_model",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
)
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
args = parser.parse_args()
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
print()
# Sample knowledge base
chunks = [
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
@@ -34,7 +27,7 @@ def main():
"Big data refers to extremely large datasets that require special tools and techniques to process.",
"Cloud computing provides on-demand access to computing resources over the internet.",
]
print("1. Building index (no embeddings stored)...")
builder = LeannBuilder(
embedding_model=args.embedding_model,
@@ -44,45 +37,45 @@ def main():
builder.add_text(chunk)
builder.build_index("demo_knowledge.leann")
print()
print("2. Searching with real-time embeddings...")
searcher = LeannSearcher("demo_knowledge.leann")
queries = [
"What is machine learning?",
"How does neural network work?",
"How does neural network work?",
"Tell me about data processing",
]
for query in queries:
print(f"Query: {query}")
results = searcher.search(query, top_k=2)
for i, result in enumerate(results, 1):
print(f" {i}. Score: {result.score:.3f}")
print(f" Text: {result.text[:100]}...")
print()
print("3. Interactive chat demo:")
print(" (Note: Requires OpenAI API key for real responses)")
chat = LeannChat("demo_knowledge.leann")
# Demo questions
demo_questions: list[str] = [
"What is the difference between machine learning and deep learning?",
"How is data science related to big data?",
]
for question in demo_questions:
print(f" Q: {question}")
response = chat.ask(question)
print(f" A: {response}")
print()
print("Demo completed! Try running:")
print(" uv run python apps/document_rag.py")
print(" uv run python examples/document_search.py")
if __name__ == "__main__":
main()
main()

View File

@@ -1,250 +0,0 @@
#!/usr/bin/env python3
"""
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
This example demonstrates how to use LEANN's metadata filtering to create
a spoiler-free book RAG system where users can search for information
up to a specific chapter they've read.
Usage:
python spoiler_free_book_rag.py
"""
import os
import sys
from typing import Any, Optional
# Add LEANN to path (adjust path as needed)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
from leann.api import LeannBuilder, LeannSearcher
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
"""
Create sample book chunks with metadata for demonstration.
In a real implementation, this would parse actual book files (epub, txt, etc.)
and extract chapter boundaries, character mentions, etc.
Args:
book_title: Title of the book
Returns:
List of chunk dictionaries with text and metadata
"""
# Sample book chunks with metadata
# In practice, you'd use proper text processing libraries
sample_chunks = [
{
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 1,
"characters": ["Alice", "Sister"],
"themes": ["boredom", "curiosity"],
"location": "riverbank",
},
},
{
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 2,
"characters": ["Alice", "White Rabbit"],
"themes": ["decision", "surprise", "magic"],
"location": "riverbank",
},
},
{
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
"metadata": {
"book": book_title,
"chapter": 2,
"page": 15,
"characters": ["Alice"],
"themes": ["falling", "wonder", "transformation"],
"location": "rabbit hole",
},
},
{
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
"metadata": {
"book": book_title,
"chapter": 6,
"page": 85,
"characters": ["Alice", "Cheshire Cat"],
"themes": ["madness", "philosophy", "identity"],
"location": "Duchess's house",
},
},
{
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
"metadata": {
"book": book_title,
"chapter": 8,
"page": 120,
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
"themes": ["justice", "absurdity", "authority"],
"location": "Queen's court",
},
},
{
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
"metadata": {
"book": book_title,
"chapter": 12,
"page": 180,
"characters": ["Alice", "Sister", "Rabbit"],
"themes": ["revelation", "reality", "growth"],
"location": "riverbank",
},
},
]
return sample_chunks
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
"""
Build a LEANN index with book chunks that include spoiler metadata.
Args:
book_chunks: List of book chunks with metadata
index_name: Name for the index
Returns:
Path to the built index
"""
print(f"📚 Building spoiler-free book index: {index_name}")
# Initialize LEANN builder
builder = LeannBuilder(
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
)
# Add each chunk with its metadata
for chunk in book_chunks:
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
# Build the index
index_path = f"{index_name}_book_index"
builder.build_index(index_path)
print(f"✅ Index built successfully: {index_path}")
return index_path
def spoiler_free_search(
index_path: str,
query: str,
max_chapter: int,
character_filter: Optional[list[str]] = None,
) -> list[dict[str, Any]]:
"""
Perform a spoiler-free search on the book index.
Args:
index_path: Path to the LEANN index
query: Search query
max_chapter: Maximum chapter number to include
character_filter: Optional list of characters to focus on
Returns:
List of search results safe for the reader
"""
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
searcher = LeannSearcher(index_path)
metadata_filters = {"chapter": {"<=": max_chapter}}
if character_filter:
metadata_filters["characters"] = {"contains": character_filter[0]}
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
return results
def demo_spoiler_free_rag():
"""
Demonstrate the spoiler-free book RAG system.
"""
print("🎭 Spoiler-Free Book RAG Demo")
print("=" * 40)
# Step 1: Prepare book data
book_title = "Alice's Adventures in Wonderland"
book_chunks = chunk_book_with_metadata(book_title)
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
# Step 2: Build the index (in practice, this would be done once)
try:
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
except Exception as e:
print(f"❌ Failed to build index (likely missing dependencies): {e}")
print(
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
)
return
# Step 3: Demonstrate various spoiler-free searches
search_scenarios = [
{
"description": "Reader who has only read Chapter 1",
"query": "What can you tell me about the rabbit?",
"max_chapter": 1,
},
{
"description": "Reader who has read up to Chapter 5",
"query": "Tell me about Alice's adventures",
"max_chapter": 5,
},
{
"description": "Reader who has read most of the book",
"query": "What does the Cheshire Cat represent?",
"max_chapter": 10,
},
{
"description": "Reader who has read the whole book",
"query": "What can you tell me about the rabbit?",
"max_chapter": 12,
},
]
for scenario in search_scenarios:
print(f"\n📚 Scenario: {scenario['description']}")
print(f" Query: {scenario['query']}")
try:
results = spoiler_free_search(
index_path=index_path,
query=scenario["query"],
max_chapter=scenario["max_chapter"],
)
print(f" 📄 Found {len(results)} results:")
for i, result in enumerate(results[:3], 1): # Show top 3
chapter = result.metadata.get("chapter", "?")
location = result.metadata.get("location", "?")
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
except Exception as e:
print(f" ❌ Search failed: {e}")
if __name__ == "__main__":
print("📚 LEANN Spoiler-Free Book RAG Example")
print("=====================================")
try:
demo_spoiler_free_rag()
except ImportError as e:
print(f"❌ Cannot run demo due to missing dependencies: {e}")
except Exception as e:
print(f"❌ Error running demo: {e}")

View File

@@ -0,0 +1,319 @@
import os
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any, Optional
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
import requests
import time
dotenv.load_dotenv()
# Default WeChat export directory
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path],
index_path: str = "wechat_history_index.leann",
max_count: int = -1,
):
"""
Create LEANN index from multiple WeChat export data sources.
Args:
export_dirs: List of Path objects pointing to WeChat export directories
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process per export
"""
print("Creating LEANN index from multiple WeChat export data sources...")
# Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each WeChat export directory
for i, export_dir in enumerate(export_dirs):
print(
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try:
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_count,
concatenate_messages=True, # Disable concatenation - one message per document
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
all_texts.append(text)
print(
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
)
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="Qwen/Qwen3-Embedding-0.6B",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(
export_dir: str = None,
index_path: str = "wechat_history_index.leann",
max_count: int = 1000,
):
"""
Create LEANN index from WeChat chat history data.
Args:
export_dir: Path to the WeChat export directory (optional, uses default if None)
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process
"""
print("Creating LEANN index from WeChat chat history data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
documents = reader.load_data(
wechat_export_dir=export_dir,
max_count=max_count,
concatenate_messages=False, # Disable concatenation - one message per document
)
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} chat documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=16,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
"""Main function with integrated WeChat export functionality."""
# Parse command line arguments
parser = argparse.ArgumentParser(
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
)
parser.add_argument(
"--export-dir",
type=str,
default=DEFAULT_WECHAT_EXPORT_DIR,
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
)
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_magic_test_11Debug_new",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=50,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--force-export",
action="store_true",
default=False,
help="Force re-export of WeChat data even if exports exist",
)
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
print(f"Using WeChat export directory: {args.export_dir}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Initialize WeChat reader with export capabilities
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Exiting.")
return
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_wechat_exports(
export_dirs, INDEX_PATH, max_count=args.max_entries
)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
]
for query in queries:
print("\n" + "=" * 60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,28 +0,0 @@
# llms.txt — LEANN MCP and Agent Integration
product: LEANN
homepage: https://github.com/yichuan-w/LEANN
contact: https://github.com/yichuan-w/LEANN/issues
# Installation
install: uv tool install leann-core --with leann
# MCP Server Entry Point
mcp.server: leann_mcp
mcp.protocol_version: 2024-11-05
# Tools
mcp.tools: leann_list, leann_search
mcp.tool.leann_list.description: List available LEANN indexes
mcp.tool.leann_list.input: {}
mcp.tool.leann_search.description: Semantic search across a named LEANN index
mcp.tool.leann_search.input.index_name: string, required
mcp.tool.leann_search.input.query: string, required
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
# Notes
note: Build indexes with `leann build <name> --docs <files...>` before searching.
example.add: claude mcp add --scope user leann-server -- leann_mcp
example.verify: claude mcp list | cat

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,36 @@
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# Find Python - scikit-build-core should provide this
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
# Print Python information for debugging
message(STATUS "Python_FOUND: ${Python_FOUND}")
message(STATUS "Python_VERSION: ${Python_VERSION}")
message(STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}")
message(STATUS "Python_INCLUDE_DIRS: ${Python_INCLUDE_DIRS}")
# Pass Python information to DiskANN through cache variables
set(Python_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python executable" FORCE)
set(Python_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python include dirs" FORCE)
set(Python_LIBRARIES ${Python_LIBRARIES} CACHE FILEPATH "Python libraries" FORCE)
set(Python_VERSION ${Python_VERSION} CACHE STRING "Python version" FORCE)
set(Python_FOUND ${Python_FOUND} CACHE BOOL "Python found" FORCE)
# Also set Python3 variables for compatibility
set(Python3_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python3 executable" FORCE)
set(Python3_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python3 include dirs" FORCE)
set(Python3_LIBRARIES ${Python_LIBRARIES} CACHE FILEPATH "Python3 libraries" FORCE)
set(Python3_VERSION ${Python_VERSION} CACHE STRING "Python3 version" FORCE)
set(Python3_FOUND ${Python_FOUND} CACHE BOOL "Python3 found" FORCE)
set(Python3_Development_FOUND TRUE CACHE BOOL "Python3 development found" FORCE)
# Set Python finding strategy
set(Python_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
set(Python3_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
# DiskANN will handle everything itself, including compiling Python bindings
add_subdirectory(third_party/DiskANN)

View File

@@ -1 +1 @@
# This file makes the directory a Python package
# This file makes the directory a Python package

View File

@@ -1,7 +1 @@
from . import diskann_backend as diskann_backend
from . import graph_partition
# Export main classes and functions
from .graph_partition import GraphPartitioner, partition_graph
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
from . import diskann_backend

View File

@@ -1,20 +1,20 @@
import contextlib
import logging
import numpy as np
import os
import struct
import sys
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Dict, Any, List, Literal, Optional
import contextlib
import numpy as np
import psutil
import logging
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import (
LeannBackendBuilderInterface,
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface,
)
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
logger = logging.getLogger(__name__)
@@ -22,11 +22,6 @@ logger = logging.getLogger(__name__)
@contextlib.contextmanager
def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
if os.getenv("CI") == "true":
yield
return
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
@@ -90,43 +85,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
f.write(data.tobytes())
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
"""
Calculate smart memory configuration for DiskANN based on data size and system specs.
Args:
data: The embedding data array
Returns:
tuple: (search_memory_maximum, build_memory_maximum) in GB
"""
num_vectors, dim = data.shape
# Calculate embedding storage size
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
embedding_size_gb = embedding_size_bytes / (1024**3)
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
# This controls Product Quantization size - smaller means more compression
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
# build_memory_maximum: Based on available system RAM for sharding control
# This controls how much memory DiskANN uses during index construction
available_memory_gb = psutil.virtual_memory().available / (1024**3)
total_memory_gb = psutil.virtual_memory().total / (1024**3)
# Use 50% of available memory, but at least 2GB and at most 75% of total
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
logger.info(
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
)
return search_memory_gb, build_memory_gb
@register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
@@ -142,72 +100,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
"""
Safely cleanup files after partition.
In partition mode, C++ doesn't read _disk.index content,
so we can delete it if all derived files exist.
"""
disk_index_file = index_dir / f"{index_prefix}_disk.index"
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
# Required files that C++ partition mode needs
# Note: C++ generates these with _disk.index suffix
disk_suffix = "_disk.index"
required_files = [
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
f"{index_prefix}_pq_pivots.bin", # PQ table
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
]
# Check if all required files exist
missing_files = []
for filename in required_files:
file_path = index_dir / filename
if not file_path.exists():
missing_files.append(filename)
if missing_files:
logger.warning(
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
)
logger.info("Keeping all original files for safety")
return
# Calculate space savings
space_saved = 0
files_to_delete = []
if disk_index_file.exists():
space_saved += disk_index_file.stat().st_size
files_to_delete.append(disk_index_file)
if beam_search_file.exists():
space_saved += beam_search_file.stat().st_size
files_to_delete.append(beam_search_file)
# Safe to delete!
for file_to_delete in files_to_delete:
try:
os.remove(file_to_delete)
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
except Exception as e:
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
if space_saved > 0:
space_saved_mb = space_saved / (1024 * 1024)
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
# Show what files are kept
logger.info("📁 Kept essential files for partition mode:")
for filename in required_files:
file_path = index_dir / filename
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info(f" - {filename} ({size_mb:.1f} MB)")
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
@@ -221,17 +114,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
_write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs}
# Extract is_recompute from nested backend_kwargs if needed
is_recompute = build_kwargs.get("is_recompute", False)
if not is_recompute and "backend_kwargs" in build_kwargs:
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
# Flatten all backend_kwargs parameters to top level for compatibility
if "backend_kwargs" in build_kwargs:
nested_params = build_kwargs.pop("backend_kwargs")
build_kwargs.update(nested_params)
metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower()
)
@@ -240,16 +122,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
)
# Calculate smart memory configuration if not explicitly provided
if (
"search_memory_maximum" not in build_kwargs
or "build_memory_maximum" not in build_kwargs
):
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
else:
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
try:
from . import _diskannpy as diskannpy # type: ignore
@@ -260,36 +132,12 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_prefix,
build_kwargs.get("complexity", 64),
build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", smart_search_mem),
build_kwargs.get("build_memory_maximum", smart_build_mem),
build_kwargs.get("search_memory_maximum", 4.0),
build_kwargs.get("build_memory_maximum", 8.0),
build_kwargs.get("num_threads", 8),
build_kwargs.get("pq_disk_bytes", 0),
"",
)
# Auto-partition if is_recompute is enabled
if build_kwargs.get("is_recompute", False):
logger.info("is_recompute=True, starting automatic graph partitioning...")
from .graph_partition import partition_graph
# Partition the index using absolute paths
# Convert to absolute paths to avoid issues with working directory changes
absolute_index_dir = Path(index_dir).resolve()
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
disk_graph_path, partition_bin_path = partition_graph(
index_prefix_path=absolute_index_prefix_path,
output_dir=str(absolute_index_dir),
partition_prefix=index_prefix,
)
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
self._safe_cleanup_after_partition(index_dir, index_prefix)
logger.info("✅ Graph partitioning completed successfully!")
logger.info(f" - Disk graph: {disk_graph_path}")
logger.info(f" - Partition file: {partition_bin_path}")
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
@@ -316,69 +164,18 @@ class DiskannSearcher(BaseSearcher):
self.num_threads = kwargs.get("num_threads", 8)
# For DiskANN, we need to reinitialize the index when zmq_port changes
# Store the initialization parameters for later use
# Note: C++ load method expects the BASE path (without _disk.index suffix)
# C++ internally constructs: index_prefix + "_disk.index"
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
# Auto-detect partition files and set partition_prefix
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
partition_prefix = ""
if partition_graph_file.exists() and partition_bin_file.exists():
# C++ expects full path prefix, not just filename
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
else:
logger.debug("No partition files detected, using standard index files")
self._init_params = {
"metric_enum": metric_enum,
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
"pq_prefix": "",
"partition_prefix": partition_prefix,
}
# Log partition configuration for debugging
if partition_prefix:
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
def _ensure_index_loaded(self, zmq_port: int):
"""Ensure the index is loaded with the correct zmq_port."""
if self._index is None or self._current_zmq_port != zmq_port:
# Need to (re)load the index with the correct zmq_port
with suppress_cpp_output_if_needed():
if self._index is not None:
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
else:
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
self._index = self._diskannpy.StaticDiskFloatIndex(
self._init_params["metric_enum"],
self._init_params["full_index_prefix"],
self._init_params["num_threads"],
self._init_params["num_nodes_to_cache"],
self._init_params["cache_mechanism"],
zmq_port,
self._init_params["pq_prefix"],
self._init_params["partition_prefix"],
)
self._current_zmq_port = zmq_port
fake_zmq_port = 6666
full_index_prefix = str(self.index_dir / self.index_path.stem)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum,
full_index_prefix,
self.num_threads,
kwargs.get("num_nodes_to_cache", 0),
1,
fake_zmq_port, # Initial port, can be updated at runtime
"",
"",
)
def search(
self,
@@ -393,7 +190,7 @@ class DiskannSearcher(BaseSearcher):
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Search for nearest neighbors using DiskANN index.
@@ -416,15 +213,18 @@ class DiskannSearcher(BaseSearcher):
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
# Handle zmq_port compatibility: Ensure index is loaded with correct port
# Handle zmq_port compatibility: DiskANN can now update port at runtime
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
self._ensure_index_loaded(zmq_port)
else:
# If not recomputing, we still need an index, use a default port
if self._index is None:
self._ensure_index_loaded(6666) # Default port when not recomputing
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
current_port = self._index.get_zmq_port()
if zmq_port != current_port:
logger.debug(
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
)
self._index.set_zmq_port(zmq_port)
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
@@ -441,14 +241,7 @@ class DiskannSearcher(BaseSearcher):
else: # "global"
use_global_pruning = True
# Strategy:
# - Traversal always uses PQ distances
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
# (fetch embeddings for the final candidate set only)
# - Do not recompute neighbor distances along the path
use_deferred_fetch = True if recompute_embeddings else False
recompute_neighors = False # Expected typo. For backward compatibility.
# Perform search with suppressed C++ output based on log level
with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search(
query,
@@ -457,15 +250,17 @@ class DiskannSearcher(BaseSearcher):
complexity,
beam_width,
self.num_threads,
use_deferred_fetch,
kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False),
recompute_neighors,
recompute_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
use_global_pruning,
)
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -3,17 +3,16 @@ DiskANN-specific embedding server
"""
import argparse
import json
import logging
import os
import sys
import threading
import time
import os
import zmq
import numpy as np
import json
from pathlib import Path
from typing import Optional
import numpy as np
import zmq
import sys
import logging
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -37,7 +36,6 @@ def create_diskann_embedding_server(
zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers",
distance_metric: str = "l2",
):
"""
Create and start a ZMQ-based embedding server for DiskANN backend.
@@ -52,8 +50,8 @@ def create_diskann_embedding_server(
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
@@ -78,12 +76,13 @@ def create_diskann_embedding_server(
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file) as f:
with open(passages_file, "r") as f:
meta = json.load(f)
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
passages = PassageManager(meta["passage_sources"])
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
# Import protobuf after ensuring the path is correct
try:
@@ -101,9 +100,8 @@ def create_diskann_embedding_server(
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.setsockopt(zmq.SNDTIMEO, 1000)
socket.setsockopt(zmq.LINGER, 0)
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
while True:
try:
@@ -152,7 +150,9 @@ def create_diskann_embedding_server(
):
texts = request
is_text_request = True
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
logger.info(
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception as msgpack_error:
@@ -167,7 +167,9 @@ def create_diskann_embedding_server(
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError as e:
logger.error(f"Passage ID {nid} not found: {e}")
@@ -178,7 +180,9 @@ def create_diskann_embedding_server(
# Debug logging
logger.debug(f"Processing {len(texts)} texts")
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
logger.debug(
f"Text lengths: {[len(t) for t in texts[:5]]}"
) # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
@@ -195,7 +199,9 @@ def create_diskann_embedding_server(
else:
# For DiskANN C++ compatibility: return protobuf format
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
hidden_contiguous = np.ascontiguousarray(
embeddings, dtype=np.float32
)
# Serialize embeddings data
resp_proto.embeddings_data = hidden_contiguous.tobytes()
@@ -220,217 +226,30 @@ def create_diskann_embedding_server(
traceback.print_exc()
raise
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
This creates its own REP socket, binds to zmq_port, and periodically
checks shutdown_event using recv timeouts to exit cleanly.
"""
logger.info("DiskANN ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
# Set receive timeout so we can check shutdown_event periodically
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
# REP socket receives single-part messages
message = rep_socket.recv()
# Check for empty messages - REP socket requires response to every request
if not message:
logger.warning("Received empty message, sending empty response")
rep_socket.send(b"")
continue
# Try protobuf first (same logic as original)
texts = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
# Look up texts by node IDs
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"ZMQ received msgpack text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error("NaN or Inf detected in embeddings!")
# Send error response
if is_text_request:
import msgpack
response_data = msgpack.packb([])
else:
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = resp_proto.SerializeToString()
rep_socket.send(response_data)
continue
# Prepare response based on request type
if is_text_request:
# For direct text requests, return msgpack
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For protobuf requests, return protobuf
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
rep_socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
try:
# Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("DiskANN ZMQ server thread exiting gracefully")
# Add shutdown coordination
shutdown_event = threading.Event()
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Clean up other resources
try:
import gc
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed")
sys.exit(0)
# Register signal handlers within this function scope
import signal
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Start ZMQ thread (NOT daemon!)
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import signal
import sys
# Signal handlers are now registered within create_diskann_embedding_server
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
@@ -449,16 +268,9 @@ if __name__ == "__main__":
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode",
)
parser.add_argument(
"--distance-metric",
type=str,
default="l2",
choices=["l2", "mips", "cosine"],
help="Distance metric for similarity computation",
)
args = parser.parse_args()
@@ -468,5 +280,4 @@ if __name__ == "__main__":
zmq_port=args.zmq_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
distance_metric=args.distance_metric,
)

View File

@@ -1,28 +1,27 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: embedding.proto
# ruff: noqa
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._options = None
_NODEEMBEDDINGREQUEST._serialized_start = 35
_NODEEMBEDDINGREQUEST._serialized_end = 75
_NODEEMBEDDINGRESPONSE._serialized_start = 77
_NODEEMBEDDINGRESPONSE._serialized_end = 166
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_NODEEMBEDDINGREQUEST._serialized_start=35
_NODEEMBEDDINGREQUEST._serialized_end=75
_NODEEMBEDDINGRESPONSE._serialized_start=77
_NODEEMBEDDINGRESPONSE._serialized_end=166
# @@protoc_insertion_point(module_scope)

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env python3
"""
Graph Partition Module for LEANN DiskANN Backend
This module provides Python bindings for the graph partition functionality
of DiskANN, allowing users to partition disk-based indices for better
performance.
"""
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
class GraphPartitioner:
"""
A Python interface for DiskANN's graph partition functionality.
This class provides methods to partition disk-based indices for improved
search performance and memory efficiency.
"""
def __init__(self, build_type: str = "release"):
"""
Initialize the GraphPartitioner.
Args:
build_type: Build type for the executables ("debug" or "release")
"""
self.build_type = build_type
self._ensure_executables()
def _get_executable_path(self, name: str) -> str:
"""Get the path to a graph partition executable."""
# Get the directory where this Python module is located
module_dir = Path(__file__).parent
# Navigate to the graph_partition directory
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
if not executable_path.exists():
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
return str(executable_path)
def _ensure_executables(self):
"""Ensure that the required executables are built."""
try:
self._get_executable_path("partitioner")
self._get_executable_path("index_relayout")
except FileNotFoundError:
# Try to build the executables automatically
print("Executables not found, attempting to build them...")
self._build_executables()
def _build_executables(self):
"""Build the required executables."""
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Clean any existing build
if (graph_partition_dir / "build").exists():
shutil.rmtree(graph_partition_dir / "build")
# Run the build script
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
# Check if executables were created
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
print(f"✅ Built partitioner: {partitioner_path}")
print(f"✅ Built index_relayout: {relayout_path}")
except Exception as e:
raise RuntimeError(f"Failed to build executables: {e}")
finally:
os.chdir(original_dir)
def partition_graph(
self,
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
"""
Partition a disk-based index for improved performance.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory for results (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
**kwargs: Additional parameters for graph partitioning:
- gp_times: Number of LDG partition iterations (default: 10)
- lock_nums: Number of lock nodes (default: 10)
- cut: Cut adjacency list degree (default: 100)
- scale_factor: Scale factor (default: 1)
- data_type: Data type (default: "float")
- thread_nums: Number of threads (default: 10)
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
Raises:
RuntimeError: If the partitioning process fails
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Determine partition prefix
if partition_prefix is None:
partition_prefix = Path(index_prefix_path).name
# Get executable paths
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Change to the graph_partition directory for temporary files
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Create temporary data directory
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Run partitioner
gp_file_path = graph_gp_path / "_part.bin"
partitioner_cmd = [
partitioner_path,
"--index_file",
old_index_file,
"--data_type",
params["data_type"],
"--gp_file",
str(gp_file_path),
"-T",
str(params["thread_nums"]),
"--ldg_times",
str(params["gp_times"]),
"--scale",
str(params["scale_factor"]),
"--mode",
"1",
]
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
result = subprocess.run(
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Partitioner failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Run relayout
part_tmp_index = graph_gp_path / "_part_tmp.index"
relayout_cmd = [
relayout_path,
old_index_file,
str(gp_file_path),
params["data_type"],
"1",
]
print(f"Running relayout: {' '.join(relayout_cmd)}")
result = subprocess.run(
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Relayout failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Copy results to output directory
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
shutil.copy2(part_tmp_index, disk_graph_path)
shutil.copy2(gp_file_path, partition_bin_path)
print(f"Results copied to: {output_dir}")
return str(disk_graph_path), str(partition_bin_path)
finally:
os.chdir(original_dir)
def get_partition_info(self, partition_bin_path: str) -> dict:
"""
Get information about a partition file.
Args:
partition_bin_path: Path to the partition binary file
Returns:
Dictionary containing partition information
"""
if not os.path.exists(partition_bin_path):
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
# For now, return basic file information
# In the future, this could parse the binary file for detailed info
stat = os.stat(partition_bin_path)
return {
"file_size": stat.st_size,
"file_path": partition_bin_path,
"modified_time": stat.st_mtime,
}
def partition_graph(
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
build_type: str = "release",
**kwargs,
) -> tuple[str, str]:
"""
Convenience function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix
output_dir: Output directory (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
build_type: Build type for executables ("debug" or "release")
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
partitioner = GraphPartitioner(build_type=build_type)
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
# Example usage:
if __name__ == "__main__":
# Example: partition an index
try:
disk_graph_path, partition_bin_path = partition_graph(
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
)
print("Partitioning completed successfully!")
print(f"Disk graph index: {disk_graph_path}")
print(f"Partition binary: {partition_bin_path}")
except Exception as e:
print(f"Partitioning failed: {e}")

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.3.3"
dependencies = ["leann-core==0.3.3", "numpy", "protobuf>=3.19.0"]
version = "0.1.8"
dependencies = ["leann-core==0.1.8", "numpy"]
[tool.scikit-build]
# Key: simplified CMake path
@@ -17,5 +17,8 @@ editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
# Let CMake find packages via Homebrew prefix
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
wheel.exclude = ["CMakeLists.txt", "src", "third_party/**", "*.o", "*.so"]
sdist.include = ["CMakeLists.txt", "src", "third_party", "leann_backend_diskann/*.txt"]
[tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8"

View File

@@ -2,12 +2,12 @@ syntax = "proto3";
package protoembedding;
message NodeEmbeddingRequest {
repeated uint32 node_ids = 1;
message NodeEmbeddingRequest {
repeated uint32 node_ids = 1;
}
message NodeEmbeddingResponse {
bytes embeddings_data = 1; // All embedded binary datas
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
repeated uint32 missing_ids = 3; // Missing node ids
}
}

View File

@@ -5,28 +5,11 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
# Set OpenMP path for macOS
if(APPLE)
# Detect Homebrew installation path (Apple Silicon vs Intel)
if(EXISTS "/opt/homebrew/opt/libomp")
set(HOMEBREW_PREFIX "/opt/homebrew")
elseif(EXISTS "/usr/local/opt/libomp")
set(HOMEBREW_PREFIX "/usr/local")
else()
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
endif()
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
# Force use of system libc++ to avoid version mismatch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
# Set minimum macOS version for better compatibility
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
endif()
# Use system ZeroMQ instead of building from source
@@ -41,6 +24,38 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE)
add_compile_definitions(MSGPACK_NO_BOOST)
include_directories(third_party/msgpack-c/include)
# Find Python for our own use (not for Faiss)
if(DEFINED SKBUILD)
message(STATUS "Building with scikit-build")
# scikit-build-core provides Python information
endif()
# Find Python - scikit-build-core should provide this
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module NumPy)
# Print Python information for debugging
message(STATUS "Python_FOUND: ${Python_FOUND}")
message(STATUS "Python_VERSION: ${Python_VERSION}")
message(STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}")
message(STATUS "Python_INCLUDE_DIRS: ${Python_INCLUDE_DIRS}")
message(STATUS "Python_NumPy_INCLUDE_DIRS: ${Python_NumPy_INCLUDE_DIRS}")
# Pass Python information to faiss through cache variables
set(Python_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python executable" FORCE)
set(Python_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python include dirs" FORCE)
set(Python_NumPy_INCLUDE_DIRS ${Python_NumPy_INCLUDE_DIRS} CACHE PATH "NumPy include dirs" FORCE)
set(Python_VERSION ${Python_VERSION} CACHE STRING "Python version" FORCE)
set(Python_FOUND ${Python_FOUND} CACHE BOOL "Python found" FORCE)
# Also set Python3 variables for compatibility
set(Python3_EXECUTABLE ${Python_EXECUTABLE} CACHE FILEPATH "Python3 executable" FORCE)
set(Python3_INCLUDE_DIRS ${Python_INCLUDE_DIRS} CACHE PATH "Python3 include dirs" FORCE)
set(Python3_NumPy_INCLUDE_DIRS ${Python_NumPy_INCLUDE_DIRS} CACHE PATH "NumPy include dirs" FORCE)
set(Python3_VERSION ${Python_VERSION} CACHE STRING "Python3 version" FORCE)
set(Python3_FOUND ${Python_FOUND} CACHE BOOL "Python3 found" FORCE)
set(Python3_Development_FOUND TRUE CACHE BOOL "Python3 development found" FORCE)
set(Python3_NumPy_FOUND TRUE CACHE BOOL "Python3 NumPy found" FORCE)
# Faiss configuration - streamlined build
set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE)
set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE)
@@ -49,28 +64,9 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
# Disable additional SIMD versions to speed up compilation
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
# ARM64-specific configuration
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
message(STATUS "Configuring Faiss for ARM64 architecture")
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
else()
# Use generic optimization for other ARM64 platforms (like macOS)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
endif()
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
message(STATUS "Using ARM64-compatible Faiss submodule")
endif()
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
@@ -88,4 +84,8 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss)
# Force Faiss to use our Python settings
set(Python_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
set(Python3_FIND_VIRTUALENV ONLY CACHE STRING "" FORCE)
add_subdirectory(third_party/faiss)

View File

@@ -1 +1 @@
from . import hnsw_backend as hnsw_backend
from . import hnsw_backend

View File

@@ -1,122 +1,87 @@
import argparse
import gc # Import garbage collector interface
import logging
import os
import struct
import sys
import time
import numpy as np
# Set up logging to avoid print buffer issues
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
import os
import argparse
import gc # Import garbage collector interface
import time
# --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
# Add other HNSW fourccs if you expect different storage types inside HNSW
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
# --- Helper functions for reading/writing binary data ---
def read_struct(f, fmt):
"""Reads data according to the struct format."""
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
)
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
return struct.unpack(fmt, data)[0]
def read_vector_raw(f, element_fmt_char):
"""Reads a vector (size followed by data), returns count and raw bytes."""
count = -1 # Initialize count
total_bytes = -1 # Initialize total_bytes
count = -1 # Initialize count
total_bytes = -1 # Initialize total_bytes
try:
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
element_size = struct.calcsize(element_fmt_char)
# --- FIX for MemoryError: Check for unreasonably large count ---
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
if count > max_reasonable_count or count < 0:
raise MemoryError(
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
)
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
total_bytes = count * element_size
# --- FIX for MemoryError: Check for huge byte size before allocation ---
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
raise MemoryError(
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
)
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
data_bytes = f.read(total_bytes)
if len(data_bytes) != total_bytes:
raise EOFError(
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
)
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
return count, data_bytes
except (MemoryError, OverflowError) as e:
# Add context to the error message
print(
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
file=sys.stderr,
)
raise e # Re-raise the original error type
# Add context to the error message
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
raise e # Re-raise the original error type
def read_numpy_vector(f, np_dtype, struct_fmt_char):
"""Reads a vector into a NumPy array."""
count = -1 # Initialize count for robust error handling
print(
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
end="",
flush=True,
)
count = -1 # Initialize count for robust error handling
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
try:
count, data_bytes = read_vector_raw(f, struct_fmt_char)
print(f"Count={count}, Bytes={len(data_bytes)}")
if count > 0 and len(data_bytes) > 0:
arr = np.frombuffer(data_bytes, dtype=np_dtype)
if arr.size != count:
raise ValueError(
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
)
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
return arr
elif count == 0:
return np.array([], dtype=np_dtype)
return np.array([], dtype=np_dtype)
else:
raise ValueError("Read zero bytes but count > 0.")
raise ValueError("Read zero bytes but count > 0.")
except MemoryError as e:
# Now count should be defined (or -1 if error was in read_struct)
print(
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
file=sys.stderr,
)
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
raise e
except Exception as e: # Catch other potential errors like ValueError
print(
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
file=sys.stderr,
)
except Exception as e: # Catch other potential errors like ValueError
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
raise e
def write_numpy_vector(f, arr, struct_fmt_char):
"""Writes a NumPy array as a vector (size followed by data)."""
count = arr.size
f.write(struct.pack("<Q", count))
f.write(struct.pack('<Q', count))
try:
expected_dtype = np.dtype(struct_fmt_char)
if arr.dtype != expected_dtype:
@@ -124,30 +89,23 @@ def write_numpy_vector(f, arr, struct_fmt_char):
else:
data_to_write = arr.tobytes()
f.write(data_to_write)
del data_to_write # Hint GC
del data_to_write # Hint GC
except MemoryError as e:
print(
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
file=sys.stderr,
)
raise e
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
raise e
def write_list_vector(f, lst, struct_fmt_char):
"""Writes a Python list as a vector iteratively."""
count = len(lst)
f.write(struct.pack("<Q", count))
fmt = "<" + struct_fmt_char
f.write(struct.pack('<Q', count))
fmt = '<' + struct_fmt_char
chunk_size = 1024 * 1024
element_size = struct.calcsize(fmt)
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
try:
buffer = bytearray(chunk_size * element_size)
except MemoryError:
print(
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
file=sys.stderr,
)
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
raise
buffer_count = 0
@@ -158,80 +116,66 @@ def write_list_vector(f, lst, struct_fmt_char):
buffer_count += 1
if buffer_count == chunk_size or i == count - 1:
f.write(buffer[: buffer_count * element_size])
f.write(buffer[:buffer_count * element_size])
buffer_count = 0
except struct.error as e:
print(
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
file=sys.stderr,
)
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
raise e
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
"""Helper to get cumulative neighbors count, matching C++ logic."""
if level < 0:
return 0
if level < 0: return 0
if level < len(cum_nneighbor_per_level_np):
return cum_nneighbor_per_level_np[level]
else:
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
def write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data,
):
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np,
compact_neighbors_data, storage_fourcc, storage_data):
"""Write HNSW data in compact format following C++ read order exactly."""
# Write IndexHNSW Header
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
f_out.write(struct.pack('<i', original_hnsw_data['d']))
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
if original_hnsw_data['metric_type'] > 1:
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
# Write HNSW struct parts (standard order)
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, assign_probas_np, 'd')
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
write_numpy_vector(f_out, levels_np, 'i')
# Write compact format flag
f_out.write(struct.pack("<?", True)) # storage_is_compact = True
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
if isinstance(compact_level_ptr, np.ndarray):
write_numpy_vector(f_out, compact_level_ptr, "Q")
write_numpy_vector(f_out, compact_level_ptr, 'Q')
else:
write_list_vector(f_out, compact_level_ptr, "Q")
write_numpy_vector(f_out, compact_node_offsets_np, "Q")
write_list_vector(f_out, compact_level_ptr, 'Q')
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
# Write HNSW scalar parameters
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
# Write storage fourcc (this determines how to read what follows)
f_out.write(struct.pack("<I", storage_fourcc))
f_out.write(struct.pack('<I', storage_fourcc))
# Write compact neighbors data AFTER storage fourcc
write_list_vector(f_out, compact_neighbors_data, "i")
write_list_vector(f_out, compact_neighbors_data, 'i')
# Write storage data if not NULL (only after neighbors)
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
@@ -239,244 +183,185 @@ def write_compact_format(
# --- Main Conversion Logic ---
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
"""
Converts an HNSW graph file to the CSR format.
Supports both original and already-compact formats (backward compatibility).
Args:
input_filename: Input HNSW index file
output_filename: Output CSR index file
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
"""
# Keep prints simple; rely on CI runner to flush output as needed
print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time()
original_hnsw_data = {}
neighbors_np = None # Initialize to allow check in finally block
neighbors_np = None # Initialize to allow check in finally block
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
# --- Read IndexHNSW FourCC and Header ---
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
# ... (Keep the header reading logic as before) ...
hnsw_index_fourcc = read_struct(f_in, "<I")
hnsw_index_fourcc = read_struct(f_in, '<I')
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
print(
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
)
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
return False
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
original_hnsw_data['d'] = read_struct(f_in, '<i')
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
original_hnsw_data['metric_arg'] = 0.0
if original_hnsw_data['metric_type'] > 1:
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
# --- Read original HNSW struct data ---
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
print(
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
)
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
gc.collect()
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
print(
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
)
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
gc.collect()
levels_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, 'i')
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
gc.collect()
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
print(
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
file=sys.stderr,
)
original_hnsw_data["ntotal"] = ntotal
if ntotal != original_hnsw_data['ntotal']:
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
original_hnsw_data['ntotal'] = ntotal
# --- Check for compact format flag ---
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
pos_before_compact = f_in.tell()
try:
is_compact_flag = read_struct(f_in, "<?")
is_compact_flag = read_struct(f_in, '<?')
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
if is_compact_flag:
# Input is already in compact format - read compact data
print(
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
)
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
print(
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
)
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
print(
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
)
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
# Read scalar parameters
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
print(
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
# Read storage fourcc
storage_fourcc = read_struct(f_in, "<I")
print(
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
)
storage_fourcc = read_struct(f_in, '<I')
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
# Read compact neighbors data
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
print(
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
)
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np
# Skip storage data and write with NULL marker
print(
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
)
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
storage_fourcc = NULL_INDEX_FOURCC
elif not prune_embeddings:
# Read and preserve compact neighbors and storage
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np
# Read remaining storage data
storage_data = f_in.read()
else:
# Already pruned (NULL storage)
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np
storage_data = b""
storage_data = b''
# Write the updated compact format
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data if not prune_embeddings else b"",
)
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np,
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
return True
else:
# is_compact=False, rewind and read original format
f_in.seek(pos_before_compact)
print(
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
)
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
except EOFError:
# No compact flag found, assume original format
f_in.seek(pos_before_compact)
print(
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
)
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
# --- Handle potential extra byte in original format (like C++ code) ---
print(
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
)
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B") # Read 1 byte
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
if suspected_flag == 0x00:
print(
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
)
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
elif suspected_flag == 0x01:
print(
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
)
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
raise ValueError("Inconsistent compact flag state")
else:
# Rewind - this byte is part of offsets data
f_in.seek(pos_before_probe)
print(
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
)
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
except EOFError:
f_in.seek(pos_before_probe)
print(
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
)
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
# --- Read original format data ---
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
if len(offsets_np) != ntotal + 1:
raise ValueError(
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
)
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
gc.collect()
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
if neighbors_np.size != expected_neighbors_size:
print(
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
)
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
gc.collect()
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
print(
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
storage_fourcc = None
try:
storage_fourcc = read_struct(f_in, "<I")
print(
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
)
storage_fourcc = read_struct(f_in, '<I')
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
except EOFError:
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
except Exception as e:
print(
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
)
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
# --- Perform Conversion ---
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
@@ -488,21 +373,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
current_level_ptr_idx = 0
current_data_idx = 0
total_valid_neighbors_counted = 0 # For validation
total_valid_neighbors_counted = 0 # For validation
# Optimize calculation by getting slices once per node if possible
for i in range(ntotal):
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
progress = (i / ntotal) * 100
elapsed = time.time() - start_time
print(
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
end="",
)
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
node_max_level = levels_np[i] - 1
if node_max_level < -1:
node_max_level = -1
if node_max_level < -1: node_max_level = -1
node_ptr_start_index = current_level_ptr_idx
compact_node_offsets_np[i] = node_ptr_start_index
@@ -513,17 +394,13 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
for level in range(node_max_level + 1):
compact_level_ptr.append(current_data_idx)
begin_orig_np = original_offset_start + get_cum_neighbors(
cum_nneighbor_per_level_np, level
)
end_orig_np = original_offset_start + get_cum_neighbors(
cum_nneighbor_per_level_np, level + 1
)
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
begin_orig = int(begin_orig_np)
end_orig = int(end_orig_np)
neighbors_len = len(neighbors_np) # Cache length
neighbors_len = len(neighbors_np) # Cache length
begin_orig = min(max(0, begin_orig), neighbors_len)
end_orig = min(max(begin_orig, end_orig), neighbors_len)
@@ -536,117 +413,83 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
if num_valid > 0:
# Append valid neighbors
compact_neighbors_data.extend(
level_neighbors_slice[valid_neighbors_mask]
)
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
current_data_idx += num_valid
total_valid_neighbors_counted += num_valid
compact_level_ptr.append(current_data_idx)
current_level_ptr_idx += num_pointers_expected
compact_node_offsets_np[ntotal] = current_level_ptr_idx
print(
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
) # Clear progress line
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
# --- Validation Checks ---
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
valid_check_passed = True
# Check 1: Total valid neighbors count
print(" Checking total valid neighbor count...")
print(f" Checking total valid neighbor count...")
expected_valid_count = np.sum(neighbors_np >= 0)
if total_valid_neighbors_counted != len(compact_neighbors_data):
print(
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
valid_check_passed = False
if expected_valid_count != len(compact_neighbors_data):
print(
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
valid_check_passed = False
else:
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
# Check 2: Final pointer indices consistency
print(" Checking final pointer indices...")
print(f" Checking final pointer indices...")
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
print(
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
file=sys.stderr,
)
valid_check_passed = False
if (
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
print(
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
valid_check_passed = False
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
valid_check_passed = False
else:
print(" OK: Final pointers match data size.")
print(f" OK: Final pointers match data size.")
if not valid_check_passed:
print(
"Error: Validation checks failed. Output file might be incorrect.",
file=sys.stderr,
)
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
# Optional: Exit here if validation fails
# return False
# --- Explicitly delete large intermediate arrays ---
print(
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
)
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
del neighbors_np
del offsets_np
gc.collect()
print(
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
)
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
# --- Write CSR HNSW graph data using unified function ---
print(
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
)
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
# Determine storage fourcc and data based on prune_embeddings
if prune_embeddings:
print(" Pruning embeddings: Writing NULL storage marker.")
print(f" Pruning embeddings: Writing NULL storage marker.")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
storage_data = b''
else:
# Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(" Preserving embeddings: Reading original storage data...")
print(f" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data")
else:
print(" No embeddings found in original file (NULL storage)")
print(f" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
storage_data = b''
# Use the unified write function
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
output_storage_fourcc,
storage_data,
)
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np,
compact_neighbors_data, output_storage_fourcc, storage_data)
# Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
@@ -660,66 +503,40 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
return False
except MemoryError as e:
print(
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
file=sys.stderr,
)
# Clean up potentially partially written output file?
try:
os.remove(output_filename)
except OSError:
pass
return False
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
# Clean up potentially partially written output file?
try: os.remove(output_filename)
except OSError: pass
return False
except EOFError as e:
print(
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
file=sys.stderr,
)
try:
os.remove(output_filename)
except OSError:
pass
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
try: os.remove(output_filename)
except OSError: pass
return False
except Exception as e:
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
try:
os.remove(output_filename)
except OSError:
pass
except OSError: pass
return False
# Ensure neighbors_np is deleted even if an error occurs after its allocation
finally:
try:
if "neighbors_np" in locals() and neighbors_np is not None:
del neighbors_np
gc.collect()
except NameError:
pass
if 'neighbors_np' in locals() and neighbors_np is not None:
del neighbors_np
gc.collect()
# --- Script Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
)
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
parser.add_argument(
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
)
parser.add_argument(
"--prune-embeddings",
action="store_true",
default=True,
help="Prune embedding storage (write NULL storage marker)",
)
parser.add_argument(
"--keep-embeddings",
action="store_true",
help="Keep embedding storage (overrides --prune-embeddings)",
)
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
parser.add_argument("--prune-embeddings", action="store_true", default=True,
help="Prune embedding storage (write NULL storage marker)")
parser.add_argument("--keep-embeddings", action="store_true",
help="Keep embedding storage (overrides --prune-embeddings)")
args = parser.parse_args()
@@ -728,12 +545,10 @@ if __name__ == "__main__":
sys.exit(1)
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
sys.exit(1)
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
sys.exit(1)
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
success = convert_hnsw_graph_to_csr(
args.input_index_file, args.output_csr_graph_file, prune_embeddings
)
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
if not success:
sys.exit(1)
sys.exit(1)

View File

@@ -1,20 +1,19 @@
import logging
import os
import shutil
import time
from pathlib import Path
from typing import Any, Literal, Optional
import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal, Optional
import shutil
import logging
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
from leann.interface import (
LeannBackendBuilderInterface,
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface,
)
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
logger = logging.getLogger(__name__)
@@ -29,12 +28,6 @@ def get_metric_map():
}
def normalize_l2(data: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(data, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
return data / norms
@register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
@@ -55,15 +48,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute and self.is_compact:
# Auto-correct: non-recompute requires non-compact storage for HNSW
logger.warning(
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
)
self.is_compact = False
self.build_params["is_compact"] = False
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss # type: ignore
path = Path(index_path)
@@ -84,7 +70,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index.hnsw.efConstruction = self.efConstruction
if self.distance_metric.lower() == "cosine":
data = normalize_l2(data)
faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
@@ -109,12 +95,16 @@ class HNSWBuilder(LeannBackendBuilderInterface):
# index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
raise RuntimeError(
"CSR conversion failed - cannot proceed with compact format"
)
class HNSWSearcher(BaseSearcher):
@@ -126,9 +116,7 @@ class HNSWSearcher(BaseSearcher):
)
from . import faiss # type: ignore
self.distance_metric = (
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
)
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -162,7 +150,7 @@ class HNSWSearcher(BaseSearcher):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
batch_size: int = 0,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""
Search for nearest neighbors using HNSW index.
@@ -186,36 +174,28 @@ class HNSWSearcher(BaseSearcher):
"""
from . import faiss # type: ignore
if not recompute_embeddings and self.is_pruned:
raise RuntimeError(
"Recompute is required for pruned/compact HNSW index. "
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
)
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
if query.dtype != np.float32:
query = query.astype(np.float32)
if self.distance_metric == "cosine":
query = normalize_l2(query)
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
if zmq_port is not None:
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
params.zmq_port = (
zmq_port # C++ code won't use this if recompute_embeddings is False
)
params.efSearch = complexity
params.beam_size = beam_width
# For OpenAI embeddings with cosine distance, disable relative distance check
# This prevents early termination when all scores are in a narrow range
embedding_model = self.meta.get("embedding_model", "").lower()
if self.distance_metric == "cosine" and any(
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
):
params.check_relative_distance = False
else:
params.check_relative_distance = True
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
params.pq_pruning_ratio = prune_ratio
@@ -225,7 +205,9 @@ class HNSWSearcher(BaseSearcher):
params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional":
params.local_prune = False
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
params.send_neigh_times_ratio = (
1.0 # Any value > 1e-6 triggers proportional mode
)
else: # "global"
params.local_prune = False
params.send_neigh_times_ratio = 0.0
@@ -237,7 +219,6 @@ class HNSWSearcher(BaseSearcher):
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
search_time = time.time()
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
@@ -246,8 +227,9 @@ class HNSWSearcher(BaseSearcher):
faiss.swig_ptr(labels),
params,
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}

View File

@@ -3,18 +3,17 @@ HNSW-specific embedding server
"""
import argparse
import json
import logging
import os
import sys
import threading
import time
import os
import zmq
import numpy as np
import msgpack
import json
from pathlib import Path
from typing import Optional
import msgpack
import numpy as np
import zmq
import sys
import logging
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -53,8 +52,8 @@ def create_hnsw_embedding_server(
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
@@ -79,319 +78,207 @@ def create_hnsw_embedding_server(
raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources
with open(passages_file) as f:
with open(passages_file, "r") as f:
meta = json.load(f)
# Let PassageManager handle path resolution uniformly. It supports fallback order:
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
# Dimension from metadata for shaping responses
try:
embedding_dim: int = int(meta.get("dimensions", 0))
except Exception:
embedding_dim = 0
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
Creates its own REP socket bound to zmq_port and polls with timeouts
to allow graceful shutdown.
"""
logger.info("ZMQ server thread started with shutdown support")
passages = PassageManager(meta["passage_sources"])
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
def zmq_server_thread():
"""ZMQ server thread"""
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
# Track last request type/length for shape-correct fallbacks
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
while True:
try:
message_bytes = socket.recv()
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
# Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes)
e2e_start = time.time()
request_payload = msgpack.unpackb(message_bytes)
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
response_bytes = msgpack.packb([model_name])
rep_socket.send(response_bytes)
continue
# Handle direct text embedding request
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Handle direct text embedding request
if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
rep_socket.send(msgpack.packb(embeddings.tolist()))
# Use unified embedding computation (now with model caching)
embeddings = compute_embeddings(
request_payload, model_name, mode=embedding_mode
)
response = embeddings.tolist()
socket.send(msgpack.packb(response))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
logger.info(
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
# Handle distance calculation request: [[ids], [query_vector]]
if (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
node_ids = request[0]
# Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
# Handle distance calculation requests
if (
isinstance(request_payload, list)
and len(request_payload) == 2
and isinstance(request_payload[0], list)
and isinstance(request_payload[1], list)
):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
# Get embeddings for node IDs
texts = []
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
txt = passage_data["text"]
texts.append(txt)
except KeyError:
logger.error(f"Passage with ID {nid} not found")
logger.error(f"Passage ID {nid} not found")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
flat_data = []
else:
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
# Calculate distances
if distance_metric == "l2":
distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
distances = -np.dot(embeddings, query_vector)
response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb(
[response_payload], use_single_float=True
)
logger.debug(
f"Sending distance response with {len(distances)} distances"
)
rep_socket.send(response_bytes)
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("ZMQ server thread exiting gracefully")
# Standard embedding request (passage ID lookup)
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
logger.error(
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
# Add shutdown coordination
shutdown_event = threading.Event()
node_ids = request_payload[0]
logger.debug(f"Request for {len(node_ids)} node embeddings")
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
# Look up texts by node IDs
texts = []
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
assert False
# Clean up other resources
try:
import gc
hidden_contiguous_f32 = np.ascontiguousarray(
embeddings, dtype=np.float32
)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
logger.info("Graceful shutdown completed")
sys.exit(0)
except zmq.Again:
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
# Register signal handlers within this function scope
import signal
traceback.print_exc()
socket.send(msgpack.packb([[], []]))
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Pass shutdown_event to ZMQ thread
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while not shutdown_event.is_set():
time.sleep(0.1) # Check shutdown more frequently
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("HNSW Server shutting down...")
shutdown_zmq_server()
return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__":
import signal
import sys
# Signal handlers are now registered within create_hnsw_embedding_server
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument(
@@ -412,7 +299,7 @@ if __name__ == "__main__":
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode",
)

View File

@@ -6,24 +6,24 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.3.3"
version = "0.1.8"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.3.3",
"leann-core==0.1.8",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
]
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect"
cmake.build-type = "Release"
build.verbose = true
build.tool-args = ["-j8"]
wheel.exclude = ["CMakeLists.txt", "src", "third_party"]
sdist.include = ["CMakeLists.txt", "src", "third_party", "leann_backend_hnsw/*.txt"]
cmake.args = ["-DCMAKE_BUILD_TYPE=Release"]
# Ensure CMake can find system libraries
build-dir = "build/{cache_tag}"
minimum-version = "build-system.requires"
# CMake definitions to optimize compilation and find Homebrew packages
# CMake definitions to optimize compilation
[tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8"
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
OpenMP_ROOT = {env = "OpenMP_ROOT"}
SKBUILD_SOABI = "YES"

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.3.3"
version = "0.1.8"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"
@@ -15,38 +15,16 @@ dependencies = [
"numpy>=1.20.0",
"tqdm>=4.60.0",
"psutil>=5.8.0",
"pyzmq>=23.0.0",
"pyzmq>=23.0.0,<27", # Cap at 26.x for manylinux2014 compatibility
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
"pymupdf>=1.23.0",
"pdfplumber>=0.10.0",
"nbconvert>=7.0.0", # For .ipynb file support
"gitignore-parser>=0.1.12", # For proper .gitignore handling
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
]
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
]
[project.scripts]
leann = "leann.cli:main"
leann_mcp = "leann.mcp:main"
[tool.setuptools.packages.find]
where = ["src"]
where = ["src"]

View File

@@ -8,14 +8,10 @@ if platform.system() == "Darwin":
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0"
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
if os.environ.get("CI") == "true":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends()
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]

View File

@@ -4,35 +4,23 @@ with the correct, original embedding logic from the user's reference code.
"""
import json
import logging
import pickle
import re
import subprocess
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Union
import numpy as np
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
import numpy as np
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
import logging
logger = logging.getLogger(__name__)
def get_registered_backends() -> list[str]:
"""Get list of registered backend names."""
return list(BACKEND_REGISTRY.keys())
def compute_embeddings(
chunks: list[str],
chunks: List[str],
model_name: str,
mode: str = "sentence-transformers",
use_server: bool = True,
@@ -49,7 +37,6 @@ def compute_embeddings(
- "sentence-transformers": Use sentence-transformers library (default)
- "mlx": Use MLX backend for Apple Silicon
- "openai": Use OpenAI embedding API
- "gemini": Use Google Gemini embedding API
use_server: Whether to use embedding server (True for search, False for build)
Returns:
@@ -74,7 +61,9 @@ def compute_embeddings(
)
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
"""Computes embeddings using sentence-transformers.
Args:
@@ -84,9 +73,9 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
logger.info(
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
)
import zmq
import msgpack
import numpy as np
import zmq
# Connect to embedding server
context = zmq.Context()
@@ -115,160 +104,39 @@ class SearchResult:
id: str
score: float
text: str
metadata: dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
def __init__(
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
):
self.offset_maps: dict[str, dict[str, int]] = {}
self.passage_files: dict[str, str] = {}
# Avoid materializing a single gigantic global map to reduce memory
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
# per-shard maps and do a lightweight per-shard lookup on demand.
self._total_count: int = 0
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
index_name_base = None
if metadata_file_path:
meta_name = Path(metadata_file_path).name
if meta_name.endswith(".meta.json"):
index_name_base = meta_name[: -len(".meta.json")]
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source.get("path", "")
index_file = source.get("index_path", "") # .idx file
# Fix path resolution - relative paths should be relative to metadata file directory
def _resolve_candidates(
primary: str,
relative_key: str,
default_name: Optional[str],
source_dict: dict[str, Any],
) -> list[Path]:
"""
Build an ordered list of candidate paths. For relative paths specified in
metadata, prefer resolution relative to the metadata file directory first,
then fall back to CWD-based resolution, and finally to conventional
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
"""
candidates: list[Path] = []
# 1) Primary path
if primary:
p = Path(primary)
if p.is_absolute():
candidates.append(p)
else:
# Prefer metadata-relative resolution for relative paths
if metadata_file_path:
candidates.append(Path(metadata_file_path).parent / p)
# Also consider CWD-relative as a fallback for legacy layouts
candidates.append(Path.cwd() / p)
# 2) metadata-relative explicit relative key (if present)
if metadata_file_path and source_dict.get(relative_key):
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
# 3) metadata-relative standard sibling filename
if metadata_file_path and default_name:
candidates.append(Path(metadata_file_path).parent / default_name)
return candidates
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
idx_candidates = _resolve_candidates(
index_file, "index_path_relative", idx_default, source
)
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
def _pick_existing(cands: list[Path]) -> str:
for c in cands:
if c.exists():
return str(c.resolve())
# Fallback to last candidate (best guess) even if not exists; will error below
return str(cands[-1].resolve()) if cands else ""
index_file = _pick_existing(idx_candidates)
passage_file = _pick_existing(pas_candidates)
passage_file = source["path"]
index_file = source["index_path"] # .idx file
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
self._total_count += len(offset_map)
def get_passage(self, passage_id: str) -> dict[str, Any]:
# Fast path: check each shard map (there are typically few shards).
# This avoids building a massive combined dict while keeping lookups
# bounded by the number of shards.
for passage_file, offset_map in self.offset_maps.items():
try:
offset = offset_map[passage_id]
with open(passage_file, encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
except KeyError:
continue
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed
with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
def filter_search_results(
self,
search_results: list[SearchResult],
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
) -> list[SearchResult]:
"""
Apply metadata filters to search results.
Args:
search_results: List of SearchResult objects
metadata_filters: Filter specifications to apply
Returns:
Filtered list of SearchResult objects
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying metadata filters to {len(search_results)} results")
# Convert SearchResult objects to dictionaries for the filter engine
result_dicts = []
for result in search_results:
result_dicts.append(
{
"id": result.id,
"score": result.score,
"text": result.text,
"metadata": result.metadata,
}
)
# Apply filters using the filter engine
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
# Convert back to SearchResult objects
filtered_results = []
for result_dict in filtered_dicts:
filtered_results.append(
SearchResult(
id=result_dict["id"],
score=result_dict["score"],
text=result_dict["text"],
metadata=result_dict["metadata"],
)
)
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
return filtered_results
def __len__(self) -> int:
return self._total_count
class LeannBuilder:
def __init__(
@@ -280,99 +148,19 @@ class LeannBuilder:
**backend_kwargs,
):
self.backend_name = backend_name
# Normalize incompatible combinations early (for consistent metadata)
if backend_name == "hnsw":
is_recompute = backend_kwargs.get("is_recompute", True)
is_compact = backend_kwargs.get("is_compact", True)
if is_recompute is False and is_compact is True:
warnings.warn(
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
UserWarning,
stacklevel=2,
)
backend_kwargs["is_compact"] = False
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
backend_name
)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory
self.embedding_model = embedding_model
self.dimensions = dimensions
self.embedding_mode = embedding_mode
# Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
# OpenAI models
("openai", "text-embedding-ada-002"),
("openai", "text-embedding-3-small"),
("openai", "text-embedding-3-large"),
# Voyage AI models
("voyage", "voyage-2"),
("voyage", "voyage-3"),
("voyage", "voyage-large-2"),
("voyage", "voyage-multilingual-2"),
("voyage", "voyage-code-2"),
# Cohere models
("cohere", "embed-english-v3.0"),
("cohere", "embed-multilingual-v3.0"),
("cohere", "embed-english-light-v3.0"),
("cohere", "embed-multilingual-light-v3.0"),
}
# Also check for patterns in model names
is_normalized = False
current_model_lower = embedding_model.lower()
current_mode_lower = embedding_mode.lower()
# Check exact matches
for mode, model in normalized_embeddings_models:
if (current_mode_lower == mode and current_model_lower == model) or (
mode in current_mode_lower and model in current_model_lower
):
is_normalized = True
break
# Check patterns
if not is_normalized:
# OpenAI patterns
if "openai" in current_mode_lower or "openai" in current_model_lower:
if any(
pattern in current_model_lower
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
):
is_normalized = True
# Voyage patterns
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
is_normalized = True
# Cohere patterns
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
if "embed" in current_model_lower:
is_normalized = True
# Handle distance metric
if is_normalized and "distance_metric" not in backend_kwargs:
backend_kwargs["distance_metric"] = "cosine"
warnings.warn(
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
f"Automatically setting distance_metric='cosine' for optimal performance. "
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
UserWarning,
stacklevel=2,
)
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
current_metric = backend_kwargs.get("distance_metric", "mips")
warnings.warn(
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
f"'{embedding_model}' may lead to suboptimal search results. "
f"Consider using 'cosine' distance metric for better performance.",
UserWarning,
stacklevel=2,
)
self.backend_kwargs = backend_kwargs
self.chunks: list[dict[str, Any]] = []
self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
passage_id = metadata.get("id", str(len(self.chunks)))
@@ -382,23 +170,6 @@ class LeannBuilder:
def build_index(self, index_path: str):
if not self.chunks:
raise ValueError("No chunks added.")
# Filter out invalid/empty text chunks early to keep passage and embedding counts aligned
valid_chunks: list[dict[str, Any]] = []
skipped = 0
for chunk in self.chunks:
text = chunk.get("text", "")
if isinstance(text, str) and text.strip():
valid_chunks.append(chunk)
else:
skipped += 1
if skipped > 0:
print(
f"Warning: Skipping {skipped} empty/invalid text chunk(s). Processing {len(valid_chunks)} valid chunks"
)
self.chunks = valid_chunks
if not self.chunks:
raise ValueError("All provided chunks are empty or invalid. Nothing to index.")
if self.dimensions is None:
self.dimensions = len(
compute_embeddings(
@@ -419,7 +190,9 @@ class LeannBuilder:
try:
from tqdm import tqdm
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError:
chunk_iterator = self.chunks
@@ -449,7 +222,9 @@ class LeannBuilder:
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
builder_instance.build(
embeddings, string_ids, index_path, **current_backend_kwargs
)
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
@@ -461,12 +236,8 @@ class LeannBuilder:
"passage_sources": [
{
"type": "jsonl",
# Preserve existing relative file names (backward-compatible)
"path": passages_file.name,
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
"path": str(passages_file),
"index_path": str(offset_file),
}
],
}
@@ -502,7 +273,9 @@ class LeannBuilder:
ids, embeddings = data
if not isinstance(embeddings, np.ndarray):
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
raise ValueError(
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
if len(ids) != embeddings.shape[0]:
raise ValueError(
@@ -514,7 +287,9 @@ class LeannBuilder:
if self.dimensions is None:
self.dimensions = embedding_dim
elif self.dimensions != embedding_dim:
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
raise ValueError(
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
logger.info(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
@@ -581,12 +356,8 @@ class LeannBuilder:
"passage_sources": [
{
"type": "jsonl",
# Preserve existing relative file names (backward-compatible)
"path": passages_file.name,
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
"path": str(passages_file),
"index_path": str(offset_file),
}
],
"built_from_precomputed_embeddings": True,
@@ -603,37 +374,27 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
logger.info(
f"Index built successfully from precomputed embeddings: {index_path}"
)
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists():
parent_dir = Path(index_path).parent
print(
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
)
# highlight in red the filenotfound error
raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
f"Leann metadata file not found at {self.meta_path_str}"
)
with open(self.meta_path_str, encoding="utf-8") as f:
with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
# Delegate portability handling to PassageManager
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
self.embedding_mode = self.meta_data.get(
"embedding_mode", "sentence-transformers"
)
# Preserve backend name for conditional parameter forwarding
self.backend_name = backend_name
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
@@ -653,57 +414,13 @@ class LeannSearcher:
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
**kwargs,
) -> list[SearchResult]:
"""
Search for nearest neighbors with optional metadata filtering.
Args:
query: Text query to search for
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata.
Format: {"field_name": {"operator": value}}
Supported operators:
- Comparison: "==", "!=", "<", "<=", ">", ">="
- Membership: "in", "not_in"
- String: "contains", "starts_with", "ends_with"
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
**kwargs: Backend-specific parameters
Returns:
List of SearchResult objects with text, metadata, and similarity scores
"""
# Handle grep search
if use_grep:
return self._grep_search(query, top_k)
) -> List[SearchResult]:
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
logger.info(f" Metadata filters: {metadata_filters}")
logger.info(f" Additional kwargs: {kwargs}")
# Smart top_k detection and adjustment
# Use PassageManager length (sum of shard sizes) to avoid
# depending on a massive combined map
total_docs = len(self.passage_manager)
original_top_k = top_k
if top_k > total_docs:
top_k = total_docs
logger.warning(
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
)
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
zmq_port = None
start_time = time.time()
@@ -724,39 +441,31 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
logger.info(f" Embedding time: {embedding_time} seconds")
# logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
backend_search_kwargs: dict[str, Any] = {
"complexity": complexity,
"beam_width": beam_width,
"prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy,
"zmq_port": zmq_port,
}
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
**backend_search_kwargs,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
# logger.info(f" Search time: {search_time} seconds")
logger.info(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = []
if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
# Python 3.9 does not support zip(strict=...); lengths are expected to match
for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0])
):
@@ -770,168 +479,37 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}),
)
)
# Color codes for better logging
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RESET = "\033[0m"
# Truncate text for display (first 100 chars)
display_text = passage_data["text"]
display_text = passage_data['text']
logger.info(
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
)
except KeyError:
RED = "\033[91m"
RESET = "\033[0m"
logger.error(
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
# Apply metadata filters if specified
if metadata_filters:
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
enriched_results = self.passage_manager.filter_search_results(
enriched_results, metadata_filters
)
# Define color codes outside the loop for final message
GREEN = "\033[92m"
RESET = "\033[0m"
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results
def _find_jsonl_file(self) -> Optional[str]:
"""Find the .jsonl file containing raw passages for grep search"""
index_path = Path(self.meta_path_str).parent
potential_files = [
index_path / "documents.leann.passages.jsonl",
index_path.parent / "documents.leann.passages.jsonl",
]
for file_path in potential_files:
if file_path.exists():
return str(file_path)
return None
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Perform grep-based search on raw passages"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl passages file found for grep search")
try:
cmd = ["grep", "-i", "-n", query, jsonl_file]
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 1:
return []
elif result.returncode != 0:
raise RuntimeError(f"Grep failed: {result.stderr}")
matches = []
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split(":", 1)
if len(parts) != 2:
continue
try:
data = json.loads(parts[1])
text = data.get("text", "")
score = text.lower().count(query.lower())
matches.append(
SearchResult(
id=data.get("id", parts[0]),
text=text,
metadata=data.get("metadata", {}),
score=float(score),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
except FileNotFoundError:
raise RuntimeError(
"grep command not found. Please install grep or use semantic search."
)
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Fallback regex search"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl file found")
pattern = re.compile(re.escape(query), re.IGNORECASE)
matches = []
with open(jsonl_file, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if pattern.search(line):
try:
data = json.loads(line.strip())
matches.append(
SearchResult(
id=data.get("id", str(line_num)),
text=data.get("text", ""),
metadata=data.get("metadata", {}),
score=float(len(pattern.findall(data.get("text", "")))),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios.
"""
backend = getattr(self.backend_impl, "embedding_server_manager", None)
if backend is not None:
backend.stop_server()
# Enable automatic cleanup patterns
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
try:
self.cleanup()
except Exception:
pass
def __del__(self):
try:
self.cleanup()
except Exception:
# Avoid noisy errors during interpreter shutdown
pass
class LeannChat:
def __init__(
self,
index_path: str,
llm_config: Optional[dict[str, Any]] = None,
llm_config: Optional[Dict[str, Any]] = None,
enable_warmup: bool = False,
searcher: Optional[LeannSearcher] = None,
**kwargs,
):
if searcher is None:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self._owns_searcher = True
else:
self.searcher = searcher
self._owns_searcher = False
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self.llm = get_llm(llm_config)
def ask(
@@ -943,11 +521,8 @@ class LeannChat:
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None,
llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
**search_kwargs,
):
if llm_kwargs is None:
@@ -962,12 +537,10 @@ class LeannChat:
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
metadata_filters=metadata_filters,
batch_size=batch_size,
**search_kwargs,
)
search_time = time.time() - search_time
logger.info(f" Search time: {search_time} seconds")
# logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -976,10 +549,7 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge."
)
ask_time = time.time()
ans = self.llm.ask(prompt, **llm_kwargs)
ask_time = time.time() - ask_time
logger.info(f" Ask time: {ask_time} seconds")
return ans
def start_interactive(self):
@@ -996,30 +566,3 @@ class LeannChat:
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
# Only stop the embedding server if this LeannChat instance created the searcher.
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()
# Enable automatic cleanup patterns
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
try:
self.cleanup()
except Exception:
pass
def __del__(self):
try:
self.cleanup()
except Exception:
pass

View File

@@ -4,12 +4,11 @@ This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
"""
import difflib
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Optional
import difflib
import torch
# Configure logging
@@ -17,12 +16,11 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_ollama_models(host: str) -> list[str]:
def check_ollama_models() -> List[str]:
"""Check available Ollama models and return a list"""
try:
import requests
response = requests.get(f"{host}/api/tags", timeout=5)
response = requests.get("http://localhost:11434/api/tags", timeout=5)
if response.status_code == 200:
data = response.json()
return [model["name"] for model in data.get("models", [])]
@@ -33,52 +31,51 @@ def check_ollama_models(host: str) -> list[str]:
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import re
import requests
import re
# Split model name and tag
if ":" in model_name:
base_model, requested_tag = model_name.split(":", 1)
if ':' in model_name:
base_model, requested_tag = model_name.split(':', 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if "<" in tag or ">" in tag:
if '<' in tag or '>' in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
@@ -86,80 +83,76 @@ def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
return []
query_lower = query.lower()
suggestions = []
# 1. Exact matches first
exact_matches = [m for m in available_models if query_lower == m.lower()]
suggestions.extend(exact_matches)
# 2. Starts with query
starts_with = [
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
]
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
suggestions.extend(starts_with)
# 3. Contains query
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
suggestions.extend(contains)
# 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(":")[0].split("-")[0]
return model_name.split(':')[0].split('-')[0]
query_base = get_base_name(query_lower)
base_matches = [
m
for m in available_models
m for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions
]
suggestions.extend(base_matches)
# 5. Family/variant matching
model_families = {
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
"qwen": ["qwen", "qwen2", "qwen3"],
"gemma": ["gemma", "gemma2"],
"phi": ["phi", "phi2", "phi3"],
"mistral": ["mistral", "mixtral", "openhermes"],
"dolphin": ["dolphin", "openchat"],
"deepseek": ["deepseek", "deepseek-coder"],
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
'qwen': ['qwen', 'qwen2', 'qwen3'],
'gemma': ['gemma', 'gemma2'],
'phi': ['phi', 'phi2', 'phi3'],
'mistral': ['mistral', 'mixtral', 'openhermes'],
'dolphin': ['dolphin', 'openchat'],
'deepseek': ['deepseek', 'deepseek-coder']
}
query_family = None
for family, variants in model_families.items():
if any(variant in query_lower for variant in variants):
query_family = family
break
if query_family:
family_variants = model_families[query_family]
family_matches = [
m
for m in available_models
m for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
]
suggestions.extend(family_matches)
# 6. Use difflib for remaining fuzzy matches
remaining_models = [m for m in available_models if m not in suggestions]
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
suggestions.extend(difflib_matches)
return suggestions[:8] # Return top 8 suggestions
@@ -169,13 +162,15 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
# Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
"""Use difflib to find similar model names"""
if not available_models:
return []
# Get close matches using fuzzy matching
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
suggestions = difflib.get_close_matches(
invalid_model, available_models, n=3, cutoff=0.3
)
return suggestions
@@ -183,50 +178,49 @@ def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading"""
try:
from huggingface_hub import model_info
model_info(model_name)
return True
except Exception:
return False
def get_popular_hf_models() -> list[str]:
def get_popular_hf_models() -> List[str]:
"""Return a list of popular HuggingFace models for suggestions"""
try:
from huggingface_hub import list_models
# Get popular text-generation models, sorted by downloads
models = list_models(
filter="text-generation",
sort="downloads",
direction=-1,
limit=20, # Get top 20 most downloaded
limit=20 # Get top 20 most downloaded
)
# Extract model names and filter for chat/conversation models
model_names = []
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
for model in models:
model_name = model.id if hasattr(model, "id") else str(model)
model_name = model.id if hasattr(model, 'id') else str(model)
# Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name)
elif len(model_names) < 10: # Fill up with other popular models
model_names.append(model_name)
return model_names[:10] if model_names else _get_fallback_hf_models()
except Exception:
# Fallback to static list if API call fails
return _get_fallback_hf_models()
def _get_fallback_hf_models() -> list[str]:
def _get_fallback_hf_models() -> List[str]:
"""Fallback list of popular HuggingFace models"""
return [
"microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large",
"microsoft/DialoGPT-large",
"facebook/blenderbot-400M-distill",
"microsoft/phi-2",
"deepseek-ai/deepseek-llm-7b-chat",
@@ -234,44 +228,44 @@ def _get_fallback_hf_models() -> list[str]:
"facebook/blenderbot_small-90M",
"microsoft/phi-1_5",
"facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B",
"EleutherAI/gpt-neo-1.3B"
]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
try:
from huggingface_hub import list_models
# HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models(
search=query,
filter="text-generation",
sort="downloads",
sort="downloads",
direction=-1,
limit=limit,
limit=limit
)
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
# If direct search doesn't return enough results, try some variations
if len(model_names) < 3:
# Try searching for partial matches or common variations
variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
if base_query != query.lower():
variations.append(base_query)
# Try common model name patterns
if "gpt" in query.lower():
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
elif "llama" in query.lower():
variations.extend(["llama2", "alpaca", "vicuna"])
elif "bert" in query.lower():
variations.extend(["roberta", "distilbert", "albert"])
if 'gpt' in query.lower():
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
elif 'llama' in query.lower():
variations.extend(['llama2', 'alpaca', 'vicuna'])
elif 'bert' in query.lower():
variations.extend(['roberta', 'distilbert', 'albert'])
# Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
try:
@@ -280,15 +274,13 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
filter="text-generation",
sort="downloads",
direction=-1,
limit=3,
limit=3
)
var_names = [
model.id if hasattr(model, "id") else str(model) for model in var_models
]
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
model_names.extend(var_names)
except Exception:
except:
continue
# Remove duplicates while preserving order
seen = set()
unique_models = []
@@ -296,75 +288,67 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
if model not in seen:
seen.add(model)
unique_models.append(model)
return unique_models[:limit]
except Exception:
# If search fails, return empty list
return []
def search_hf_models(query: str, limit: int = 10) -> list[str]:
def search_hf_models(query: str, limit: int = 10) -> List[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
) -> Optional[str]:
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
available_models = check_ollama_models()
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += "\n\nTo install the requested model:\n"
error_msg += f"\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(":")[0]
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
error_msg += (
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
)
base_model = model_name.split(':')[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += (
"\n\nDid you mean one of these installed models?\n"
+ "\nTry to use ollama pull to install the model you need\n"
)
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
@@ -373,25 +357,23 @@ def validate_model_and_suggest(
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += (
f"\n ollama pull {available_tags[0]} # Install recommended variant"
)
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg
elif llm_type == "hf":
# For HF models, we can do a quick existence check
if not check_hf_model_exists(model_name):
# Use HF Hub's native fuzzy search directly
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
if search_suggestions:
error_msg += "\n\nDid you mean one of these?\n"
@@ -403,10 +385,10 @@ def validate_model_and_suggest(
error_msg += "\n\nPopular chat models:\n"
for i, model in enumerate(popular_models[:5], 1):
error_msg += f" {i}. {model}\n"
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
return error_msg
return None # Model is valid or we can't check
@@ -422,6 +404,7 @@ class LLMInterface(ABC):
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
@@ -433,6 +416,7 @@ class LLMInterface(ABC):
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
@@ -467,61 +451,38 @@ class OllamaChat(LLMInterface):
# Check if the Ollama server is responsive
if host:
requests.get(host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host)
model_error = validate_model_and_suggest(model, "ollama")
if model_error:
raise ValueError(model_error)
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
logger.error(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
import requests
import json
import requests
full_url = f"{self.host}/api/generate"
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
# Remove thinking_budget from options as it's not a standard Ollama option
options.pop("thinking_budget", None)
# Only apply reasoning parameters to models that support it
reasoning_supported_models = [
"gpt-oss:20b",
"gpt-oss:120b",
"deepseek-r1",
"deepseek-coder",
]
if thinking_budget in ["low", "medium", "high"]:
if any(model in self.model.lower() for model in reasoning_supported_models):
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
else:
logger.warning(
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
)
payload = {
"model": self.model,
"prompt": prompt,
"stream": False, # Keep it simple for now
"options": options,
"options": kwargs,
}
logger.debug(f"Sending request to Ollama: {payload}")
try:
logger.info("Sending request to Ollama and waiting for response...")
logger.info(f"Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status()
@@ -545,15 +506,15 @@ class HFChat(LLMInterface):
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf")
if model_error:
raise ValueError(model_error)
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
@@ -570,67 +531,42 @@ class HFChat(LLMInterface):
self.device = "cpu"
logger.info("No GPU detected. Using CPU.")
# Load tokenizer and model with timeout protection
try:
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Model download/loading timed out")
# Set timeout for model loading (60 seconds)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60)
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
)
logger.info(f"Successfully loaded {model_name}")
finally:
signal.alarm(0) # Cancel the alarm
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
except TimeoutError:
logger.error(f"Model loading timed out for {model_name}")
raise RuntimeError(
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
)
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
raise
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True
)
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str:
print("kwargs in HF: ", kwargs)
print('kwargs in HF: ', kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
@@ -641,13 +577,13 @@ class HFChat(LLMInterface):
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
max_length=2048
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -661,79 +597,28 @@ class HFChat(LLMInterface):
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config)
outputs = self.model.generate(
**inputs,
**generation_config
)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class GeminiChat(LLMInterface):
"""LLM interface for Google Gemini models."""
def __init__(self, model: str = "gemini-2.5-flash", api_key: Optional[str] = None):
self.model = model
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError(
"Gemini API key is required. Set GEMINI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing Gemini Chat with model='{model}'")
try:
import google.genai as genai
self.client = genai.Client(api_key=self.api_key)
except ImportError:
raise ImportError(
"The 'google-genai' library is required for Gemini models. Please install it with 'uv pip install google-genai'."
)
def ask(self, prompt: str, **kwargs) -> str:
logger.info(f"Sending request to Gemini with model {self.model}")
try:
from google.genai.types import GenerateContentConfig
generation_config = GenerateContentConfig(
temperature=kwargs.get("temperature", 0.7),
max_output_tokens=kwargs.get("max_tokens", 1000),
)
# Handle top_p parameter
if "top_p" in kwargs:
generation_config.top_p = kwargs["top_p"]
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=generation_config,
)
# Handle potential None response text
response_text = response.text
if response_text is None:
logger.warning("Gemini returned None response text")
return ""
return response_text.strip()
except Exception as e:
logger.error(f"Error communicating with Gemini: {e}")
return f"Error: Could not get a response from Gemini. Details: {e}"
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
@@ -762,38 +647,15 @@ class OpenAIChat(LLMInterface):
params = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7),
**{
k: v
for k, v in kwargs.items()
if k not in ["max_tokens", "temperature"]
},
}
# Handle max_tokens vs max_completion_tokens based on model
max_tokens = kwargs.get("max_tokens", 1000)
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
# o-series models use max_completion_tokens
params["max_completion_tokens"] = max_tokens
params["temperature"] = 1.0
else:
# Other models use max_tokens
params["max_tokens"] = max_tokens
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model (partial match for model names)
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
# Use the correct OpenAI reasoning parameter format
params["reasoning_effort"] = thinking_budget
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
else:
logger.warning(
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
)
# Add other kwargs (excluding thinking_budget as it's handled above)
for k, v in kwargs.items():
if k not in ["max_tokens", "temperature", "thinking_budget"]:
params[k] = v
logger.info(f"Sending request to OpenAI with model {self.model}")
try:
@@ -813,7 +675,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
"""
Factory function to get an LLM interface based on configuration.
@@ -847,8 +709,6 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":
return SimulatedChat()
else:

View File

File diff suppressed because it is too large Load Diff

View File

@@ -4,13 +4,11 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance
"""
import logging
import os
import time
from typing import Any
import numpy as np
import torch
from typing import List, Dict, Any
import logging
import os
# Set up logger with proper level
logger = logging.getLogger(__name__)
@@ -19,18 +17,16 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {}
_model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: list[str],
texts: List[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -38,7 +34,7 @@ def compute_embeddings(
Args:
texts: List of texts to compute embeddings for
model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
is_build: Whether this is a build operation (shows progress bar)
batch_size: Batch size for processing
adaptive_optimization: Whether to use adaptive optimization based on batch size
@@ -53,31 +49,23 @@ def compute_embeddings(
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
manual_tokenize=manual_tokenize,
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else:
raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers(
texts: list[str],
texts: List[str],
model_name: str,
use_fp16: bool = True,
device: str = "auto",
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
@@ -126,7 +114,9 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key]
else:
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
logger.info(
f"Loading and caching optimized SentenceTransformer model: {model_name}"
)
from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}")
@@ -144,7 +134,9 @@ def compute_embeddings_sentence_transformers(
if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError:
logger.warning("Some MPS optimizations not available in this PyTorch version")
logger.warning(
"Some MPS optimizations not available in this PyTorch version"
)
elif device == "cpu":
# TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4))
@@ -221,158 +213,41 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
# Use torch.inference_mode for optimal performance
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
start_time = time.time()
if not manual_tokenize:
# Use SentenceTransformer's optimized encode path (default)
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
# Synchronize if CUDA to measure accurate wall time
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
all_embeddings: list[np.ndarray] = []
# Progress bar when building or for large inputs
show_progress = is_build or len(texts) > 32
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
unit="batch",
)
else:
batch_iter = range(0, len(texts), batch_size)
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import os
import openai
import os
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
# Validate input list
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
# Extra validation: abort early if any item is empty/whitespace
invalid_count = sum(1 for t in texts if not isinstance(t, str) or not t.strip())
if invalid_count > 0:
raise ValueError(
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
@@ -389,19 +264,10 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length
max_batch_size = 800 # Conservative batch size because the token limit is 300K
max_batch_size = 100 # Conservative batch size
all_embeddings = []
# get the avg len of texts
avg_len = sum(len(text) for text in texts) / len(texts)
print(f"avg len of texts: {avg_len}")
# if avg len is less than 1000, use the max batch size
if avg_len > 300:
max_batch_size = 500
# if avg len is less than 1000, use the max batch size
try:
from tqdm import tqdm
@@ -427,12 +293,15 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
print(f"len of embeddings: {len(embeddings)}")
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
return embeddings
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model."""
try:
@@ -504,366 +373,3 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
# Stack numpy arrays
return np.stack(all_embeddings)
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Args:
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import requests
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
)
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
"Installation: https://ollama.com/download"
)
raise RuntimeError(error_msg)
except Exception as e:
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
# Filter for embedding models (models that support embeddings)
embedding_models = []
suggested_embedding_models = [
"nomic-embed-text",
"mxbai-embed-large",
"bge-m3",
"all-minilm",
"snowflake-arctic-embed",
]
for model in model_names:
# Check if it's an embedding model (by name patterns or known models)
base_name = model.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model)
# Check if model exists (handle versioned names) and resolve to full name
resolved_model_name = None
for name in model_names:
# Exact match
if model_name == name:
resolved_model_name = name
break
# Match without version tag (use the versioned name)
elif model_name == name.split(":")[0]:
resolved_model_name = name
break
if not resolved_model_name:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model
error_msg += "📦 To install this embedding model:\n"
error_msg += f" ollama pull {model_name}\n\n"
# Show available embedding models
if embedding_models:
error_msg += "✅ Available embedding models:\n"
for model in embedding_models[:5]:
error_msg += f"{model}\n"
if len(embedding_models) > 5:
error_msg += f" ... and {len(embedding_models) - 5} more\n"
else:
error_msg += "💡 Popular embedding models to install:\n"
for model in suggested_embedding_models[:3]:
error_msg += f" • ollama pull {model}\n"
error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg)
# Use the resolved model name for all subsequent operations
if resolved_model_name != model_name:
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
)
if test_response.status_code != 200:
error_msg = (
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
f"Please use an embedding model like:\n"
)
for model in suggested_embedding_models[:3]:
error_msg += f"{model}\n"
raise ValueError(error_msg)
except requests.exceptions.RequestException:
# If test fails, continue anyway - model might still work
pass
except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}")
# Determine batch size based on device availability
# Check for CUDA/MPS availability using torch if available
batch_size = 32 # Default for MPS/CPU
try:
import torch
if torch.cuda.is_available():
batch_size = 128 # CUDA gets larger batch size
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
batch_size = 32 # MPS gets smaller batch size
except ImportError:
# If torch is not available, use conservative batch size
batch_size = 32
logger.info(f"Using batch size: {batch_size}")
def get_batch_embeddings(batch_texts):
"""Get embeddings for a batch of texts."""
all_embeddings = []
failed_indices = []
for i, text in enumerate(batch_texts):
max_retries = 3
retry_count = 0
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logger.error(f"Failed to get embedding for text {i}: {e}")
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
except ImportError:
show_progress = False
# Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else:
batch_iterator = range(num_batches)
for batch_idx in batch_iterator:
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
# Adjust failed indices to global indices
global_failed = [start_idx + idx for idx in batch_failed]
all_failed_indices.extend(global_failed)
all_embeddings.extend(batch_embeddings)
# Handle failed embeddings
if all_failed_indices:
if len(all_failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings")
logger.warning(
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
# Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding:
embedding_dim = len(valid_embedding)
for i, embedding in enumerate(all_embeddings):
if embedding is None:
all_embeddings[i] = [0.0] * embedding_dim
# Remove None values
all_embeddings = [e for e in all_embeddings if e is not None]
if not all_embeddings:
raise RuntimeError("No valid embeddings were computed")
# Validate embedding dimensions
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
# Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize embeddings (L2 normalization)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings
def compute_embeddings_gemini(
texts: list[str], model_name: str = "text-embedding-004", is_build: bool = False
) -> np.ndarray:
"""
Compute embeddings using Google Gemini API.
Args:
texts: List of texts to compute embeddings for
model_name: Gemini model name (default: "text-embedding-004")
is_build: Whether this is a build operation (shows progress bar)
Returns:
Embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import os
import google.genai as genai
except ImportError as e:
raise ImportError(f"Google GenAI package not installed: {e}")
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("GEMINI_API_KEY environment variable not set")
# Cache Gemini client
cache_key = "gemini_client"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = genai.Client(api_key=api_key)
_model_cache[cache_key] = client
logger.info("Gemini client cached")
logger.info(
f"Computing embeddings for {len(texts)} texts using Gemini API, model: '{model_name}'"
)
# Gemini supports batch embedding
max_batch_size = 100 # Conservative batch size for Gemini
all_embeddings = []
try:
from tqdm import tqdm
total_batches = (len(texts) + max_batch_size - 1) // max_batch_size
batch_range = range(0, len(texts), max_batch_size)
batch_iterator = tqdm(
batch_range, desc="Computing embeddings", unit="batch", total=total_batches
)
except ImportError:
# Fallback when tqdm is not available
batch_iterator = range(0, len(texts), max_batch_size)
for i in batch_iterator:
batch_texts = texts[i : i + max_batch_size]
try:
# Use the embed_content method from the new Google GenAI SDK
response = client.models.embed_content(
model=model_name,
contents=batch_texts,
config=genai.types.EmbedContentConfig(
task_type="RETRIEVAL_DOCUMENT" # For document embedding
),
)
# Extract embeddings from response
for embedding_data in response.embeddings:
all_embeddings.append(embedding_data.values)
except Exception as e:
logger.error(f"Batch {i} failed: {e}")
raise
embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings

View File

@@ -1,14 +1,13 @@
import time
import atexit
import logging
import os
import socket
import subprocess
import sys
import time
import os
import logging
from pathlib import Path
from typing import Optional
# Lightweight, self-contained server manager with no cross-process inspection
import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -19,31 +18,136 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def _is_colab_environment() -> bool:
"""Check if we're running in Google Colab environment."""
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
def _get_available_port(start_port: int = 5557) -> int:
"""Get an available port starting from start_port."""
port = start_port
while port < start_port + 100: # Try up to 100 ports
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
# Note: All cross-process scanning helpers removed for simplicity
def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
cmdline = proc.info["cmdline"]
if not cmdline:
continue
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
logger.warning(f"Could not check process on port {port}: {e}")
return False
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
logger.info(f"Found compatible server on port {port}")
return port, True
else:
logger.info(f"Port {port} has incompatible server, trying next port...")
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager:
@@ -62,16 +166,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
# Track last-started config for in-process reuse only
self._server_config: Optional[dict] = None
self._atexit_registered = False
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
self._finalizer = None
def start_server(
self,
@@ -80,58 +175,69 @@ class EmbeddingServerManager:
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
"""
Starts the embedding server process.
# If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port:
logger.info("Reusing in-process server")
Args:
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server.
Returns:
tuple[bool, int]: (success, actual_port_used)
"""
passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, (
"a compatible running server should set server_port"
)
return True, self.server_port
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Always pick a fresh available port
# Find available port (compatible or free)
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
except RuntimeError as e:
logger.error(str(e))
return False, port
# Start a new server
if is_compatible:
logger.info(f"Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port
if actual_port != port:
logger.info(f"Using port {actual_port} instead of {port}")
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server."""
if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False
logger.info(f"Starting server on port {actual_port} for Colab environment")
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
return True
# Use a simpler startup strategy for Colab
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
logger.info(
"Existing server process is incompatible. Should start a new server."
)
return False
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
@@ -163,13 +269,9 @@ class EmbeddingServerManager:
]
if kwargs.get("passages_file"):
# Convert to absolute path to ensure subprocess can find the file
passages_file = Path(kwargs["passages_file"]).resolve()
command.extend(["--passages-file", str(passages_file)])
command.extend(["--passages-file", str(kwargs["passages_file"])])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("distance_metric"):
command.extend(["--distance-metric", kwargs["distance_metric"]])
return command
@@ -178,62 +280,22 @@ class EmbeddingServerManager:
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
# Embedding servers use many print statements that can fill stdout buffers
is_ci = os.environ.get("CI") == "true"
if is_ci:
stdout_target = subprocess.DEVNULL
stderr_target = None # Keep stderr for error debugging in CI
logger.info(
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
)
else:
stdout_target = None # Direct to console for visible logs
stderr_target = None # Direct to console for visible logs
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
stdout=None, # Direct to console
stderr=None, # Direct to console
)
self.server_port = port
# Record config for in-process reuse
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
# Always attempt best-effort finalize at interpreter exit
atexit.register(self._finalize_process)
# Use a lambda to avoid issues with bound methods
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
# Touch finalizer so it knows there is a live process
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
pass
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready."""
@@ -258,114 +320,29 @@ class EmbeddingServerManager:
if not self.server_process:
return
if self.server_process and self.server_process.poll() is not None:
if self.server_process.poll() is not None:
# Process already terminated
self.server_process = None
self.server_port = None
self._server_config = None
return
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
# Use simple termination first; if the server installed signal handlers,
# it will exit cleanly. Otherwise escalate to kill after a short wait.
try:
self.server_process.terminate()
except Exception:
pass
self.server_process.terminate()
try:
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
self.server_process.wait(timeout=5)
logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
logger.warning(
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
)
try:
self.server_process.kill()
except Exception:
pass
try:
self.server_process.wait(timeout=2)
logger.info(f"Server process {self.server_process.pid} killed successfully.")
except subprocess.TimeoutExpired:
logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung"
)
self.server_process.kill()
# Clean up process resources with timeout to avoid CI hang
# Clean up process resources to prevent resource tracker warnings
try:
# Use shorter timeout in CI environments
is_ci = os.environ.get("CI") == "true"
timeout = 3 if is_ci else 10
self.server_process.wait(timeout=timeout)
logger.info(f"Server process {self.server_process.pid} cleanup completed")
except subprocess.TimeoutExpired:
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
except Exception as e:
logger.warning(f"Error during process cleanup: {e}")
finally:
self.server_process = None
self.server_port = None
self._server_config = None
def _finalize_process(self) -> None:
"""Best-effort cleanup used by weakref.finalize/atexit."""
try:
self.stop_server()
self.server_process.wait() # Ensure process is fully cleaned up
except Exception:
pass
def _adopt_existing_server(self, *args, **kwargs) -> None:
# Removed: cross-process adoption no longer supported
return
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback (unified)
if not self._atexit_registered:
atexit.register(self._finalize_process)
self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error("Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port
time.sleep(wait_interval)
logger.error(f"Colab server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
self.server_process = None

View File

@@ -1,14 +1,15 @@
from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
import numpy as np
from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes"""
@abstractmethod
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index
Args:
@@ -52,7 +53,7 @@ class LeannBackendSearcherInterface(ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None,
**kwargs,
) -> dict[str, Any]:
) -> Dict[str, Any]:
"""Search for nearest neighbors
Args:

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env python3
import json
import subprocess
import sys
def handle_request(request):
if request.get("method") == "initialize":
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"capabilities": {"tools": {}},
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
},
}
elif request.get("method") == "tools/list":
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"tools": [
{
"name": "leann_search",
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
🎯 **Perfect for**:
- "How does authentication work?" → finds auth-related code
- "Error handling patterns" → locates try-catch blocks and error logic
- "Database connection setup" → finds DB initialization code
- "API endpoint definitions" → locates route handlers
- "Configuration management" → finds config files and usage
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
"inputSchema": {
"type": "object",
"properties": {
"index_name": {
"type": "string",
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
},
"query": {
"type": "string",
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
},
"top_k": {
"type": "integer",
"default": 5,
"minimum": 1,
"maximum": 20,
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
},
"complexity": {
"type": "integer",
"default": 32,
"minimum": 16,
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
},
"required": ["index_name", "query"],
},
},
{
"name": "leann_list",
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
"inputSchema": {"type": "object", "properties": {}},
},
]
},
}
elif request.get("method") == "tools/call":
tool_name = request["params"]["name"]
args = request["params"].get("arguments", {})
try:
if tool_name == "leann_search":
# Validate required parameters
if not args.get("index_name") or not args.get("query"):
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": "Error: Both index_name and query are required",
}
]
},
}
# Build simplified command with non-interactive flag for MCP compatibility
cmd = [
"leann",
"search",
args["index_name"],
args["query"],
f"--top-k={args.get('top_k', 5)}",
f"--complexity={args.get('complexity', 32)}",
"--non-interactive",
]
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_list":
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": result.stdout
if result.returncode == 0
else f"Error: {result.stderr}",
}
]
},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {"code": -1, "message": str(e)},
}
def main():
for line in sys.stdin:
try:
request = json.loads(line.strip())
response = handle_request(request)
if response:
print(json.dumps(response))
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {"code": -1, "message": str(e)},
}
print(json.dumps(error_response))
sys.stdout.flush()
if __name__ == "__main__":
main()

View File

@@ -1,240 +0,0 @@
"""
Metadata filtering engine for LEANN search results.
This module provides generic metadata filtering capabilities that can be applied
to search results from any LEANN backend. The filtering supports various
operators for different data types including numbers, strings, booleans, and lists.
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
# Type alias for filter specifications
FilterValue = Union[str, int, float, bool, list]
FilterSpec = dict[str, FilterValue]
MetadataFilters = dict[str, FilterSpec]
class MetadataFilterEngine:
"""
Engine for evaluating metadata filters against search results.
Supports various operators for filtering based on metadata fields:
- Comparison: ==, !=, <, <=, >, >=
- Membership: in, not_in
- String operations: contains, starts_with, ends_with
- Boolean operations: is_true, is_false
"""
def __init__(self):
"""Initialize the filter engine with supported operators."""
self.operators = {
"==": self._equals,
"!=": self._not_equals,
"<": self._less_than,
"<=": self._less_than_or_equal,
">": self._greater_than,
">=": self._greater_than_or_equal,
"in": self._in,
"not_in": self._not_in,
"contains": self._contains,
"starts_with": self._starts_with,
"ends_with": self._ends_with,
"is_true": self._is_true,
"is_false": self._is_false,
}
def apply_filters(
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
) -> list[dict[str, Any]]:
"""
Apply metadata filters to a list of search results.
Args:
search_results: List of result dictionaries, each containing 'metadata' field
metadata_filters: Dictionary of filter specifications
Format: {"field_name": {"operator": value}}
Returns:
Filtered list of search results
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying filters: {metadata_filters}")
logger.debug(f"Input results count: {len(search_results)}")
filtered_results = []
for result in search_results:
if self._evaluate_filters(result, metadata_filters):
filtered_results.append(result)
logger.debug(f"Filtered results count: {len(filtered_results)}")
return filtered_results
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
"""
Evaluate all filters against a single search result.
All filters must pass (AND logic) for the result to be included.
Args:
result: Full search result dictionary (including metadata, text, etc.)
filters: Filter specifications to evaluate
Returns:
True if all filters pass, False otherwise
"""
for field_name, filter_spec in filters.items():
if not self._evaluate_field_filter(result, field_name, filter_spec):
return False
return True
def _evaluate_field_filter(
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
) -> bool:
"""
Evaluate a single field filter against a search result.
Args:
result: Full search result dictionary
field_name: Name of the field to filter on
filter_spec: Filter specification for this field
Returns:
True if the filter passes, False otherwise
"""
# First check top-level fields, then check metadata
field_value = result.get(field_name)
if field_value is None:
# Try to get from metadata if not found at top level
metadata = result.get("metadata", {})
field_value = metadata.get(field_name)
# Handle missing fields - they fail all filters except existence checks
if field_value is None:
logger.debug(f"Field '{field_name}' not found in result or metadata")
return False
# Evaluate each operator in the filter spec
for operator, expected_value in filter_spec.items():
if operator not in self.operators:
logger.warning(f"Unsupported operator: {operator}")
return False
try:
if not self.operators[operator](field_value, expected_value):
logger.debug(
f"Filter failed: {field_name} {operator} {expected_value} "
f"(actual: {field_value})"
)
return False
except Exception as e:
logger.warning(
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
)
return False
return True
# Comparison operators
def _equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value equals expected value."""
return field_value == expected_value
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value does not equal expected value."""
return field_value != expected_value
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
# Membership operators
def _in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'in' operator requires a list, tuple, or set")
return field_value in expected_value
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is not in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'not_in' operator requires a list, tuple, or set")
return field_value not in expected_value
# String operators
def _contains(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value contains the expected substring."""
field_str = str(field_value)
expected_str = str(expected_value)
return expected_str in field_str
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value starts with the expected prefix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.startswith(expected_str)
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value ends with the expected suffix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.endswith(expected_str)
# Boolean operators
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is truthy."""
return bool(field_value)
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is falsy."""
return not bool(field_value)
# Helper methods
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
"""
Helper for numeric comparisons with type coercion.
Args:
field_value: Value from metadata
expected_value: Value to compare against
compare_func: Comparison function to apply
Returns:
Result of comparison
"""
try:
# Try to convert both values to numbers for comparison
if isinstance(field_value, str) and isinstance(expected_value, str):
# String comparison if both are strings
return compare_func(field_value, expected_value)
# Numeric comparison - attempt to convert to float
field_num = (
float(field_value) if not isinstance(field_value, (int, float)) else field_value
)
expected_num = (
float(expected_value)
if not isinstance(expected_value, (int, float))
else expected_value
)
return compare_func(field_num, expected_num)
except (ValueError, TypeError):
# Fall back to string comparison if numeric conversion fails
return compare_func(str(field_value), str(expected_value))

View File

@@ -1,26 +1,20 @@
# packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
# Set up logger for this module
logger = logging.getLogger(__name__)
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str):
"""A decorator to register a new backend class."""
def decorator(cls):
logger.debug(f"Registering backend '{name}'")
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
@@ -37,62 +31,13 @@ def autodiscover_backends():
backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
for backend_module_name in sorted(
discovered_backends
): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError:
except ImportError as e:
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass
# print("INFO: Backend auto-discovery finished.")
def register_project_directory(project_dir: Optional[Union[str, Path]] = None):
"""
Register a project directory in the global LEANN registry.
This allows `leann list` to discover indexes created by apps or other tools.
Args:
project_dir: Directory to register. If None, uses current working directory.
"""
if project_dir is None:
project_dir = Path.cwd()
else:
project_dir = Path(project_dir)
# Only register directories that have some kind of LEANN content
# Either .leann/indexes/ (CLI format) or *.leann.meta.json files (apps format)
has_cli_indexes = (project_dir / ".leann" / "indexes").exists()
has_app_indexes = any(project_dir.rglob("*.leann.meta.json"))
if not (has_cli_indexes or has_app_indexes):
# Don't register if there are no LEANN indexes
return
global_registry = Path.home() / ".leann" / "projects.json"
global_registry.parent.mkdir(exist_ok=True)
project_str = str(project_dir.resolve())
# Load existing registry
projects = []
if global_registry.exists():
try:
with open(global_registry) as f:
projects = json.load(f)
except Exception:
logger.debug("Could not load existing project registry")
projects = []
# Add project if not already present
if project_str not in projects:
projects.append(project_str)
# Save updated registry
try:
with open(global_registry, "w") as f:
json.dump(projects, f, indent=2)
logger.debug(f"Registered project directory: {project_str}")
except Exception as e:
logger.warning(f"Could not save project registry: {e}")

Some files were not shown because too many files have changed in this diff Show More