Compare commits

..

42 Commits

Author SHA1 Message Date
Andy Lee
80330f8d97 fix: remove whitespace from blank line
🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 00:04:32 +00:00
Andy Lee
4772a5bb18 feat: add process group management to prevent hanging subprocesses
- Add start_new_session=True to subprocess.Popen for better isolation
- Use os.killpg() to terminate entire process groups instead of single processes
- Import signal module for SIGTERM/SIGKILL handling
- This ensures child processes of embedding servers are also cleaned up

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 22:11:12 +00:00
Andy Lee
3d67205670 fix: remove Chinese comments to pass ruff check
🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 08:31:21 +00:00
Andy Lee
4de709ad4b feat: add ZMQ timeout configurations to prevent hanging
- Add RCVTIMEO (300s) to prevent recv operations from hanging indefinitely
- Add SNDTIMEO (300s) to prevent send operations from hanging indefinitely
- Add IMMEDIATE mode to avoid message queue blocking
- Applied to both api.py and searcher_base.py ZMQ socket connections

This ensures ZMQ operations timeout gracefully instead of hanging the process
when embedding servers become unresponsive.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 08:30:02 +00:00
Andy Lee
48c82ee3e3 fix: remove strict parameter from zip() for Python 3.9 compatibility
The strict parameter for zip() was added in Python 3.10.
Remove it to support Python 3.9.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 00:58:42 +00:00
Andy Lee
6d1ac4a503 fix: use Python 3.9 compatible builtin generics
- Convert List[str] to list[str], Dict[str, Any] to dict[str, Any], etc.
- Use ruff --unsafe-fixes to automatically apply all type annotation updates
- Remove deprecated typing imports (List, Dict, Tuple) where no longer needed
- Keep Optional[str] syntax (union operator | not supported in Python 3.9)

Now all type annotations are Python 3.9 compatible with modern builtin generics.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 00:38:33 +00:00
Andy Lee
ffba435252 fix: Python 3.9 compatibility - replace union types and builtin generics
- Replace 'str | None' with 'Optional[str]'
- Replace 'list[str]' with 'List[str]'
- Replace 'dict[' with 'Dict['
- Replace 'tuple[' with 'Tuple['
- Add missing typing imports (List, Dict, Tuple)

Fixes TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 00:29:46 +00:00
Andy Lee
728fa42ad5 style: run ruff format on modified files
- Format diskann_backend.py and conftest.py according to ruff standards

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-10 00:11:16 +00:00
Andy Lee
bce8aca3fa fix: ensure newline at end of conftest.py for ruff compliance 2025-08-09 23:56:18 +00:00
Andy Lee
f4e41e4353 style: fix ruff formatting issues in conftest.py
- Fix import sorting and organization
- Remove trailing whitespace
- Add proper newline at end of file

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-09 23:53:31 +00:00
Andy Lee
75c7b047d7 Merge branch 'main' into fix/clean-hang-solution 2025-08-09 16:49:51 -07:00
Andy Lee
490329dc66 fix: clean and simple hang prevention solution
This commit provides a minimal, focused fix for CI hanging issues by addressing the root causes:

**Key Changes:**

1. **ZMQ Resource Management:**
   - Remove `context.term()` calls that were causing hangs
   - Add `socket.setsockopt(zmq.LINGER, 0)` to prevent blocking on close
   - Keep socket operations simple with default timeouts (no artificial limits)

2. **Process Cleanup:**
   - Add timeout (1s) to final `process.wait()` in embedding server manager
   - Prevent infinite waiting that was causing CI hangs

3. **Resource Cleanup Methods:**
   - Add simple `cleanup()` methods to searchers and API classes
   - Focus on C++ object destruction for DiskANN backend
   - Avoid complex cleanup logic that could introduce new issues

4. **Basic Test Safety:**
   - Simple pytest-timeout configuration (300s)
   - Basic test session cleanup using psutil
   - Minimal conftest.py without complex logic

**Philosophy:**
This solution avoids the complex multi-layered fixes from the previous PR chain.
Instead, it targets the specific root causes:
- ZMQ context termination blocking
- Process wait() without timeout
- C++ resource leaks in backends

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-09 23:45:18 +00:00
Andy Lee
575b354976 style: organize imports per ruff; finish py39 Optional changes
- Fix import ordering in embedding servers and graph_partition_simple
- Remove duplicate Optional import
- Complete Optional[...] replacements
2025-08-07 15:06:25 -07:00
Andy Lee
65bbff1d93 fix(py39): replace union type syntax in chat.py
- validate_model_and_suggest: str | None -> Optional[str]
- OpenAIChat.__init__: api_key: str | None -> Optional[str]
- get_llm: dict[str, Any] | None -> Optional[dict[str, Any]]

Ensures Python 3.9 compatibility for CI macOS 3.9.
2025-08-07 15:01:09 -07:00
Andy Lee
df798d350d ci(macOS): set MACOSX_DEPLOYMENT_TARGET back to 13.3
- Fix build failure: 'sgesdd_' only available on macOS 13.3+
- Keep other CI improvements (local builds, find-links installs)
2025-08-07 14:38:32 -07:00
Andy Lee
3fa6b2aa17 ci: allow resolving third-party deps from index; still prefer local wheels for our packages
- Remove --no-index so numpy/scipy/etc can be resolved on Python 3.13
- Keep --find-links to force our packages from local dist

Fixes: dependency resolution failure on Ubuntu Python 3.13 (numpy missing)
2025-08-07 13:29:30 -07:00
Andy Lee
ba95554fe7 ci: build all packages on all platforms; install from local wheels only
- Build leann-core and leann on macOS too
- Install all packages via --find-links and --no-index across platforms
- Lower macOS MACOSX_DEPLOYMENT_TARGET to 12.0 for wider compatibility

This ensures consistency and avoids PyPI drift while improving macOS compatibility.
2025-08-07 13:00:11 -07:00
Andy Lee
677eb0bae3 fix: Python 3.9 compatibility - replace Union type syntax
- Replace 'int | None' with 'Optional[int]' everywhere
- Replace 'subprocess.Popen | None' with 'Optional[subprocess.Popen]'
- Add Optional import to all affected files
- Update ruff target-version from py310 to py39
- The '|' syntax for Union types was introduced in Python 3.10 (PEP 604)

Fixes TypeError: unsupported operand type(s) for |: 'type' and 'NoneType'
2025-08-07 12:54:16 -07:00
Andy Lee
9cdfcec331 fix: resolve dependency issues in CI package installation
- Ubuntu: Install all packages from local builds with --no-index
- macOS: Install core packages from PyPI, backends from local builds
- Remove --no-index for macOS backend installation to allow dependency resolution
- Pin versions when installing from PyPI to ensure consistency

Fixes error: 'leann-core was not found in the provided package locations'
2025-08-07 12:20:42 -07:00
Andy Lee
f30d1a2530 fix: ensure venv uses correct Python version from matrix
- Explicitly specify Python version when creating venv with uv
- Prevents mismatch between build Python (e.g., 3.10) and test Python
- Fixes: _diskannpy.cpython-310-x86_64-linux-gnu.so in Python 3.11 error

The issue: uv venv was defaulting to Python 3.11 regardless of matrix version
2025-08-07 12:01:11 -07:00
Andy Lee
df69a49123 fix: ensure CI installs correct Python version wheel packages
- Use --find-links with --no-index to let uv select correct wheel
- Prevents installing wrong Python version wheel (e.g., cp310 for Python 3.11)
- Fixes ImportError: _diskannpy.cpython-310-x86_64-linux-gnu.so in Python 3.11

The issue was that *.whl glob matched all Python versions, causing
uv to potentially install a cp310 wheel in a Python 3.11 environment.
2025-08-07 11:31:25 -07:00
Andy Lee
65b54ff905 fix: remove invalid --plat argument from auditwheel repair
- Remove '--plat linux_x86_64' which is not a valid platform tag
- Let auditwheel automatically determine the correct platform
- Based on CI output, it will use manylinux_2_35_x86_64

This was causing auditwheel repair to fail, preventing proper wheel repair
2025-08-07 11:04:34 -07:00
Andy Lee
4db3e94f35 debug: add more CI diagnostics for DiskANN module import issue
- Check wheel contents before and after auditwheel repair
- Verify _diskannpy module installation after pip install
- List installed package directory structure
- Add explicit platform tag for auditwheel repair

This helps diagnose why ImportError: cannot import name '_diskannpy' occurs
2025-08-07 10:55:09 -07:00
Andy Lee
a2568f3ddc fix: force install local wheels in CI to prevent PyPI version conflicts
- Change from --find-links to direct wheel installation with --force-reinstall
- This ensures CI uses locally built packages with latest source code
- Prevents uv from using PyPI packages with same version number but old code
- Fixes CI test failures where old code (without metadata_file_path) was used

Root cause: CI was installing leann-backend-diskann v0.2.1 from PyPI
instead of the locally built wheel with same version number.
2025-08-07 00:36:07 -07:00
Andy Lee
45bdad4fa7 debug: add detailed logging for CI path resolution debugging
- Add logging in DiskANN embedding server to show metadata_file_path
- Add debug logging in PassageManager to trace path resolution
- This will help identify why CI fails to find passage files
2025-08-07 00:00:12 -07:00
Andy Lee
8b538d1ef9 fix: use uv tool install for ruff instead of uv pip install
- uv tool install is the correct way to install CLI tools like ruff
- uv pip install --system is for Python packages, not tools
2025-08-06 22:57:18 -07:00
Andy Lee
ada8bcbc70 fix: pin ruff version to 0.12.7 across all environments
- Pin ruff==0.12.7 in pyproject.toml dev dependencies
- Update CI to use exact ruff version instead of latest
- Add comments explaining version pinning rationale
- Ensures consistent formatting across local, CI, and pre-commit
2025-08-06 22:56:32 -07:00
Andy Lee
6061e8f2de fix: format test files with latest ruff version for CI compatibility 2025-08-06 22:53:40 -07:00
Andy Lee
9842ad8330 fix: update pre-commit ruff version and format compliance 2025-08-06 22:33:15 -07:00
Andy Lee
7d920f9071 docs: add ldg-times parameter for diskann graph locality optimization 2025-08-06 22:23:02 -07:00
Andy Lee
f28f15000c docs: highlight diskann readiness and add performance comparison 2025-08-06 22:10:56 -07:00
Andy Lee
1d657fd9f6 tests: diskann and partition 2025-08-06 21:59:51 -07:00
Andy Lee
d217adbe40 fix: diskann building and partitioning 2025-08-06 21:32:03 -07:00
Andy Lee
f790ec634f chore: more data 2025-08-06 21:28:14 -07:00
Andy Lee
b8da9d7b12 docs: tool cli install 2025-08-06 21:28:05 -07:00
Andy Lee
0cb0463929 fix: always use relative path in metadata 2025-08-06 21:27:43 -07:00
yichuan520030910320
b982241249 add a path related fix 2025-08-05 23:35:48 -07:00
yichuan520030910320
c66f197e1d ruff 2025-08-05 23:24:55 -07:00
yichuan520030910320
4a1353761a merge 2025-08-05 23:23:07 -07:00
yichuan520030910320
a72090d2ab merge 2025-08-05 23:22:48 -07:00
yichuan520030910320
669e622430 chore: Update DiskANN submodule to latest with graph partition tools
- Update DiskANN submodule to commit b2dc4ea
- Includes graph partition tools and CMake integration
- Enables graph partitioning functionality in DiskANN backend
2025-08-05 23:14:19 -07:00
yichuan520030910320
77d7b60a61 feat: Add graph partition support for DiskANN backend
- Add GraphPartitioner class for advanced graph partitioning
- Add partition_graph_simple function for easy-to-use partitioning
- Add pybind11 dependency for C++ executable building
- Update __init__.py to export partition functions
- Include test scripts for partition functionality

The partition functionality allows optimizing disk-based indices
for better search performance and memory efficiency.
2025-08-05 23:11:09 -07:00
36 changed files with 4457 additions and 5731 deletions

View File

@@ -5,7 +5,6 @@ on:
branches: [ main ] branches: [ main ]
pull_request: pull_request:
branches: [ main ] branches: [ main ]
workflow_dispatch:
jobs: jobs:
build: build:

View File

@@ -28,7 +28,7 @@ jobs:
- name: Install ruff - name: Install ruff
run: | run: |
uv tool install ruff uv tool install ruff==0.12.7
- name: Run ruff check - name: Run ruff check
run: | run: |
@@ -54,36 +54,16 @@ jobs:
python: '3.12' python: '3.12'
- os: ubuntu-22.04 - os: ubuntu-22.04
python: '3.13' python: '3.13'
- os: macos-14 - os: macos-latest
python: '3.9' python: '3.9'
- os: macos-14 - os: macos-latest
python: '3.10' python: '3.10'
- os: macos-14 - os: macos-latest
python: '3.11' python: '3.11'
- os: macos-14 - os: macos-latest
python: '3.12' python: '3.12'
- os: macos-14 - os: macos-latest
python: '3.13' python: '3.13'
- os: macos-15
python: '3.9'
- os: macos-15
python: '3.10'
- os: macos-15
python: '3.11'
- os: macos-15
python: '3.12'
- os: macos-15
python: '3.13'
- os: macos-13
python: '3.9'
- os: macos-13
python: '3.10'
- os: macos-13
python: '3.11'
- os: macos-13
python: '3.12'
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@@ -129,70 +109,41 @@ jobs:
uv pip install --system delocate uv pip install --system delocate
fi fi
- name: Set macOS environment variables
if: runner.os == 'macOS'
run: |
# Use brew --prefix to automatically detect Homebrew installation path
HOMEBREW_PREFIX=$(brew --prefix)
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
# Set compiler flags for OpenMP (required for both backends)
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
- name: Build packages - name: Build packages
run: | run: |
# Build core (platform independent) # Build core (platform independent) on all platforms for consistency
cd packages/leann-core cd packages/leann-core
uv build uv build
cd ../.. cd ../..
# Build HNSW backend # Build HNSW backend
cd packages/leann-backend-hnsw cd packages/leann-backend-hnsw
if [[ "${{ matrix.os }}" == macos-* ]]; then if [ "${{ matrix.os }}" == "macos-latest" ]; then
# Use system clang for better compatibility # Use system clang instead of homebrew LLVM for better compatibility
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
# Homebrew libraries on each macOS version require matching minimum version export MACOSX_DEPLOYMENT_TARGET=11.0
if [[ "${{ matrix.os }}" == "macos-13" ]]; then uv build --wheel --python python
export MACOSX_DEPLOYMENT_TARGET=13.0
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist uv build --wheel --python python
fi fi
cd ../.. cd ../..
# Build DiskANN backend # Build DiskANN backend
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [[ "${{ matrix.os }}" == macos-* ]]; then if [ "${{ matrix.os }}" == "macos-latest" ]; then
# Use system clang for better compatibility # Use system clang instead of homebrew LLVM for better compatibility
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function # sgesdd_ is only available on macOS 13.3+
# But Homebrew libraries on each macOS version require matching minimum version export MACOSX_DEPLOYMENT_TARGET=13.3
if [[ "${{ matrix.os }}" == "macos-13" ]]; then uv build --wheel --python python
export MACOSX_DEPLOYMENT_TARGET=13.3
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
export MACOSX_DEPLOYMENT_TARGET=14.0
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
export MACOSX_DEPLOYMENT_TARGET=15.0
fi
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
else else
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist uv build --wheel --python python
fi fi
cd ../.. cd ../..
# Build meta package (platform independent) # Build meta package (platform independent) on all platforms
cd packages/leann cd packages/leann
uv build uv build
cd ../.. cd ../..
@@ -209,10 +160,15 @@ jobs:
fi fi
cd ../.. cd ../..
# Repair DiskANN wheel # Repair DiskANN wheel - use show first to debug
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [ -d dist ]; then if [ -d dist ]; then
echo "Checking DiskANN wheel contents before repair:"
unzip -l dist/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found"
auditwheel show dist/*.whl || echo "auditwheel show failed"
auditwheel repair dist/*.whl -w dist_repaired auditwheel repair dist/*.whl -w dist_repaired
echo "Checking DiskANN wheel contents after repair:"
unzip -l dist_repaired/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found after repair"
rm -rf dist rm -rf dist
mv dist_repaired dist mv dist_repaired dist
fi fi
@@ -221,24 +177,10 @@ jobs:
- name: Repair wheels (macOS) - name: Repair wheels (macOS)
if: runner.os == 'macOS' if: runner.os == 'macOS'
run: | run: |
# Determine deployment target based on runner OS
# Must match the Homebrew libraries for each macOS version
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
HNSW_TARGET="13.0"
DISKANN_TARGET="13.3"
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
HNSW_TARGET="14.0"
DISKANN_TARGET="14.0"
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
HNSW_TARGET="15.0"
DISKANN_TARGET="15.0"
fi
# Repair HNSW wheel # Repair HNSW wheel
cd packages/leann-backend-hnsw cd packages/leann-backend-hnsw
if [ -d dist ]; then if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET delocate-wheel -w dist_repaired -v dist/*.whl
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
rm -rf dist rm -rf dist
mv dist_repaired dist mv dist_repaired dist
fi fi
@@ -247,8 +189,7 @@ jobs:
# Repair DiskANN wheel # Repair DiskANN wheel
cd packages/leann-backend-diskann cd packages/leann-backend-diskann
if [ -d dist ]; then if [ -d dist ]; then
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET delocate-wheel -w dist_repaired -v dist/*.whl
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
rm -rf dist rm -rf dist
mv dist_repaired dist mv dist_repaired dist
fi fi
@@ -259,34 +200,44 @@ jobs:
echo "📦 Built packages:" echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Install built packages for testing - name: Install built packages for testing
run: | run: |
# Create a virtual environment with the correct Python version # Create a virtual environment with the correct Python version
uv venv --python ${{ matrix.python }} uv venv --python python${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate source .venv/bin/activate || source .venv/Scripts/activate
# Install packages using --find-links to prioritize local builds # Install the built wheels directly to ensure we use locally built packages
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz # Use only locally built wheels on all platforms for full consistency
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl FIND_LINKS="--find-links packages/leann-core/dist --find-links packages/leann/dist"
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl FIND_LINKS="$FIND_LINKS --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist"
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
uv pip install leann-core leann leann-backend-hnsw leann-backend-diskann \
$FIND_LINKS --force-reinstall
# Install test dependencies using extras # Install test dependencies using extras
uv pip install -e ".[test]" uv pip install -e ".[test]"
# Debug: Check if _diskannpy module is installed correctly
echo "Checking installed DiskANN module structure:"
python -c "import leann_backend_diskann; print('leann_backend_diskann location:', leann_backend_diskann.__file__)" || echo "Failed to import leann_backend_diskann"
python -c "from leann_backend_diskann import _diskannpy; print('_diskannpy imported successfully')" || echo "Failed to import _diskannpy"
ls -la $(python -c "import leann_backend_diskann; import os; print(os.path.dirname(leann_backend_diskann.__file__))" 2>/dev/null) 2>/dev/null || echo "Failed to list module directory"
- name: Run tests with pytest - name: Run tests with pytest
env: env:
CI: true CI: true # Mark as CI environment to skip memory-intensive tests
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1 HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0 PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
OMP_NUM_THREADS: 1 OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
MKL_NUM_THREADS: 1 MKL_NUM_THREADS: 1 # Single thread for MKL operations
run: | run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate source .venv/bin/activate || source .venv/Scripts/activate
pytest tests/ -v --tb=short
# Run all tests
pytest tests/
- name: Run sanity checks (optional) - name: Run sanity checks (optional)
run: | run: |

View File

@@ -3,11 +3,10 @@
</p> </p>
<p align="center"> <p align="center">
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions"> <img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License"> <img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration"> <img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
</p> </p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto"> <h2 align="center" tabindex="-1" class="heading-element" dir="auto">
@@ -71,8 +70,6 @@ source .venv/bin/activate
uv pip install leann uv pip install leann
``` ```
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups).
<details> <details>
<summary> <summary>
<strong>🔧 Build from Source (Recommended for development)</strong> <strong>🔧 Build from Source (Recommended for development)</strong>
@@ -100,7 +97,6 @@ uv sync
</details> </details>
## Quick Start ## Quick Start
Our declarative API makes RAG as easy as writing a config file. Our declarative API makes RAG as easy as writing a config file.
@@ -186,34 +182,34 @@ All RAG examples share these common parameters. **Interactive mode** is availabl
```bash ```bash
# Core Parameters (General preprocessing for all examples) # Core Parameters (General preprocessing for all examples)
--index-dir DIR # Directory to store the index (default: current directory) --index-dir DIR # Directory to store the index (default: current directory)
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively --query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
--max-items N # Limit data preprocessing (default: -1, process all data) --max-items N # Limit data preprocessing (default: -1, process all data)
--force-rebuild # Force rebuild index even if it exists --force-rebuild # Force rebuild index even if it exists
# Embedding Parameters # Embedding Parameters
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text --embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, or mlx-community/multilingual-e5-base-mlx
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama --embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models) # LLM Parameters (Text generation models)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai) --llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct --llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models) --thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
# Search Parameters # Search Parameters
--top-k N # Number of results to retrieve (default: 20) --top-k N # Number of results to retrieve (default: 20)
--search-complexity N # Search complexity for graph traversal (default: 32) --search-complexity N # Search complexity for graph traversal (default: 32)
# Chunking Parameters # Chunking Parameters
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat) --chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source) --chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
# Index Building Parameters # Index Building Parameters
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw) --backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
--graph-degree N # Graph degree for index construction (default: 32) --graph-degree N # Graph degree for index construction (default: 32)
--build-complexity N # Build complexity for index construction (default: 64) --build-complexity N # Build complexity for index construction (default: 64)
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build. --no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build. --no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
``` ```
</details> </details>
@@ -470,7 +466,7 @@ leann --help
### Usage Examples ### Usage Examples
```bash ```bash
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files) # build from a specific directory, and my_docs is the index name
leann build my-docs --docs ./your_documents leann build my-docs --docs ./your_documents
# Search your documents # Search your documents
@@ -484,29 +480,27 @@ leann list
``` ```
**Key CLI features:** **Key CLI features:**
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files) - Auto-detects document formats (PDF, TXT, MD, DOCX)
- Smart text chunking with overlap - Smart text chunking with overlap
- Multiple LLM providers (Ollama, OpenAI, HuggingFace) - Multiple LLM providers (Ollama, OpenAI, HuggingFace)
- Organized index storage in `.leann/indexes/` (project-local) - Organized index storage in `~/.leann/indexes/`
- Support for advanced search parameters - Support for advanced search parameters
<details> <details>
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary> <summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help` to get the complete CLI reference.
**Build Command:** **Build Command:**
```bash ```bash
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS] leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
Options: Options:
--backend {hnsw,diskann} Backend to use (default: hnsw) --backend {hnsw,diskann} Backend to use (default: hnsw)
--embedding-model MODEL Embedding model (default: facebook/contriever) --embedding-model MODEL Embedding model (default: facebook/contriever)
--graph-degree N Graph degree (default: 32) --graph-degree N Graph degree (default: 32)
--complexity N Build complexity (default: 64) --complexity N Build complexity (default: 64)
--force Force rebuild existing index --force Force rebuild existing index
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build. --compact Use compact storage (default: true)
--recompute / --no-recompute Enable recomputation (default: true) --recompute Enable recomputation (default: true)
``` ```
**Search Command:** **Search Command:**
@@ -514,9 +508,9 @@ Options:
leann search INDEX_NAME QUERY [OPTIONS] leann search INDEX_NAME QUERY [OPTIONS]
Options: Options:
--top-k N Number of results (default: 5) --top-k N Number of results (default: 5)
--complexity N Search complexity (default: 64) --complexity N Search complexity (default: 64)
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build. --recompute-embeddings Use recomputation for highest accuracy
--pruning-strategy {global,local,proportional} --pruning-strategy {global,local,proportional}
``` ```
@@ -615,9 +609,8 @@ We welcome more contributors! Feel free to open issues or submit PRs.
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/). This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
## Star History ---
[![Star History Chart](https://api.star-history.com/svg?repos=yichuan-w/LEANN&type=Date)](https://www.star-history.com/#yichuan-w/LEANN&Date)
<p align="center"> <p align="center">
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong> <strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
</p> </p>

View File

@@ -69,14 +69,14 @@ class BaseRAGExample(ABC):
"--embedding-model", "--embedding-model",
type=str, type=str,
default=embedding_model_default, default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text", help=f"Embedding model to use (default: {embedding_model_default})",
) )
embedding_group.add_argument( embedding_group.add_argument(
"--embedding-mode", "--embedding-mode",
type=str, type=str,
default="sentence-transformers", default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama", help="Embedding backend mode (default: sentence-transformers)",
) )
# LLM parameters # LLM parameters
@@ -86,13 +86,13 @@ class BaseRAGExample(ABC):
type=str, type=str,
default="openai", default="openai",
choices=["openai", "ollama", "hf", "simulated"], choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend: openai, ollama, or hf (default: openai)", help="LLM backend to use (default: openai)",
) )
llm_group.add_argument( llm_group.add_argument(
"--llm-model", "--llm-model",
type=str, type=str,
default=None, default=None,
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct", help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
) )
llm_group.add_argument( llm_group.add_argument(
"--llm-host", "--llm-host",
@@ -178,9 +178,6 @@ class BaseRAGExample(ABC):
config["host"] = args.llm_host config["host"] = args.llm_host
elif args.llm == "hf": elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct" config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config return config

View File

@@ -1,148 +0,0 @@
import argparse
import os
import time
from pathlib import Path
from leann import LeannBuilder, LeannSearcher
def _meta_exists(index_path: str) -> bool:
p = Path(index_path)
return (p.parent / f"{p.stem}.meta.json").exists()
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
# if _meta_exists(index_path):
# return
kwargs = {}
if backend_name == "hnsw":
kwargs["is_compact"] = is_recompute
builder = LeannBuilder(
backend_name=backend_name,
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
graph_degree=32,
complexity=64,
is_recompute=is_recompute,
num_threads=4,
**kwargs,
)
for i in range(num_docs):
builder.add_text(
f"This is a test document number {i}. It contains some repeated text for benchmarking."
)
builder.build_index(index_path)
def _bench_group(
index_path: str,
recompute: bool,
query: str,
repeats: int,
complexity: int = 32,
top_k: int = 10,
) -> float:
# Independent searcher per group; fixed port when recompute
searcher = LeannSearcher(index_path=index_path)
# Warm-up once
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
def _once() -> float:
t0 = time.time()
_ = searcher.search(
query,
top_k=top_k,
complexity=complexity,
recompute_embeddings=recompute,
)
return time.time() - t0
if repeats <= 1:
t = _once()
else:
vals = [_once() for _ in range(repeats)]
vals.sort()
t = vals[len(vals) // 2]
searcher.cleanup()
return t
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-docs", type=int, default=5000)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument("--complexity", type=int, default=32)
args = parser.parse_args()
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
base.parent.mkdir(parents=True, exist_ok=True)
# ---------- Build HNSW variants ----------
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
# ---------- Build DiskANN variants ----------
diskann_r = str(base / "diskann_r.leann")
diskann_nr = str(base / "diskann_nr.leann")
ensure_index(diskann_r, "diskann", args.num_docs, True)
ensure_index(diskann_nr, "diskann", args.num_docs, False)
# ---------- Helpers ----------
def _size_for(prefix: str) -> int:
p = Path(prefix)
base_dir = p.parent
stem = p.stem
total = 0
for f in base_dir.iterdir():
if f.is_file() and f.name.startswith(stem):
total += f.stat().st_size
return total
# ---------- HNSW benchmark ----------
t_hnsw_r = _bench_group(
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
t_hnsw_nr = _bench_group(
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
)
size_hnsw_r = _size_for(hnsw_r)
size_hnsw_nr = _size_for(hnsw_nr)
print("Benchmark results (HNSW):")
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
print(
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
)
print(" Expectation: no-recompute should be faster but larger on disk.")
# ---------- DiskANN benchmark ----------
t_diskann_r = _bench_group(
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
)
t_diskann_nr = _bench_group(
diskann_nr,
False,
"DiskANN NR test doc 123",
repeats=args.repeats,
complexity=args.complexity,
)
size_diskann_r = _size_for(diskann_r)
size_diskann_nr = _size_for(diskann_nr)
print("\nBenchmark results (DiskANN):")
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
if __name__ == "__main__":
main()

View File

@@ -10,7 +10,6 @@ This benchmark compares search performance between DiskANN and HNSW backends:
""" """
import gc import gc
import multiprocessing as mp
import tempfile import tempfile
import time import time
from pathlib import Path from pathlib import Path
@@ -18,12 +17,6 @@ from typing import Any
import numpy as np import numpy as np
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
try:
mp.set_start_method("fork", force=True)
except Exception:
pass
def create_test_texts(n_docs: int) -> list[str]: def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking.""" """Create synthetic test documents for benchmarking."""
@@ -120,10 +113,10 @@ def benchmark_backend(
] ]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0 score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up (ensure embedding server shutdown and object GC) # Clean up
try: try:
if hasattr(searcher, "cleanup"): if hasattr(searcher, "__del__"):
searcher.cleanup() searcher.__del__()
del searcher del searcher
del builder del builder
gc.collect() gc.collect()
@@ -266,21 +259,10 @@ if __name__ == "__main__":
print(f"\n❌ Benchmark failed: {e}") print(f"\n❌ Benchmark failed: {e}")
sys.exit(1) sys.exit(1)
finally: finally:
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads) # Ensure clean exit
try: try:
gc.collect() gc.collect()
print("\n🧹 Cleanup completed") print("\n🧹 Cleanup completed")
# Flush stdio to ensure message is visible before hard-exit
try:
import sys as _sys
_sys.stdout.flush()
_sys.stderr.flush()
except Exception:
pass
except Exception: except Exception:
pass pass
# Use os._exit to bypass atexit handlers that may hang in rare cases sys.exit(0)
import os as _os
_os._exit(0)

View File

@@ -52,7 +52,7 @@ Based on our experience developing LEANN, embedding models fall into three categ
### Quick Start: Cloud and Local Embedding Options ### Quick Start: Cloud and Local Embedding Options
**OpenAI Embeddings (Fastest Setup)** **OpenAI Embeddings (Fastest Setup)**
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API): For immediate testing without local model downloads:
```bash ```bash
# Set OpenAI embeddings (requires OPENAI_API_KEY) # Set OpenAI embeddings (requires OPENAI_API_KEY)
--embedding-mode openai --embedding-model text-embedding-3-small --embedding-mode openai --embedding-model text-embedding-3-small
@@ -97,23 +97,29 @@ ollama pull nomic-embed-text
``` ```
### DiskANN ### DiskANN
**Best for**: Large datasets, especially when you want `recompute=True`. **Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
**Key advantages:** **How it works:**
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases) - **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes - **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
- **Better scaling**: Designed for 100k+ documents - **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
**Recompute behavior:** **Trade-offs compared to HNSW:**
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning - **Faster search latency** (typically 2-8x speedup)
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy - **Better scaling** for large datasets
-**Smart storage management** with automatic partitioning
-**Better graph locality** with `--ldg-times` parameter for SSD optimization
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
```bash ```bash
# Recommended for most use cases # Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64 --backend-name diskann --graph-degree 32 --build-complexity 64
# For large-scale deployments
--backend-name diskann --graph-degree 64 --build-complexity 128
``` ```
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system. **Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison ## LLM Selection: Engine and Model Comparison
@@ -230,15 +236,9 @@ python apps/document_rag.py --query "What are the main techniques LEANN explores
3. **Use MLX on Apple Silicon** (optional optimization): 3. **Use MLX on Apple Silicon** (optional optimization):
```bash ```bash
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit --embedding-mode mlx --embedding-model mlx-community/multilingual-e5-base-mlx
``` ```
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
4. **Use Ollama**
```bash
--embedding-mode ollama --embedding-model nomic-embed-text
```
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
### If Search Quality is Poor ### If Search Quality is Poor
1. **Increase retrieval count**: 1. **Increase retrieval count**:
@@ -267,114 +267,24 @@ Every configuration choice involves trade-offs:
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed. The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Low-resource setups ## Deep Dive: Critical Configuration Decisions
If you dont have a local GPU or builds/searches are too slow, use one or more of the options below. ### When to Disable Recomputation
### 1) Use OpenAI embeddings (no local compute) LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
```bash ```bash
export OPENAI_API_KEY=sk-... --no-recompute # Disable selective recomputation
# Build with OpenAI embeddings
leann build my-index \
--embedding-mode openai \
--embedding-model text-embedding-3-small
# Search with OpenAI embeddings (recompute at query time)
leann search my-index "your query" \
--recompute
``` ```
### 2) Run remote builds with SkyPilot (cloud GPU) **Trade-offs**:
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`. - **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
```bash
# One-time: install and configure SkyPilot
pip install skypilot
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
sky launch -c leann-gpu sky/leann-build.yaml
# Override parameters via -e key=value (optional)
sky launch -c leann-gpu sky/leann-build.yaml \
-e index_name=my-index \
-e backend=hnsw \
-e embedding_mode=sentence-transformers \
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
# Copy the built index back to your local .leann (use rsync)
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
```
### 3) Disable recomputation to trade storage for speed
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
```bash
# Build without recomputation (HNSW requires non-compact in this mode)
leann build my-index --no-recompute --no-compact
# Search without recomputation
leann search my-index "your query" --no-recompute
```
When to use:
- Extreme low latency requirements (high QPS, interactive assistants)
- Read-heavy workloads where storage is cheaper than latency
- No always-available GPU
Constraints:
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
- DiskANN: supported; `--no-recompute` skips selective recompute during search
Storage impact:
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
Converting an existing index (rebuild required):
```bash
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
leann build my-index --force --no-recompute --no-compact
```
Python API usage:
```python
from leann import LeannSearcher
searcher = LeannSearcher("/path/to/my-index.leann")
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
```
Trade-offs:
- Lower latency and fewer network hops at query time
- Significantly higher storage (10100× vs selective recomputation)
- Slightly larger memory footprint during build and search
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
- HNSW
```text
recompute=True: search_time=0.818s, size=1.1MB
recompute=False: search_time=0.012s, size=16.6MB
```
- DiskANN
```text
recompute=True: search_time=0.041s, size=5.9MB
recompute=False: search_time=0.013s, size=24.6MB
```
Conclusion:
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
**Disable when**:
- You have abundant storage and memory
- Need extremely low latency (< 100ms)
- Running a read-heavy workload where storage cost is acceptable
## Further Reading ## Further Reading

View File

@@ -0,0 +1,8 @@
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
cmake_minimum_required(VERSION 3.20)
project(leann_backend_diskann_wrapper)
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
# DiskANN will handle everything itself, including compiling Python bindings
add_subdirectory(src/third_party/DiskANN)

View File

@@ -22,11 +22,6 @@ logger = logging.getLogger(__name__)
@contextlib.contextmanager @contextlib.contextmanager
def suppress_cpp_output_if_needed(): def suppress_cpp_output_if_needed():
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL""" """Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
if os.getenv("CI") == "true":
yield
return
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL) # Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
@@ -441,14 +436,9 @@ class DiskannSearcher(BaseSearcher):
else: # "global" else: # "global"
use_global_pruning = True use_global_pruning = True
# Strategy: # Perform search with suppressed C++ output based on log level
# - Traversal always uses PQ distances use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
# - If recompute_embeddings=True, do a single final rerank via deferred fetch recompute_neighors = False
# (fetch embeddings for the final candidate set only)
# - Do not recompute neighbor distances along the path
use_deferred_fetch = True if recompute_embeddings else False
recompute_neighors = False # Expected typo. For backward compatibility.
with suppress_cpp_output_if_needed(): with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search( labels, distances = self._index.batch_search(
query, query,
@@ -469,3 +459,25 @@ class DiskannSearcher(BaseSearcher):
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels] string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}
def cleanup(self):
"""Cleanup DiskANN-specific resources including C++ index."""
# Call parent cleanup first
super().cleanup()
# Delete the C++ index to trigger destructors
try:
if hasattr(self, "_index") and self._index is not None:
del self._index
self._index = None
self._current_zmq_port = None
except Exception:
pass
# Force garbage collection to ensure C++ objects are destroyed
try:
import gc
gc.collect()
except Exception:
pass

View File

@@ -103,9 +103,8 @@ def create_diskann_embedding_server(
socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}") logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 1000) socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 1000) socket.setsockopt(zmq.SNDTIMEO, 300000)
socket.setsockopt(zmq.LINGER, 0)
while True: while True:
try: try:
@@ -222,217 +221,30 @@ def create_diskann_embedding_server(
traceback.print_exc() traceback.print_exc()
raise raise
def zmq_server_thread_with_shutdown(shutdown_event): zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
"""ZMQ server thread that respects shutdown signal.
This creates its own REP socket, binds to zmq_port, and periodically
checks shutdown_event using recv timeouts to exit cleanly.
"""
logger.info("DiskANN ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
# Set receive timeout so we can check shutdown_event periodically
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
try:
while not shutdown_event.is_set():
try:
e2e_start = time.time()
# REP socket receives single-part messages
message = rep_socket.recv()
# Check for empty messages - REP socket requires response to every request
if not message:
logger.warning("Received empty message, sending empty response")
rep_socket.send(b"")
continue
# Try protobuf first (same logic as original)
texts = []
is_text_request = False
try:
req_proto = embedding_pb2.NodeEmbeddingRequest()
req_proto.ParseFromString(message)
node_ids = list(req_proto.node_ids)
# Look up texts by node IDs
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"ZMQ received msgpack text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error("NaN or Inf detected in embeddings!")
# Send error response
if is_text_request:
import msgpack
response_data = msgpack.packb([])
else:
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = resp_proto.SerializeToString()
rep_socket.send(response_data)
continue
# Prepare response based on request type
if is_text_request:
# For direct text requests, return msgpack
import msgpack
response_data = msgpack.packb(embeddings.tolist())
else:
# For protobuf requests, return protobuf
resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString()
# Send response back to the client
rep_socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
try:
# Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("DiskANN ZMQ server thread exiting gracefully")
# Add shutdown coordination
shutdown_event = threading.Event()
def shutdown_zmq_server():
"""Gracefully shutdown ZMQ server."""
logger.info("Initiating graceful shutdown...")
shutdown_event.set()
if zmq_thread.is_alive():
logger.info("Waiting for ZMQ thread to finish...")
zmq_thread.join(timeout=5)
if zmq_thread.is_alive():
logger.warning("ZMQ thread did not finish in time")
# Clean up ZMQ resources
try:
# Note: socket and context are cleaned up by thread exit
logger.info("ZMQ resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning ZMQ resources: {e}")
# Clean up other resources
try:
import gc
gc.collect()
logger.info("Additional resources cleaned up")
except Exception as e:
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed")
sys.exit(0)
# Register signal handlers within this function scope
import signal
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Start ZMQ thread (NOT daemon!)
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread.start() zmq_thread.start()
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}") logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
# Keep the main thread alive # Keep the main thread alive
try: try:
while not shutdown_event.is_set(): while True:
time.sleep(0.1) # Check shutdown more frequently time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("DiskANN Server shutting down...") logger.info("DiskANN Server shutting down...")
shutdown_zmq_server()
return return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__": if __name__ == "__main__":
import signal
import sys import sys
# Signal handlers are now registered within create_diskann_embedding_server def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="DiskANN Embedding service") parser = argparse.ArgumentParser(description="DiskANN Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""
Simplified Graph Partition Module for LEANN DiskANN Backend
This module provides a simple Python interface for graph partitioning
that directly calls the existing executables.
"""
import os
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
def partition_graph_simple(
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
) -> tuple[str, str]:
"""
Simple function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory (defaults to parent of index_prefix_path)
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Find the graph_partition directory
current_file = Path(__file__)
graph_partition_dir = current_file.parent.parent / "third_party" / "DiskANN" / "graph_partition"
if not graph_partition_dir.exists():
raise RuntimeError(f"Graph partition directory not found: {graph_partition_dir}")
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Run the build script with our parameters
cmd = [str(graph_partition_dir / "build.sh"), "release", "split_graph", index_prefix_path]
# Set environment variables for parameters
env = os.environ.copy()
env.update(
{
"GP_TIMES": str(params["gp_times"]),
"GP_LOCK_NUMS": str(params["lock_nums"]),
"GP_CUT": str(params["cut"]),
"GP_SCALE_F": str(params["scale_factor"]),
"DATA_TYPE": params["data_type"],
"GP_T": str(params["thread_nums"]),
}
)
print(f"Running graph partition with command: {' '.join(cmd)}")
print(f"Working directory: {graph_partition_dir}")
# Run the command
result = subprocess.run(
cmd, env=env, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
print(f"Command failed with return code {result.returncode}")
print(f"stdout: {result.stdout}")
print(f"stderr: {result.stderr}")
raise RuntimeError(
f"Graph partitioning failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Check if output files were created
disk_graph_path = Path(output_dir) / "_disk_graph.index"
partition_bin_path = Path(output_dir) / "_partition.bin"
if not disk_graph_path.exists():
raise RuntimeError(f"Expected output file not found: {disk_graph_path}")
if not partition_bin_path.exists():
raise RuntimeError(f"Expected output file not found: {partition_bin_path}")
print("✅ Partitioning completed successfully!")
print(f" Disk graph index: {disk_graph_path}")
print(f" Partition binary: {partition_bin_path}")
return str(disk_graph_path), str(partition_bin_path)
# Example usage
if __name__ == "__main__":
try:
disk_graph_path, partition_bin_path = partition_graph_simple(
"/Users/yichuan/Desktop/release2/leann/diskannbuild/test_doc_files",
gp_times=5,
lock_nums=5,
cut=50,
)
print("Success! Output files:")
print(f" - {disk_graph_path}")
print(f" - {partition_bin_path}")
except Exception as e:
print(f"Error: {e}")

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.2.9" version = "0.2.5"
dependencies = ["leann-core==0.2.9", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.2.5", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path
@@ -17,5 +17,3 @@ editable.mode = "redirect"
cmake.build-type = "Release" cmake.build-type = "Release"
build.verbose = true build.verbose = true
build.tool-args = ["-j8"] build.tool-args = ["-j8"]
# Let CMake find packages via Homebrew prefix
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}

View File

@@ -5,20 +5,11 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
# Set OpenMP path for macOS # Set OpenMP path for macOS
if(APPLE) if(APPLE)
# Detect Homebrew installation path (Apple Silicon vs Intel) set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
if(EXISTS "/opt/homebrew/opt/libomp") set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
set(HOMEBREW_PREFIX "/opt/homebrew")
elseif(EXISTS "/usr/local/opt/libomp")
set(HOMEBREW_PREFIX "/usr/local")
else()
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
endif()
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
set(OpenMP_C_LIB_NAMES "omp") set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp") set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib") set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
# Force use of system libc++ to avoid version mismatch # Force use of system libc++ to avoid version mismatch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")

View File

@@ -1,6 +1,5 @@
import argparse import argparse
import gc # Import garbage collector interface import gc # Import garbage collector interface
import logging
import os import os
import struct import struct
import sys import sys
@@ -8,12 +7,6 @@ import time
import numpy as np import numpy as np
# Set up logging to avoid print buffer issues
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# --- FourCCs (add more if needed) --- # --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little") INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
# Add other HNSW fourccs if you expect different storage types inside HNSW # Add other HNSW fourccs if you expect different storage types inside HNSW
@@ -250,8 +243,6 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
output_filename: Output CSR index file output_filename: Output CSR index file
prune_embeddings: Whether to prune embedding storage (write NULL storage marker) prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
""" """
# Keep prints simple; rely on CI runner to flush output as needed
print(f"Starting conversion: {input_filename} -> {output_filename}") print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time() start_time = time.time()
original_hnsw_data = {} original_hnsw_data = {}

View File

@@ -54,13 +54,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions") self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute and self.is_compact: if not self.is_recompute:
# Auto-correct: non-recompute requires non-compact storage for HNSW if self.is_compact:
logger.warning( # TODO: support this case @andy
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False." raise ValueError(
) "is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
self.is_compact = False )
self.build_params["is_compact"] = False
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs): def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
from . import faiss # type: ignore from . import faiss # type: ignore
@@ -185,11 +184,9 @@ class HNSWSearcher(BaseSearcher):
""" """
from . import faiss # type: ignore from . import faiss # type: ignore
if not recompute_embeddings and self.is_pruned: if not recompute_embeddings:
raise RuntimeError( if self.is_pruned:
"Recompute is required for pruned/compact HNSW index. " raise RuntimeError("Recompute is required for pruned index.")
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
)
if recompute_embeddings: if recompute_embeddings:
if zmq_port is None: if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True") raise ValueError("zmq_port must be provided if recompute_embeddings is True")

View File

@@ -82,317 +82,188 @@ def create_hnsw_embedding_server(
with open(passages_file) as f: with open(passages_file) as f:
meta = json.load(f) meta = json.load(f)
# Let PassageManager handle path resolution uniformly. It supports fallback order: # Let PassageManager handle path resolution uniformly
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file) passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
# Dimension from metadata for shaping responses
try:
embedding_dim: int = int(meta.get("dimensions", 0))
except Exception:
embedding_dim = 0
logger.info( logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
) )
# (legacy ZMQ thread removed; using shutdown-capable server only) def zmq_server_thread():
"""ZMQ server thread"""
def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.
Creates its own REP socket bound to zmq_port and polls with timeouts
to allow graceful shutdown.
"""
logger.info("ZMQ server thread started with shutdown support")
context = zmq.Context() context = zmq.Context()
rep_socket = context.socket(zmq.REP) socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0)
# Track last request type/length for shape-correct fallbacks socket.setsockopt(zmq.RCVTIMEO, 300000)
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown' socket.setsockopt(zmq.SNDTIMEO, 300000)
last_request_length = 0
try: while True:
while not shutdown_event.is_set(): try:
try: message_bytes = socket.recv()
e2e_start = time.time() logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
# Rest of the processing logic (same as original) e2e_start = time.time()
request = msgpack.unpackb(request_bytes) request_payload = msgpack.unpackb(message_bytes)
if len(request) == 1 and request[0] == "__QUERY_MODEL__": # Handle direct text embedding request
response_bytes = msgpack.packb([model_name]) if isinstance(request_payload, list) and len(request_payload) > 0:
rep_socket.send(response_bytes) # Check if this is a direct text request (list of strings)
continue if all(isinstance(item, str) for item in request_payload):
logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
# Handle direct text embedding request # Use unified embedding computation (now with model caching)
if ( embeddings = compute_embeddings(
isinstance(request, list) request_payload, model_name, mode=embedding_mode
and request )
and all(isinstance(item, str) for item in request)
): response = embeddings.tolist()
last_request_type = "text" socket.send(msgpack.packb(response))
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue continue
# Handle distance calculation request: [[ids], [query_vector]] # Handle distance calculation requests
if ( if (
isinstance(request, list) isinstance(request_payload, list)
and len(request) == 2 and len(request_payload) == 2
and isinstance(request[0], list) and isinstance(request_payload[0], list)
and isinstance(request[1], list) and isinstance(request_payload[1], list)
): ):
node_ids = request[0] node_ids = request_payload[0]
# Handle nested [[ids]] shape defensively query_vector = np.array(request_payload[1], dtype=np.float32)
if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received") logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}") logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}") logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids # Get embeddings for node IDs
texts: list[str] = [] texts = []
found_indices: list[int] = [] for nid in node_ids:
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
# Prepare full-length response with large sentinel values
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if distance_metric == "l2":
partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(f"Distance computation error, using sentinels: {e}")
# Send response in expected shape [[distances]]
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
continue
# Fallback: treat as embedding-by-id request
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
else:
node_ids = []
last_request_type = "embedding"
last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0:
dims = [0, 0]
flat_data: list[float] = []
else:
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try: try:
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data.get("text", "") txt = passage_data["text"]
if isinstance(txt, str) and len(txt) > 0: texts.append(txt)
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError: except KeyError:
logger.error(f"Passage with ID {nid} not found") logger.error(f"Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e: except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if texts: # Process embeddings
try: embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) logger.info(
logger.info( f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" )
)
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): # Calculate distances
logger.error( if distance_metric == "l2":
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." distances = np.sum(
) np.square(embeddings - query_vector.reshape(1, -1)), axis=1
dims = [0, embedding_dim] )
flat_data = [] else: # mips or cosine
else: distances = -np.dot(embeddings, query_vector)
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
if end <= len(flat_data):
flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim
]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
response_payload = [dims, flat_data] response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb(response_payload, use_single_float=True) response_bytes = msgpack.packb([response_payload], use_single_float=True)
logger.debug(f"Sending distance response with {len(distances)} distances")
rep_socket.send(response_bytes) socket.send(response_bytes)
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try:
if last_request_type == "distance":
large_distance = 1e9
fallback_len = max(0, int(last_request_length))
safe = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
safe = (
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
)
elif last_request_type == "text":
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("ZMQ server thread exiting gracefully") # Standard embedding request (passage ID lookup)
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
logger.error(
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
# Add shutdown coordination node_ids = request_payload[0]
shutdown_event = threading.Event() logger.debug(f"Request for {len(node_ids)} node embeddings")
def shutdown_zmq_server(): # Look up texts by node IDs
"""Gracefully shutdown ZMQ server.""" texts = []
logger.info("Initiating graceful shutdown...") for nid in node_ids:
shutdown_event.set() try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if zmq_thread.is_alive(): # Process embeddings
logger.info("Waiting for ZMQ thread to finish...") embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
zmq_thread.join(timeout=5) logger.info(
if zmq_thread.is_alive(): f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
logger.warning("ZMQ thread did not finish in time") )
# Clean up ZMQ resources # Serialization and response
try: if np.isnan(embeddings).any() or np.isinf(embeddings).any():
# Note: socket and context are cleaned up by thread exit logger.error(
logger.info("ZMQ resources cleaned up") f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
except Exception as e: )
logger.warning(f"Error cleaning ZMQ resources: {e}") raise AssertionError()
# Clean up other resources hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
try: response_payload = [
import gc list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
gc.collect() socket.send(response_bytes)
logger.info("Additional resources cleaned up") e2e_end = time.time()
except Exception as e: logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
logger.warning(f"Error cleaning additional resources: {e}")
logger.info("Graceful shutdown completed") except zmq.Again:
sys.exit(0) logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
# Register signal handlers within this function scope traceback.print_exc()
import signal socket.send(msgpack.packb([[], []]))
def signal_handler(sig, frame): zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
logger.info(f"Received signal {sig}, shutting down gracefully...")
shutdown_zmq_server()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Pass shutdown_event to ZMQ thread
zmq_thread = threading.Thread(
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
daemon=False, # Not daemon - we want to wait for it
)
zmq_thread.start() zmq_thread.start()
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}") logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive # Keep the main thread alive
try: try:
while not shutdown_event.is_set(): while True:
time.sleep(0.1) # Check shutdown more frequently time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("HNSW Server shutting down...") logger.info("HNSW Server shutting down...")
shutdown_zmq_server()
return return
# If we reach here, shutdown was triggered by signal
logger.info("Main loop exited, process should be shutting down")
if __name__ == "__main__": if __name__ == "__main__":
import signal
import sys import sys
# Signal handlers are now registered within create_hnsw_embedding_server def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="HNSW Embedding service") parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.2.9" version = "0.2.5"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.2.9", "leann-core==0.2.5",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",
@@ -22,8 +22,6 @@ cmake.build-type = "Release"
build.verbose = true build.verbose = true
build.tool-args = ["-j8"] build.tool-args = ["-j8"]
# CMake definitions to optimize compilation and find Homebrew packages # CMake definitions to optimize compilation
[tool.scikit-build.cmake.define] [tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8" CMAKE_BUILD_PARALLEL_LEVEL = "8"
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
OpenMP_ROOT = {env = "OpenMP_ROOT"}

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.2.9" version = "0.2.5"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
@@ -31,10 +31,8 @@ dependencies = [
"PyPDF2>=3.0.0", "PyPDF2>=3.0.0",
"pymupdf>=1.23.0", "pymupdf>=1.23.0",
"pdfplumber>=0.10.0", "pdfplumber>=0.10.0",
"nbconvert>=7.0.0", # For .ipynb file support "mlx>=0.26.3; sys_platform == 'darwin'",
"gitignore-parser>=0.1.12", # For proper .gitignore handling "mlx-lm>=0.26.0; sys_platform == 'darwin'",
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -87,21 +87,26 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
# Connect to embedding server # Connect to embedding server
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
socket.setsockopt(zmq.IMMEDIATE, 1)
socket.connect(f"tcp://localhost:{port}") socket.connect(f"tcp://localhost:{port}")
# Send chunks to server for embedding computation try:
request = chunks # Send chunks to server for embedding computation
socket.send(msgpack.packb(request)) request = chunks
socket.send(msgpack.packb(request))
# Receive embeddings from server # Receive embeddings from server
response = socket.recv() response = socket.recv()
embeddings_list = msgpack.unpackb(response) embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array # Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32) embeddings = np.array(embeddings_list, dtype=np.float32)
finally:
socket.close() socket.close()
context.term() # Don't call context.term() - this was causing hangs
return embeddings return embeddings
@@ -122,55 +127,31 @@ class PassageManager:
self.passage_files = {} self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup self.global_offset_map = {} # Combined map for fast lookup
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
index_name_base = None
if metadata_file_path:
meta_name = Path(metadata_file_path).name
if meta_name.endswith(".meta.json"):
index_name_base = meta_name[: -len(".meta.json")]
for source in passage_sources: for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported" assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source.get("path", "") passage_file = source["path"]
index_file = source.get("index_path", "") # .idx file index_file = source["index_path"] # .idx file
# Fix path resolution - relative paths should be relative to metadata file directory # Fix path resolution - relative paths should be relative to metadata file directory
def _resolve_candidates( if not Path(index_file).is_absolute():
primary: str, if metadata_file_path:
relative_key: str, # Resolve relative to metadata file directory
default_name: Optional[str], metadata_dir = Path(metadata_file_path).parent
source_dict: dict[str, Any], logger.debug(
) -> list[Path]: f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
candidates: list[Path] = [] )
# 1) Primary as-is (absolute or relative) index_file = str((metadata_dir / index_file).resolve())
if primary: passage_file = str((metadata_dir / passage_file).resolve())
p = Path(primary) logger.debug(f"PassageManager: Resolved index_file: {index_file}")
candidates.append(p if p.is_absolute() else (Path.cwd() / p)) else:
# 2) metadata-relative explicit relative key # Fallback to current directory resolution (legacy behavior)
if metadata_file_path and source_dict.get(relative_key): logger.warning(
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key]) "PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
# 3) metadata-relative standard sibling filename )
if metadata_file_path and default_name: logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
candidates.append(Path(metadata_file_path).parent / default_name) index_file = str(Path(index_file).resolve())
return candidates passage_file = str(Path(passage_file).resolve())
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
idx_candidates = _resolve_candidates(
index_file, "index_path_relative", idx_default, source
)
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
def _pick_existing(cands: list[Path]) -> str:
for c in cands:
if c.exists():
return str(c.resolve())
# Fallback to last candidate (best guess) even if not exists; will error below
return str(cands[-1].resolve()) if cands else ""
index_file = _pick_existing(idx_candidates)
passage_file = _pick_existing(pas_candidates)
if not Path(index_file).exists(): if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}") raise FileNotFoundError(f"Passage index file not found: {index_file}")
@@ -204,18 +185,6 @@ class LeannBuilder:
**backend_kwargs, **backend_kwargs,
): ):
self.backend_name = backend_name self.backend_name = backend_name
# Normalize incompatible combinations early (for consistent metadata)
if backend_name == "hnsw":
is_recompute = backend_kwargs.get("is_recompute", True)
is_compact = backend_kwargs.get("is_compact", True)
if is_recompute is False and is_compact is True:
warnings.warn(
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
UserWarning,
stacklevel=2,
)
backend_kwargs["is_compact"] = False
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name) backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.") raise ValueError(f"Backend '{backend_name}' not found or not registered.")
@@ -368,12 +337,8 @@ class LeannBuilder:
"passage_sources": [ "passage_sources": [
{ {
"type": "jsonl", "type": "jsonl",
# Preserve existing relative file names (backward-compatible) "path": passages_file.name, # Use relative path (just filename)
"path": passages_file.name, "index_path": offset_file.name, # Use relative path (just filename)
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
} }
], ],
} }
@@ -488,12 +453,8 @@ class LeannBuilder:
"passage_sources": [ "passage_sources": [
{ {
"type": "jsonl", "type": "jsonl",
# Preserve existing relative file names (backward-compatible) "path": passages_file.name, # Use relative path (just filename)
"path": passages_file.name, "index_path": offset_file.name, # Use relative path (just filename)
"index_path": offset_file.name,
# Add optional redundant relative keys for remote build portability (non-breaking)
"path_relative": passages_file.name,
"index_path_relative": offset_file.name,
} }
], ],
"built_from_precomputed_embeddings": True, "built_from_precomputed_embeddings": True,
@@ -535,7 +496,6 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
# Delegate portability handling to PassageManager
self.passage_manager = PassageManager( self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
) )
@@ -617,7 +577,6 @@ class LeannSearcher:
enriched_results = [] enriched_results = []
if "labels" in results and "distances" in results: if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:") logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
# Python 3.9 does not support zip(strict=...); lengths are expected to match
for i, (string_id, dist) in enumerate( for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0]) zip(results["labels"][0], results["distances"][0])
): ):
@@ -645,42 +604,17 @@ class LeannSearcher:
) )
except KeyError: except KeyError:
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m"
logger.error( logger.error(
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}" f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
) )
# Define color codes outside the loop for final message
GREEN = "\033[92m"
RESET = "\033[0m"
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}") logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results return enriched_results
def cleanup(self): def cleanup(self):
"""Explicitly cleanup embedding server resources. """Cleanup embedding server and other resources."""
if hasattr(self.backend_impl, "cleanup"):
This method should be called after you're done using the searcher, self.backend_impl.cleanup()
especially in test environments or batch processing scenarios.
"""
if hasattr(self.backend_impl, "embedding_server_manager"):
self.backend_impl.embedding_server_manager.stop_server()
# Enable automatic cleanup patterns
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
try:
self.cleanup()
except Exception:
pass
def __del__(self):
try:
self.cleanup()
except Exception:
# Avoid noisy errors during interpreter shutdown
pass
class LeannChat: class LeannChat:
@@ -751,28 +685,3 @@ class LeannChat:
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
print("\nGoodbye!") print("\nGoodbye!")
break break
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
if hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()
# Enable automatic cleanup patterns
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
try:
self.cleanup()
except Exception:
pass
def __del__(self):
try:
self.cleanup()
except Exception:
pass

View File

@@ -422,6 +422,7 @@ class LLMInterface(ABC):
top_k=10, top_k=10,
complexity=64, complexity=64,
beam_width=8, beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True, skip_search_reorder=True,
recompute_beighbor_embeddings=True, recompute_beighbor_embeddings=True,
dedup_node_dis=True, dedup_node_dis=True,
@@ -433,6 +434,7 @@ class LLMInterface(ABC):
Supported kwargs: Supported kwargs:
- complexity (int): Search complexity parameter (default: 32) - complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4) - beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False) - skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False) - recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False) - dedup_node_dis (bool): Deduplicate nodes by distance (default: False)

View File

@@ -1,11 +1,10 @@
import argparse import argparse
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import Union from typing import Optional
from llama_index.core import SimpleDirectoryReader from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
@@ -72,18 +71,15 @@ class LeannCLI:
def create_parser(self) -> argparse.ArgumentParser: def create_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="leann", prog="leann",
description="The smallest vector index in the world. RAG Everything with LEANN!", description="LEANN - Local Enhanced AI Navigation",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples: Examples:
leann build my-docs --docs ./documents # Build index from directory leann build my-docs --docs ./documents # Build index named my-docs
leann build my-code --docs ./src ./tests ./config # Build index from multiple directories leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
leann build my-files --docs ./file1.py ./file2.txt ./docs/ # Build index from files and directories leann search my-docs "query" # Search in my-docs index
leann build my-mixed --docs ./readme.md ./src/ ./config.json # Build index from mixed files/dirs leann ask my-docs "question" # Ask my-docs index
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files leann list # List all stored indexes
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
""", """,
) )
@@ -91,29 +87,14 @@ Examples:
# Build command # Build command
build_parser = subparsers.add_parser("build", help="Build document index") build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument( build_parser.add_argument(
"index_name", nargs="?", help="Index name (default: current directory name)" "--docs", type=str, default=".", help="Documents directory (default: current directory)"
) )
build_parser.add_argument( build_parser.add_argument(
"--docs", "--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
type=str,
nargs="+",
default=["."],
help="Documents directories and/or files (default: current directory)",
)
build_parser.add_argument(
"--backend",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
help="Backend to use (default: hnsw)",
)
build_parser.add_argument(
"--embedding-model",
type=str,
default="facebook/contriever",
help="Embedding model (default: facebook/contriever)",
) )
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
build_parser.add_argument( build_parser.add_argument(
"--embedding-mode", "--embedding-mode",
type=str, type=str,
@@ -121,28 +102,12 @@ Examples:
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)", help="Embedding backend mode (default: sentence-transformers)",
) )
build_parser.add_argument( build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
"--force", "-f", action="store_true", help="Force rebuild existing index" build_parser.add_argument("--graph-degree", type=int, default=32)
) build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument(
"--graph-degree", type=int, default=32, help="Graph degree (default: 32)"
)
build_parser.add_argument(
"--complexity", type=int, default=64, help="Build complexity (default: 64)"
)
build_parser.add_argument("--num-threads", type=int, default=1) build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument( build_parser.add_argument("--compact", action="store_true", default=True)
"--compact", build_parser.add_argument("--recompute", action="store_true", default=True)
action=argparse.BooleanOptionalAction,
default=True,
help="Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.",
)
build_parser.add_argument(
"--recompute",
action=argparse.BooleanOptionalAction,
default=True,
help="Enable recomputation (default: true)",
)
build_parser.add_argument( build_parser.add_argument(
"--file-types", "--file-types",
type=str, type=str,
@@ -153,26 +118,20 @@ Examples:
search_parser = subparsers.add_parser("search", help="Search documents") search_parser = subparsers.add_parser("search", help="Search documents")
search_parser.add_argument("index_name", help="Index name") search_parser.add_argument("index_name", help="Index name")
search_parser.add_argument("query", help="Search query") search_parser.add_argument("query", help="Search query")
search_parser.add_argument( search_parser.add_argument("--top-k", type=int, default=5)
"--top-k", type=int, default=5, help="Number of results (default: 5)" search_parser.add_argument("--complexity", type=int, default=64)
)
search_parser.add_argument(
"--complexity", type=int, default=64, help="Search complexity (default: 64)"
)
search_parser.add_argument("--beam-width", type=int, default=1) search_parser.add_argument("--beam-width", type=int, default=1)
search_parser.add_argument("--prune-ratio", type=float, default=0.0) search_parser.add_argument("--prune-ratio", type=float, default=0.0)
search_parser.add_argument( search_parser.add_argument(
"--recompute", "--recompute-embeddings",
dest="recompute_embeddings", action="store_true",
action=argparse.BooleanOptionalAction,
default=True, default=True,
help="Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.", help="Recompute embeddings (default: True)",
) )
search_parser.add_argument( search_parser.add_argument(
"--pruning-strategy", "--pruning-strategy",
choices=["global", "local", "proportional"], choices=["global", "local", "proportional"],
default="global", default="global",
help="Pruning strategy (default: global)",
) )
# Ask command # Ask command
@@ -183,27 +142,19 @@ Examples:
type=str, type=str,
default="ollama", default="ollama",
choices=["simulated", "ollama", "hf", "openai"], choices=["simulated", "ollama", "hf", "openai"],
help="LLM provider (default: ollama)",
)
ask_parser.add_argument(
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
) )
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
ask_parser.add_argument("--host", type=str, default="http://localhost:11434") ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument( ask_parser.add_argument("--interactive", "-i", action="store_true")
"--interactive", "-i", action="store_true", help="Interactive chat mode" ask_parser.add_argument("--top-k", type=int, default=20)
)
ask_parser.add_argument(
"--top-k", type=int, default=20, help="Retrieval count (default: 20)"
)
ask_parser.add_argument("--complexity", type=int, default=32) ask_parser.add_argument("--complexity", type=int, default=32)
ask_parser.add_argument("--beam-width", type=int, default=1) ask_parser.add_argument("--beam-width", type=int, default=1)
ask_parser.add_argument("--prune-ratio", type=float, default=0.0) ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
ask_parser.add_argument( ask_parser.add_argument(
"--recompute", "--recompute-embeddings",
dest="recompute_embeddings", action="store_true",
action=argparse.BooleanOptionalAction,
default=True, default=True,
help="Enable/disable embedding recomputation during ask (default: enabled)", help="Recompute embeddings (default: True)",
) )
ask_parser.add_argument( ask_parser.add_argument(
"--pruning-strategy", "--pruning-strategy",
@@ -251,63 +202,6 @@ Examples:
with open(global_registry, "w") as f: with open(global_registry, "w") as f:
json.dump(projects, f, indent=2) json.dump(projects, f, indent=2)
def _build_gitignore_parser(self, docs_dir: str):
"""Build gitignore parser using gitignore-parser library."""
from gitignore_parser import parse_gitignore
# Try to parse the root .gitignore
gitignore_path = Path(docs_dir) / ".gitignore"
if gitignore_path.exists():
try:
# gitignore-parser automatically handles all subdirectory .gitignore files!
matches = parse_gitignore(str(gitignore_path))
print(f"📋 Loaded .gitignore from {docs_dir} (includes all subdirectories)")
return matches
except Exception as e:
print(f"Warning: Could not parse .gitignore: {e}")
else:
print("📋 No .gitignore found")
# Fallback: basic pattern matching for essential files
essential_patterns = {".git", ".DS_Store", "__pycache__", "node_modules", ".venv", "venv"}
def basic_matches(file_path):
path_parts = Path(file_path).parts
return any(part in essential_patterns for part in path_parts)
return basic_matches
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
"""Check if a file should be excluded using gitignore parser."""
return gitignore_matches(str(relative_path))
def _is_git_submodule(self, path: Path) -> bool:
"""Check if a path is a git submodule."""
try:
# Find the git repo root
current_dir = Path.cwd()
while current_dir != current_dir.parent:
if (current_dir / ".git").exists():
gitmodules_path = current_dir / ".gitmodules"
if gitmodules_path.exists():
# Read .gitmodules to check if this path is a submodule
gitmodules_content = gitmodules_path.read_text()
# Convert path to relative to git root
try:
relative_path = path.resolve().relative_to(current_dir)
# Check if this path appears in .gitmodules
return f"path = {relative_path}" in gitmodules_content
except ValueError:
# Path is not under git root
return False
break
current_dir = current_dir.parent
return False
except Exception:
# If anything goes wrong, assume it's not a submodule
return False
def list_indexes(self): def list_indexes(self):
print("Stored LEANN indexes:") print("Stored LEANN indexes:")
@@ -337,9 +231,7 @@ Examples:
valid_projects.append(current_path) valid_projects.append(current_path)
if not valid_projects: if not valid_projects:
print( print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
"No indexes found. Use 'leann build <name> --docs <dir> [<dir2> ...]' to create one."
)
return return
total_indexes = 0 total_indexes = 0
@@ -386,88 +278,41 @@ Examples:
print(f' leann search {example_name} "your query"') print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive") print(f" leann ask {example_name} --interactive")
def load_documents( def load_documents(self, docs_dir: str, custom_file_types: Optional[str] = None):
self, docs_paths: Union[str, list], custom_file_types: Union[str, None] = None print(f"Loading documents from {docs_dir}...")
):
# Handle both single path (string) and multiple paths (list) for backward compatibility
if isinstance(docs_paths, str):
docs_paths = [docs_paths]
# Separate files and directories
files = []
directories = []
for path in docs_paths:
path_obj = Path(path)
if path_obj.is_file():
files.append(str(path_obj))
elif path_obj.is_dir():
# Check if this is a git submodule - if so, skip it
if self._is_git_submodule(path_obj):
print(f"⚠️ Skipping git submodule: {path}")
continue
directories.append(str(path_obj))
else:
print(f"⚠️ Warning: Path '{path}' does not exist, skipping...")
continue
# Print summary of what we're processing
total_items = len(files) + len(directories)
items_desc = []
if files:
items_desc.append(f"{len(files)} file{'s' if len(files) > 1 else ''}")
if directories:
items_desc.append(
f"{len(directories)} director{'ies' if len(directories) > 1 else 'y'}"
)
print(f"Loading documents from {' and '.join(items_desc)} ({total_items} total):")
if files:
print(f" 📄 Files: {', '.join([Path(f).name for f in files])}")
if directories:
print(f" 📁 Directories: {', '.join(directories)}")
if custom_file_types: if custom_file_types:
print(f"Using custom file types: {custom_file_types}") print(f"Using custom file types: {custom_file_types}")
all_documents = [] # Try to use better PDF parsers first
documents = []
docs_path = Path(docs_dir)
# First, process individual files if any for file_path in docs_path.rglob("*.pdf"):
if files: print(f"Processing PDF: {file_path}")
print(f"\n🔄 Processing {len(files)} individual file{'s' if len(files) > 1 else ''}...")
# Load individual files using SimpleDirectoryReader with input_files # Try PyMuPDF first (best quality)
# Note: We skip gitignore filtering for explicitly specified files text = extract_pdf_text_with_pymupdf(str(file_path))
try: if text is None:
# Group files by their parent directory for efficient loading # Try pdfplumber
from collections import defaultdict text = extract_pdf_text_with_pdfplumber(str(file_path))
files_by_dir = defaultdict(list) if text:
for file_path in files: # Create a simple document structure
parent_dir = str(Path(file_path).parent) from llama_index.core import Document
files_by_dir[parent_dir].append(file_path)
# Load files from each parent directory doc = Document(text=text, metadata={"source": str(file_path)})
for parent_dir, file_list in files_by_dir.items(): documents.append(doc)
print( else:
f" Loading {len(file_list)} file{'s' if len(file_list) > 1 else ''} from {parent_dir}" # Fallback to default reader
) print(f"Using default reader for {file_path}")
try: default_docs = SimpleDirectoryReader(
file_docs = SimpleDirectoryReader( str(file_path.parent),
parent_dir, filename_as_id=True,
input_files=file_list, required_exts=[file_path.suffix],
filename_as_id=True, ).load_data()
).load_data() documents.extend(default_docs)
all_documents.extend(file_docs)
print(
f" ✅ Loaded {len(file_docs)} document{'s' if len(file_docs) > 1 else ''}"
)
except Exception as e:
print(f" ❌ Warning: Could not load files from {parent_dir}: {e}")
except Exception as e: # Load other file types with default reader
print(f"❌ Error processing individual files: {e}")
# Define file extensions to process
if custom_file_types: if custom_file_types:
# Parse custom file types from comma-separated string # Parse custom file types from comma-separated string
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()] code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
@@ -529,106 +374,20 @@ Examples:
".py", ".py",
".jl", ".jl",
] ]
# Try to load other file types, but don't fail if none are found
# Process each directory try:
if directories: other_docs = SimpleDirectoryReader(
print( docs_dir,
f"\n🔄 Processing {len(directories)} director{'ies' if len(directories) > 1 else 'y'}..." recursive=True,
) encoding="utf-8",
required_exts=code_extensions,
for docs_dir in directories: ).load_data(show_progress=True)
print(f"Processing directory: {docs_dir}") documents.extend(other_docs)
# Build gitignore parser for each directory except ValueError as e:
gitignore_matches = self._build_gitignore_parser(docs_dir) if "No files found" in str(e):
print("No additional files found for other supported types.")
# Try to use better PDF parsers first, but only if PDFs are requested else:
documents = [] raise e
docs_path = Path(docs_dir)
# Check if we should process PDFs
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
if should_process_pdfs:
for file_path in docs_path.rglob("*.pdf"):
# Check if file matches any exclude pattern
try:
relative_path = file_path.relative_to(docs_path)
if self._should_exclude_file(relative_path, gitignore_matches):
continue
except ValueError:
# Skip files that can't be made relative to docs_path
print(f"⚠️ Skipping file outside directory scope: {file_path}")
continue
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
try:
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
documents.extend(default_docs)
except Exception as e:
print(f"Warning: Could not process {file_path}: {e}")
# Load other file types with default reader
try:
# Create a custom file filter function using our PathSpec
def file_filter(
file_path: str, docs_dir=docs_dir, gitignore_matches=gitignore_matches
) -> bool:
"""Return True if file should be included (not excluded)"""
try:
docs_path_obj = Path(docs_dir)
file_path_obj = Path(file_path)
relative_path = file_path_obj.relative_to(docs_path_obj)
return not self._should_exclude_file(relative_path, gitignore_matches)
except (ValueError, OSError):
return True # Include files that can't be processed
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=code_extensions,
file_extractor={}, # Use default extractors
filename_as_id=True,
).load_data(show_progress=True)
# Filter documents after loading based on gitignore rules
filtered_docs = []
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents.extend(filtered_docs)
except ValueError as e:
if "No files found" in str(e):
print(f"No additional files found for other supported types in {docs_dir}.")
else:
raise e
all_documents.extend(documents)
print(f"Loaded {len(documents)} documents from {docs_dir}")
documents = all_documents
all_texts = [] all_texts = []
@@ -679,9 +438,7 @@ Examples:
".jl", ".jl",
} }
print("start chunking documents") for doc in documents:
# Add progress bar for document chunking
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
# Check if this is a code file based on source path # Check if this is a code file based on source path
source_path = doc.metadata.get("source", "") source_path = doc.metadata.get("source", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts) is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
@@ -697,36 +454,18 @@ Examples:
return all_texts return all_texts
async def build_index(self, args): async def build_index(self, args):
docs_paths = args.docs docs_dir = args.docs
# Use current directory name if index_name not provided index_name = args.index_name
if args.index_name:
index_name = args.index_name
else:
index_name = Path.cwd().name
print(f"Using current directory name as index: '{index_name}'")
index_dir = self.indexes_dir / index_name index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name) index_path = self.get_index_path(index_name)
# Display all paths being indexed with file/directory distinction print(f"📂 Indexing: {Path(docs_dir).resolve()}")
files = [p for p in docs_paths if Path(p).is_file()]
directories = [p for p in docs_paths if Path(p).is_dir()]
print(f"📂 Indexing {len(docs_paths)} path{'s' if len(docs_paths) > 1 else ''}:")
if files:
print(f" 📄 Files ({len(files)}):")
for i, file_path in enumerate(files, 1):
print(f" {i}. {Path(file_path).resolve()}")
if directories:
print(f" 📁 Directories ({len(directories)}):")
for i, dir_path in enumerate(directories, 1):
print(f" {i}. {Path(dir_path).resolve()}")
if index_dir.exists() and not args.force: if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.") print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return return
all_texts = self.load_documents(docs_paths, args.file_types) all_texts = self.load_documents(docs_dir, args.file_types)
if not all_texts: if not all_texts:
print("No documents found") print("No documents found")
return return
@@ -762,7 +501,7 @@ Examples:
if not self.index_exists(index_name): if not self.index_exists(index_name):
print( print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it." f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
) )
return return
@@ -789,7 +528,7 @@ Examples:
if not self.index_exists(index_name): if not self.index_exists(index_name):
print( print(
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it." f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
) )
return return

View File

@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
import logging import logging
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -373,9 +374,7 @@ def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434" texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using Ollama API with simplified batch processing. Compute embeddings using Ollama API.
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
Args: Args:
texts: List of texts to compute embeddings for texts: List of texts to compute embeddings for
@@ -439,19 +438,12 @@ def compute_embeddings_ollama(
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]): if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model) embedding_models.append(model)
# Check if model exists (handle versioned names) and resolve to full name # Check if model exists (handle versioned names)
resolved_model_name = None model_found = any(
for name in model_names: model_name == name.split(":")[0] or model_name == name for name in model_names
# Exact match )
if model_name == name:
resolved_model_name = name
break
# Match without version tag (use the versioned name)
elif model_name == name.split(":")[0]:
resolved_model_name = name
break
if not resolved_model_name: if not model_found:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n" error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model # Suggest pulling the model
@@ -473,11 +465,6 @@ def compute_embeddings_ollama(
error_msg += "\n📚 Browse more: https://ollama.com/library" error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg) raise ValueError(error_msg)
# Use the resolved model name for all subsequent operations
if resolved_model_name != model_name:
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
model_name = resolved_model_name
# Verify the model supports embeddings by testing it # Verify the model supports embeddings by testing it
try: try:
test_response = requests.post( test_response = requests.post(
@@ -498,148 +485,138 @@ def compute_embeddings_ollama(
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}") logger.warning(f"Could not verify model existence: {e}")
# Determine batch size based on device availability # Process embeddings with optimized concurrent processing
# Check for CUDA/MPS availability using torch if available import requests
batch_size = 32 # Default for MPS/CPU
try:
import torch
if torch.cuda.is_available(): def get_single_embedding(text_idx_tuple):
batch_size = 128 # CUDA gets larger batch size """Helper function to get embedding for a single text."""
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): text, idx = text_idx_tuple
batch_size = 32 # MPS gets smaller batch size max_retries = 3
except ImportError: retry_count = 0
# If torch is not available, use conservative batch size
batch_size = 32
logger.info(f"Using batch size: {batch_size}") # Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
def get_batch_embeddings(batch_texts): while retry_count < max_retries:
"""Get embeddings for a batch of texts.""" try:
all_embeddings = [] response = requests.post(
failed_indices = [] f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
for i, text in enumerate(batch_texts): result = response.json()
max_retries = 3 embedding = result.get("embedding")
retry_count = 0
# Truncate very long texts to avoid API issues if embedding is None:
truncated_text = text[:8000] if len(text) > 8000 else text raise ValueError(f"No embedding returned for text {idx}")
while retry_count < max_retries:
try: return idx, embedding
response = requests.post(
f"{host}/api/embeddings", except requests.exceptions.Timeout:
json={"model": model_name, "prompt": truncated_text}, retry_count += 1
timeout=30, if retry_count >= max_retries:
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
return idx, None
except Exception as e:
if retry_count >= max_retries - 1:
logger.error(f"Failed to get embedding for text {idx}: {e}")
return idx, None
retry_count += 1
return idx, None
# Determine if we should use concurrent processing
use_concurrent = (
len(texts) > 5 and not is_build
) # Don't use concurrent in build mode to avoid overwhelming
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
failed_indices = []
if use_concurrent:
logger.info(
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_idx = {
executor.submit(get_single_embedding, (text, idx)): idx
for idx, text in enumerate(texts)
}
# Add progress bar for concurrent processing
try:
if is_build or len(texts) > 10:
from tqdm import tqdm
futures_iterator = tqdm(
as_completed(future_to_idx),
total=len(texts),
desc="Computing Ollama embeddings",
) )
response.raise_for_status() else:
futures_iterator = as_completed(future_to_idx)
result = response.json() except ImportError:
embedding = result.get("embedding") futures_iterator = as_completed(future_to_idx)
if embedding is None:
raise ValueError(f"No embedding returned for text {i}")
if not isinstance(embedding, list) or len(embedding) == 0:
raise ValueError(f"Invalid embedding format for text {i}")
all_embeddings.append(embedding)
break
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {i} after {max_retries} retries")
failed_indices.append(i)
all_embeddings.append(None)
break
# Collect results as they complete
for future in futures_iterator:
try:
idx, embedding = future.result()
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
except Exception as e: except Exception as e:
retry_count += 1 idx = future_to_idx[future]
if retry_count >= max_retries: logger.error(f"Exception for text {idx}: {e}")
logger.error(f"Failed to get embedding for text {i}: {e}") failed_indices.append(idx)
failed_indices.append(i)
all_embeddings.append(None)
break
return all_embeddings, failed_indices
# Process texts in batches
all_embeddings = []
all_failed_indices = []
# Setup progress bar if needed
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
except ImportError:
show_progress = False
# Process batches
num_batches = (len(texts) + batch_size - 1) // batch_size
if show_progress:
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
else: else:
batch_iterator = range(num_batches) # Sequential processing with progress bar
show_progress = is_build or len(texts) > 10
for batch_idx in batch_iterator: try:
start_idx = batch_idx * batch_size if show_progress:
end_idx = min(start_idx + batch_size, len(texts)) from tqdm import tqdm
batch_texts = texts[start_idx:end_idx]
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts) iterator = tqdm(
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
)
else:
iterator = enumerate(texts)
except ImportError:
iterator = enumerate(texts)
# Adjust failed indices to global indices for idx, text in iterator:
global_failed = [start_idx + idx for idx in batch_failed] result_idx, embedding = get_single_embedding((text, idx))
all_failed_indices.extend(global_failed) if embedding is not None:
all_embeddings.extend(batch_embeddings) all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
# Handle failed embeddings # Handle failed embeddings
if all_failed_indices: if failed_indices:
if len(all_failed_indices) == len(texts): if len(failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings") raise RuntimeError("Failed to compute any embeddings")
logger.warning( logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
)
# Use zero embeddings as fallback for failed ones # Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None) valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding: if valid_embedding:
embedding_dim = len(valid_embedding) embedding_dim = len(valid_embedding)
for i, embedding in enumerate(all_embeddings): for idx in failed_indices:
if embedding is None: all_embeddings[idx] = [0.0] * embedding_dim
all_embeddings[i] = [0.0] * embedding_dim
# Remove None values # Remove None values and convert to numpy array
all_embeddings = [e for e in all_embeddings if e is not None] all_embeddings = [e for e in all_embeddings if e is not None]
if not all_embeddings:
raise RuntimeError("No valid embeddings were computed")
# Validate embedding dimensions
expected_dim = len(all_embeddings[0])
inconsistent_dims = []
for i, embedding in enumerate(all_embeddings):
if len(embedding) != expected_dim:
inconsistent_dims.append((i, len(embedding)))
if inconsistent_dims:
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
error_msg += f" - Text {idx}: {dim} dimensions\n"
if len(inconsistent_dims) > 10:
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
error_msg += (
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
)
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
raise ValueError(error_msg)
# Convert to numpy array and normalize # Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32) embeddings = np.array(all_embeddings, dtype=np.float32)

View File

@@ -1,6 +1,7 @@
import atexit import atexit
import logging import logging
import os import os
import signal
import socket import socket
import subprocess import subprocess
import sys import sys
@@ -8,7 +9,7 @@ import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
# Lightweight, self-contained server manager with no cross-process inspection import psutil
# Set up logging based on environment variable # Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -43,7 +44,130 @@ def _check_port(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0 return s.connect_ex(("localhost", port)) == 0
# Note: All cross-process scanning helpers removed for simplicity def _check_process_matches_config(
port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""
Check if the process using the port matches our expected model and passages file.
Returns True if matches, False otherwise.
"""
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
if not _is_process_listening_on_port(proc, port):
continue
cmdline = proc.info["cmdline"]
if not cmdline:
continue
return _check_cmdline_matches_config(
cmdline, port, expected_model, expected_passages_file
)
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
logger.warning(f"Could not check process on port {port}: {e}")
return False
def _is_process_listening_on_port(proc, port: int) -> bool:
"""Check if a process is listening on the given port."""
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
return True
return False
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return False
def _check_cmdline_matches_config(
cmdline: list, port: int, expected_model: str, expected_passages_file: str
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
server_type in cmdline_str
for server_type in [
"embedding_server",
"leann_backend_diskann.embedding_server",
"leann_backend_hnsw.hnsw_embedding_server",
]
)
if not is_embedding_server:
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
model_matches = _check_model_in_cmdline(cmdline, expected_model)
# Check passages file if provided
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
"""Check if the command line contains the expected model."""
if "--model-name" not in cmdline:
return False
model_idx = cmdline.index("--model-name")
if model_idx + 1 >= len(cmdline):
return False
actual_model = cmdline[model_idx + 1]
return actual_model == expected_model
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
"""Check if the command line contains the expected passages file."""
if "--passages-file" not in cmdline:
return False # Expected but not found
passages_idx = cmdline.index("--passages-file")
if passages_idx + 1 >= len(cmdline):
return False
actual_passages = cmdline[passages_idx + 1]
expected_path = Path(expected_passages_file).resolve()
actual_path = Path(actual_passages).resolve()
return actual_path == expected_path
def _find_compatible_port_or_next_available(
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
) -> tuple[int, bool]:
"""
Find a port that either has a compatible server or is available.
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
"""
for port in range(start_port, start_port + max_attempts):
if not _check_port(port):
# Port is available
return port, False
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
logger.info(f"Found compatible server on port {port}")
return port, True
else:
logger.info(f"Port {port} has incompatible server, trying next port...")
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
)
class EmbeddingServerManager: class EmbeddingServerManager:
@@ -62,16 +186,7 @@ class EmbeddingServerManager:
self.backend_module_name = backend_module_name self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None self.server_port: Optional[int] = None
# Track last-started config for in-process reuse only
self._server_config: Optional[dict] = None
self._atexit_registered = False self._atexit_registered = False
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
self._finalizer = None
def start_server( def start_server(
self, self,
@@ -81,24 +196,26 @@ class EmbeddingServerManager:
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here passages_file = kwargs.get("passages_file")
# If this manager already has a live server, just reuse it # Check if we have a compatible server already running
if self.server_process and self.server_process.poll() is None and self.server_port: if self._has_compatible_running_server(model_name, passages_file):
logger.info("Reusing in-process server") logger.info("Found compatible running server!")
return True, self.server_port return True, port
# For Colab environment, use a different strategy # For Colab environment, use a different strategy
if _is_colab_environment(): if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy") logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs) return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Always pick a fresh available port # Find a compatible port or next available
try: actual_port, is_compatible = _find_compatible_port_or_next_available(
actual_port = _get_available_port(port) port, model_name, passages_file
except RuntimeError: )
logger.error("No available ports found")
return False, port if is_compatible:
logger.info(f"Found compatible server on port {actual_port}")
return True, actual_port
# Start a new server # Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
@@ -131,7 +248,17 @@ class EmbeddingServerManager:
logger.error(f"Failed to start embedding server in Colab: {e}") logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port return False, actual_port
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
"""Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port):
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
return True
logger.info("Existing server process is incompatible. Should start a new server.")
return False
def _start_new_server( def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs self, port: int, model_name: str, embedding_mode: str, **kwargs
@@ -178,61 +305,23 @@ class EmbeddingServerManager:
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}") logger.info(f"Command: {' '.join(command)}")
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging # Let server output go directly to console
# Embedding servers use many print statements that can fill stdout buffers # The server will respect LEANN_LOG_LEVEL environment variable
is_ci = os.environ.get("CI") == "true"
if is_ci:
stdout_target = subprocess.DEVNULL
stderr_target = None # Keep stderr for error debugging in CI
logger.info(
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
)
else:
stdout_target = None # Direct to console for visible logs
stderr_target = None # Direct to console for visible logs
# Start embedding server subprocess
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
cwd=project_root, cwd=project_root,
stdout=stdout_target, stdout=None, # Direct to console
stderr=stderr_target, stderr=None, # Direct to console
start_new_session=True, # Create new process group for better cleanup
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
if "--model-name" in command
else "",
"passages_file": command[command.index("--passages-file") + 1]
if "--passages-file" in command
else "",
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process # Register atexit callback only when we actually start a process
if not self._atexit_registered: if not self._atexit_registered:
# Always attempt best-effort finalize at interpreter exit # Use a lambda to avoid issues with bound methods
atexit.register(self._finalize_process) atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True self._atexit_registered = True
# Touch finalizer so it knows there is a live process
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
try:
import weakref
self._finalizer = weakref.finalize(self, self._finalize_process)
except Exception:
pass
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready.""" """Wait for the server to be ready."""
@@ -257,35 +346,37 @@ class EmbeddingServerManager:
if not self.server_process: if not self.server_process:
return return
if self.server_process and self.server_process.poll() is not None: if self.server_process.poll() is not None:
# Process already terminated # Process already terminated
self.server_process = None self.server_process = None
self.server_port = None
self._server_config = None
return return
logger.info( logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..." f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
) )
# Use simple termination first; if the server installed signal handlers, # Try terminating the whole process group first
# it will exit cleanly. Otherwise escalate to kill after a short wait.
try: try:
self.server_process.terminate() pgid = os.getpgid(self.server_process.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception: except Exception:
pass # Fallback to terminating just the process
self.server_process.terminate()
try: try:
self.server_process.wait(timeout=5) # Give more time for graceful shutdown self.server_process.wait(timeout=3)
logger.info(f"Server process {self.server_process.pid} terminated gracefully.") logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.warning( logger.warning(
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..." f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
) )
# Try killing the whole process group
try: try:
self.server_process.kill() pgid = os.getpgid(self.server_process.pid)
os.killpg(pgid, signal.SIGKILL)
except Exception: except Exception:
pass # Fallback to killing just the process
self.server_process.kill()
try: try:
self.server_process.wait(timeout=2) self.server_process.wait(timeout=2)
logger.info(f"Server process {self.server_process.pid} killed successfully.") logger.info(f"Server process {self.server_process.pid} killed successfully.")
@@ -293,33 +384,20 @@ class EmbeddingServerManager:
logger.error( logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung" f"Failed to kill server process {self.server_process.pid} - it may be hung"
) )
# Don't hang indefinitely
# Clean up process resources with timeout to avoid CI hang # Clean up process resources to prevent resource tracker warnings
try: try:
# Use shorter timeout in CI environments self.server_process.wait(timeout=1) # Give it one final chance with timeout
is_ci = os.environ.get("CI") == "true"
timeout = 3 if is_ci else 10
self.server_process.wait(timeout=timeout)
logger.info(f"Server process {self.server_process.pid} cleanup completed")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway") logger.warning(
except Exception as e: f"Process {self.server_process.pid} still hanging after all kill attempts"
logger.warning(f"Error during process cleanup: {e}") )
finally: # Don't wait indefinitely - just abandon it
self.server_process = None
self.server_port = None
self._server_config = None
def _finalize_process(self) -> None:
"""Best-effort cleanup used by weakref.finalize/atexit."""
try:
self.stop_server()
except Exception: except Exception:
pass pass
def _adopt_existing_server(self, *args, **kwargs) -> None: self.server_process = None
# Removed: cross-process adoption no longer supported
return
def _launch_server_process_colab(self, command: list, port: int) -> None: def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
@@ -335,16 +413,10 @@ class EmbeddingServerManager:
self.server_port = port self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}") logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback (unified) # Register atexit callback
if not self._atexit_registered: if not self._atexit_registered:
atexit.register(self._finalize_process) atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout.""" """Wait for the server to be ready with Colab-specific timeout."""

View File

@@ -25,61 +25,32 @@ def handle_request(request):
"tools": [ "tools": [
{ {
"name": "leann_search", "name": "leann_search",
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase! "description": "Search LEANN index",
🎯 **Perfect for**:
- "How does authentication work?" → finds auth-related code
- "Error handling patterns" → locates try-catch blocks and error logic
- "Database connection setup" → finds DB initialization code
- "API endpoint definitions" → locates route handlers
- "Configuration management" → finds config files and usage
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
"inputSchema": { "inputSchema": {
"type": "object", "type": "object",
"properties": { "properties": {
"index_name": { "index_name": {"type": "string"},
"type": "string", "query": {"type": "string"},
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.", "top_k": {"type": "integer", "default": 5},
},
"query": {
"type": "string",
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
},
"top_k": {
"type": "integer",
"default": 5,
"minimum": 1,
"maximum": 20,
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
},
"complexity": {
"type": "integer",
"default": 32,
"minimum": 16,
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
}, },
"required": ["index_name", "query"], "required": ["index_name", "query"],
}, },
}, },
{ {
"name": "leann_status", "name": "leann_ask",
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!", "description": "Ask question using LEANN RAG",
"inputSchema": { "inputSchema": {
"type": "object", "type": "object",
"properties": { "properties": {
"index_name": { "index_name": {"type": "string"},
"type": "string", "question": {"type": "string"},
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
}
}, },
"required": ["index_name", "question"],
}, },
}, },
{ {
"name": "leann_list", "name": "leann_list",
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.", "description": "List all LEANN indexes",
"inputSchema": {"type": "object", "properties": {}}, "inputSchema": {"type": "object", "properties": {}},
}, },
] ]
@@ -92,40 +63,19 @@ def handle_request(request):
try: try:
if tool_name == "leann_search": if tool_name == "leann_search":
# Validate required parameters
if not args.get("index_name") or not args.get("query"):
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": "Error: Both index_name and query are required",
}
]
},
}
# Build simplified command
cmd = [ cmd = [
"leann", "leann",
"search", "search",
args["index_name"], args["index_name"],
args["query"], args["query"],
"--recompute-embeddings",
f"--top-k={args.get('top_k', 5)}", f"--top-k={args.get('top_k', 5)}",
f"--complexity={args.get('complexity', 32)}",
] ]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_status": elif tool_name == "leann_ask":
if args.get("index_name"): cmd = f'echo "{args["question"]}" | leann ask {args["index_name"]} --recompute-embeddings --llm ollama --model qwen3:8b'
# Check specific index status - for now, we'll use leann list and filter result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
# We could enhance this to show more detailed status per index
else:
# Show all indexes status
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
elif tool_name == "leann_list": elif tool_name == "leann_list":
result = subprocess.run(["leann", "list"], capture_output=True, text=True) result = subprocess.run(["leann", "list"], capture_output=True, text=True)

View File

@@ -132,10 +132,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
import msgpack import msgpack
import zmq import zmq
context = None
socket = None
try: try:
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout socket.setsockopt(zmq.LINGER, 0) # Don't block on close
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
socket.setsockopt(zmq.IMMEDIATE, 1)
socket.connect(f"tcp://localhost:{zmq_port}") socket.connect(f"tcp://localhost:{zmq_port}")
# Send embedding request # Send embedding request
@@ -147,9 +152,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
response_bytes = socket.recv() response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes) response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Convert response to numpy array # Convert response to numpy array
if isinstance(response, list) and len(response) > 0: if isinstance(response, list) and len(response) > 0:
return np.array(response, dtype=np.float32) return np.array(response, dtype=np.float32)
@@ -158,6 +160,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to compute embeddings via server: {e}") raise RuntimeError(f"Failed to compute embeddings via server: {e}")
finally:
if socket:
socket.close()
# Don't call context.term() - this was causing hangs
@abstractmethod @abstractmethod
def search( def search(
@@ -191,7 +197,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
""" """
pass pass
def __del__(self): def cleanup(self):
"""Ensures the embedding server is stopped when the searcher is destroyed.""" """Cleanup resources including embedding server."""
if hasattr(self, "embedding_server_manager"): if hasattr(self, "embedding_server_manager"):
self.embedding_server_manager.stop_server() self.embedding_server_manager.stop_server()
def __del__(self):
"""Ensures resources are cleaned up when the searcher is destroyed."""
try:
self.cleanup()
except Exception:
# Ignore errors during destruction
pass

View File

@@ -4,12 +4,20 @@ Transform your development workflow with intelligent code assistance using LEANN
## Prerequisites ## Prerequisites
Install LEANN globally for MCP integration (with default backend): **Step 1:** First, complete the basic LEANN installation following the [📦 Installation guide](../../README.md#installation) in the root README:
```bash ```bash
uv tool install leann-core --with leann uv venv
source .venv/bin/activate
uv pip install leann
``` ```
This installs the `leann` CLI into an isolated tool environment and includes both backends so `leann build` works out-of-the-box.
**Step 2:** Install LEANN globally for MCP integration:
```bash
uv tool install leann-core
```
This makes the `leann` command available system-wide, which `leann_mcp` requires.
## 🚀 Quick Setup ## 🚀 Quick Setup
@@ -37,42 +45,6 @@ leann build my-project --docs ./
claude claude
``` ```
## 🚀 Advanced Usage Examples
### Index Entire Git Repository
```bash
# Index all tracked files in your git repository, note right now we will skip submodules, but we can add it back easily if you want
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Index only specific file types from git
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
```
### Multiple Directories and Files
```bash
# Index multiple directories
leann build my-codebase --docs ./src ./tests ./docs ./config --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Mix files and directories
leann build my-project --docs ./README.md ./src/ ./package.json ./docs/ --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Specific files only
leann build my-configs --docs ./tsconfig.json ./package.json ./webpack.config.js --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
```
### Advanced Git Integration
```bash
# Index recently modified files
leann build recent-changes --docs $(git diff --name-only HEAD~10..HEAD) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Index files matching pattern
leann build frontend --docs $(git ls-files "*.tsx" "*.ts" "*.jsx" "*.js") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
# Index documentation and config files
leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.json" "*.toml") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
```
**Try this in Claude Code:** **Try this in Claude Code:**
``` ```
Help me understand this codebase. List available indexes and search for authentication patterns. Help me understand this codebase. List available indexes and search for authentication patterns.

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.2.9" version = "0.2.5"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -40,13 +40,10 @@ dependencies = [
# Other dependencies # Other dependencies
"ipykernel==6.29.5", "ipykernel==6.29.5",
"msgpack>=1.1.1", "msgpack>=1.1.1",
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'", "mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'", "mlx-lm>=0.26.0; sys_platform == 'darwin'",
"psutil>=5.8.0", "psutil>=5.8.0",
"pybind11>=3.0.0", "pybind11>=3.0.0",
"pathspec>=0.12.1",
"nbconvert>=7.16.6",
"gitignore-parser>=0.1.12",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -63,7 +60,7 @@ dev = [
test = [ test = [
"pytest>=7.0", "pytest>=7.0",
"pytest-timeout>=2.0", "pytest-timeout>=2.0", # Simple timeout protection for CI
"llama-index-core>=0.12.0", "llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", "llama-index-readers-file>=0.4.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
@@ -155,7 +152,7 @@ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')", "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key", "openai: marks tests that require OpenAI API key",
] ]
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety timeout = 300 # Simple timeout for CI safety (5 minutes)
addopts = [ addopts = [
"-v", "-v",
"--tb=short", "--tb=short",

View File

@@ -1,76 +0,0 @@
name: leann-build
resources:
# Choose a GPU for fast embeddings (examples: L4, A10G, A100). CPU also works but is slower.
accelerators: L4:1
# Optionally pin a cloud, otherwise SkyPilot will auto-select
# cloud: aws
disk_size: 100
envs:
# Build parameters (override with: sky launch -c leann-gpu sky/leann-build.yaml -e key=value)
index_name: my-index
docs: ./data
backend: hnsw # hnsw | diskann
complexity: 64
graph_degree: 32
num_threads: 8
# Embedding selection
embedding_mode: sentence-transformers # sentence-transformers | openai | mlx | ollama
embedding_model: facebook/contriever
# Storage/latency knobs
recompute: true # true => selective recomputation (recommended)
compact: true # for HNSW only
# Optional pass-through
extra_args: ""
# Rebuild control
force: true
# Sync local paths to the remote VM. Adjust as needed.
file_mounts:
# Example: mount your local data directory used for building
~/leann-data: ${docs}
setup: |
set -e
# Install uv (package manager)
curl -LsSf https://astral.sh/uv/install.sh | sh
export PATH="$HOME/.local/bin:$PATH"
# Ensure modern libstdc++ for FAISS (GLIBCXX >= 3.4.30)
sudo apt-get update -y
sudo apt-get install -y libstdc++6 libgomp1
# Also upgrade conda's libstdc++ in base env (Skypilot images include conda)
if command -v conda >/dev/null 2>&1; then
conda install -y -n base -c conda-forge libstdcxx-ng
fi
# Install LEANN CLI and backends into the user environment
uv pip install --upgrade pip
uv pip install leann-core leann-backend-hnsw leann-backend-diskann
run: |
export PATH="$HOME/.local/bin:$PATH"
# Derive flags from env
recompute_flag=""
if [ "${recompute}" = "false" ] || [ "${recompute}" = "0" ]; then
recompute_flag="--no-recompute"
fi
force_flag=""
if [ "${force}" = "true" ] || [ "${force}" = "1" ]; then
force_flag="--force"
fi
# Build command
python -m leann.cli build ${index_name} \
--docs ~/leann-data \
--backend ${backend} \
--complexity ${complexity} \
--graph-degree ${graph_degree} \
--num-threads ${num_threads} \
--embedding-mode ${embedding_mode} \
--embedding-model ${embedding_model} \
${recompute_flag} ${force_flag} ${extra_args}
# Print where the index is stored for downstream rsync
echo "INDEX_OUT_DIR=~/.leann/indexes/${index_name}"

41
tests/conftest.py Normal file
View File

@@ -0,0 +1,41 @@
"""Pytest configuration and fixtures for LEANN tests."""
import os
import pytest
@pytest.fixture(autouse=True)
def test_environment():
"""Set up test environment variables."""
# Mark as test environment to skip memory-intensive operations
os.environ["CI"] = "true"
yield
@pytest.fixture(scope="session", autouse=True)
def cleanup_session():
"""Session-level cleanup to ensure no hanging processes."""
yield
# Basic cleanup after all tests
try:
import os
import psutil
current_process = psutil.Process(os.getpid())
children = current_process.children(recursive=True)
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass
# Give them time to terminate gracefully
psutil.wait_procs(children, timeout=3)
except Exception:
# Don't fail tests due to cleanup errors
pass

View File

@@ -64,9 +64,6 @@ def test_backend_basic(backend_name):
assert isinstance(results[0], SearchResult) assert isinstance(results[0], SearchResult)
assert "topic 2" in results[0].text or "document" in results[0].text assert "topic 2" in results[0].text or "document" in results[0].text
# Ensure cleanup to avoid hanging background servers
searcher.cleanup()
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues" os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
@@ -93,5 +90,3 @@ def test_large_index():
searcher = LeannSearcher(index_path) searcher = LeannSearcher(index_path)
results = searcher.search(["word10 word20"], top_k=10) results = searcher.search(["word10 word20"], top_k=10)
assert len(results[0]) == 10 assert len(results[0]) == 10
# Cleanup
searcher.cleanup()

View File

@@ -58,9 +58,6 @@ def test_document_rag_simulated(test_data_dir):
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available") @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
@pytest.mark.skipif(
os.environ.get("CI") == "true", reason="Skip OpenAI tests in CI to avoid API costs"
)
def test_document_rag_openai(test_data_dir): def test_document_rag_openai(test_data_dir):
"""Test document_rag with OpenAI embeddings.""" """Test document_rag with OpenAI embeddings."""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:

View File

@@ -16,9 +16,6 @@ def test_readme_basic_example(backend_name):
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2 # Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
if os.environ.get("CI") == "true" and platform.system() == "Darwin": if os.environ.get("CI") == "true" and platform.system() == "Darwin":
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2") pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
# Skip DiskANN on CI (Linux runners) due to C++ extension memory/hardware constraints
if os.environ.get("CI") == "true" and backend_name == "diskann":
pytest.skip("Skip DiskANN tests in CI due to resource constraints and instability")
# This is the exact code from README (with smaller model for CI) # This is the exact code from README (with smaller model for CI)
from leann import LeannBuilder, LeannChat, LeannSearcher from leann import LeannBuilder, LeannChat, LeannSearcher
@@ -62,9 +59,6 @@ def test_readme_basic_example(backend_name):
# The second text about banana-crocodile should be more relevant # The second text about banana-crocodile should be more relevant
assert "banana" in results[0].text or "crocodile" in results[0].text assert "banana" in results[0].text or "crocodile" in results[0].text
# Ensure we cleanup background embedding server
searcher.cleanup()
# Chat with your data (using simulated LLM to avoid external dependencies) # Chat with your data (using simulated LLM to avoid external dependencies)
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"}) chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
response = chat.ask("How much storage does LEANN save?", top_k=1) response = chat.ask("How much storage does LEANN save?", top_k=1)
@@ -72,8 +66,6 @@ def test_readme_basic_example(backend_name):
# Verify chat works # Verify chat works
assert isinstance(response, str) assert isinstance(response, str)
assert len(response) > 0 assert len(response) > 0
# Cleanup chat resources
chat.cleanup()
def test_readme_imports(): def test_readme_imports():

7338
uv.lock generated
View File

File diff suppressed because it is too large Load Diff