Compare commits
94 Commits
v0.1.15
...
feature/gr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9e5d5d6aa | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
a437f558a3 | ||
|
|
742c9baabc | ||
|
|
60eef4b440 | ||
|
|
f2c5355c73 | ||
|
|
439debbd3f | ||
|
|
3ff5aac8e0 | ||
|
|
a35bfb0354 | ||
|
|
a6dad47280 | ||
|
|
67fef60466 | ||
|
|
131f10b286 | ||
|
|
e3762458fc | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
05e1efa00a | ||
|
|
6363fc5f83 | ||
|
|
319dc34a24 | ||
|
|
72a5993f02 | ||
|
|
250272a3be | ||
|
|
042da1fe09 | ||
|
|
2d9c183ebb | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
a8421c0475 | ||
|
|
0ec00e1a60 | ||
|
|
777b5fed01 | ||
|
|
440ad6e816 | ||
|
|
5567302316 | ||
|
|
8714472cd8 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c799d61a5a | ||
|
|
c96d653072 | ||
|
|
e409933149 | ||
|
|
bc31876a9f | ||
|
|
e421c44b8b | ||
|
|
af69aa0508 | ||
|
|
575b354976 | ||
|
|
65bbff1d93 | ||
|
|
df798d350d | ||
|
|
3fa6b2aa17 | ||
|
|
ba95554fe7 | ||
|
|
677eb0bae3 | ||
|
|
9cdfcec331 | ||
|
|
f30d1a2530 | ||
|
|
df69a49123 | ||
|
|
65b54ff905 | ||
|
|
4db3e94f35 | ||
|
|
a2568f3ddc | ||
|
|
45bdad4fa7 | ||
|
|
8b538d1ef9 | ||
|
|
ada8bcbc70 | ||
|
|
6061e8f2de | ||
|
|
9842ad8330 | ||
|
|
7d920f9071 | ||
|
|
f28f15000c | ||
|
|
1d657fd9f6 | ||
|
|
d217adbe40 | ||
|
|
f790ec634f | ||
|
|
b8da9d7b12 | ||
|
|
0cb0463929 | ||
|
|
b982241249 | ||
|
|
c66f197e1d | ||
|
|
4a1353761a | ||
|
|
a72090d2ab | ||
|
|
669e622430 | ||
|
|
77d7b60a61 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a |
9
.github/workflows/build-and-publish.yml
vendored
9
.github/workflows/build-and-publish.yml
vendored
@@ -5,7 +5,16 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
debug_enabled:
|
||||
type: boolean
|
||||
description: 'Run with tmate debugging enabled (SSH access to runner)'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
debug_enabled: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled || false }}
|
||||
|
||||
298
.github/workflows/build-reusable.yml
vendored
298
.github/workflows/build-reusable.yml
vendored
@@ -8,6 +8,11 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
debug_enabled:
|
||||
description: 'Enable tmate debugging session for troubleshooting'
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
@@ -28,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff
|
||||
uv tool install ruff==0.12.7
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
@@ -97,7 +102,8 @@ jobs:
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
# Don't install LLVM, use system clang for better compatibility
|
||||
brew install libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
@@ -110,17 +116,19 @@ jobs:
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
# Build core (platform independent) on all platforms for consistency
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
@@ -129,18 +137,21 @@ jobs:
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# sgesdd_ is only available on macOS 13.3+
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
# Build meta package (platform independent) on all platforms
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
@@ -154,10 +165,15 @@ jobs:
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
# Repair DiskANN wheel - use show first to debug
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
echo "Checking DiskANN wheel contents before repair:"
|
||||
unzip -l dist/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found"
|
||||
auditwheel show dist/*.whl || echo "auditwheel show failed"
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
echo "Checking DiskANN wheel contents after repair:"
|
||||
unzip -l dist_repaired/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found after repair"
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
@@ -189,6 +205,254 @@ jobs:
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
uv venv --python python${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install the built wheels directly to ensure we use locally built packages
|
||||
# Use only locally built wheels on all platforms for full consistency
|
||||
FIND_LINKS="--find-links packages/leann-core/dist --find-links packages/leann/dist"
|
||||
FIND_LINKS="$FIND_LINKS --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist"
|
||||
|
||||
uv pip install leann-core leann leann-backend-hnsw leann-backend-diskann \
|
||||
$FIND_LINKS --force-reinstall
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
# Debug: Check if _diskannpy module is installed correctly
|
||||
echo "Checking installed DiskANN module structure:"
|
||||
python -c "import leann_backend_diskann; print('leann_backend_diskann location:', leann_backend_diskann.__file__)" || echo "Failed to import leann_backend_diskann"
|
||||
python -c "from leann_backend_diskann import _diskannpy; print('_diskannpy imported successfully')" || echo "Failed to import _diskannpy"
|
||||
ls -la $(python -c "import leann_backend_diskann; import os; print(os.path.dirname(leann_backend_diskann.__file__))" 2>/dev/null) 2>/dev/null || echo "Failed to list module directory"
|
||||
|
||||
# Extra debugging for Python 3.13
|
||||
if [[ "${{ matrix.python }}" == "3.13" ]]; then
|
||||
echo "=== Python 3.13 Debug Info ==="
|
||||
echo "Python version details:"
|
||||
python --version
|
||||
python -c "import sys; print(f'sys.version_info: {sys.version_info}')"
|
||||
|
||||
echo "Pytest version:"
|
||||
python -m pytest --version
|
||||
|
||||
echo "Testing basic pytest collection:"
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
timeout --signal=INT 10 python -m pytest --collect-only tests/test_ci_minimal.py -v || echo "Collection timed out or failed"
|
||||
else
|
||||
# No timeout on macOS/Windows
|
||||
python -m pytest --collect-only tests/test_ci_minimal.py -v || echo "Collection failed"
|
||||
fi
|
||||
|
||||
echo "Testing single simple test:"
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
timeout --signal=INT 10 python -m pytest tests/test_ci_minimal.py::test_package_imports --full-trace -v || echo "Simple test timed out or failed"
|
||||
else
|
||||
# No timeout on macOS/Windows
|
||||
python -m pytest tests/test_ci_minimal.py::test_package_imports --full-trace -v || echo "Simple test failed"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable tmate debugging session if requested
|
||||
- name: Setup tmate session for debugging
|
||||
if: ${{ inputs.debug_enabled }}
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
with:
|
||||
detached: true
|
||||
timeout-minutes: 30
|
||||
limit-access-to-actor: true
|
||||
|
||||
- name: Run tests with pytest
|
||||
# Timeout hierarchy:
|
||||
# 1. Individual test timeout: 20s (see pyproject.toml markers)
|
||||
# 2. Pytest session timeout: 300s (see pyproject.toml [tool.pytest.ini_options])
|
||||
# 3. Outer shell timeout: 360s (300s + 60s buffer for cleanup)
|
||||
# 4. GitHub Actions job timeout: 6 hours (default)
|
||||
env:
|
||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
HF_HUB_DISABLE_SYMLINKS: 1
|
||||
TOKENIZERS_PARALLELISM: false
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Define comprehensive diagnostic function
|
||||
diag() {
|
||||
echo "===== COMPREHENSIVE DIAGNOSTICS BEGIN ====="
|
||||
date
|
||||
echo ""
|
||||
echo "### Current Shell Info ###"
|
||||
echo "Shell PID: $$"
|
||||
echo "Shell PPID: $PPID"
|
||||
echo "Current directory: $(pwd)"
|
||||
echo ""
|
||||
|
||||
echo "### Process Tree (full) ###"
|
||||
pstree -ap 2>/dev/null || ps auxf || true
|
||||
echo ""
|
||||
|
||||
echo "### All Python/Pytest Processes ###"
|
||||
ps -ef | grep -E 'python|pytest' | grep -v grep || true
|
||||
echo ""
|
||||
|
||||
echo "### Embedding Server Processes ###"
|
||||
ps -ef | grep -E 'embedding|zmq|diskann' | grep -v grep || true
|
||||
echo ""
|
||||
|
||||
echo "### Network Listeners ###"
|
||||
ss -ltnp 2>/dev/null || netstat -ltn 2>/dev/null || true
|
||||
echo ""
|
||||
|
||||
echo "### Open File Descriptors (lsof) ###"
|
||||
lsof -p $$ 2>/dev/null | head -20 || true
|
||||
echo ""
|
||||
|
||||
echo "### Zombie Processes ###"
|
||||
ps aux | grep '<defunct>' || echo "No zombie processes"
|
||||
echo ""
|
||||
|
||||
echo "### Current Jobs ###"
|
||||
jobs -l || true
|
||||
echo ""
|
||||
|
||||
echo "### /proc/PID/fd for current shell ###"
|
||||
ls -la /proc/$$/fd 2>/dev/null || true
|
||||
echo ""
|
||||
|
||||
echo "===== COMPREHENSIVE DIAGNOSTICS END ====="
|
||||
}
|
||||
|
||||
# Enable verbose logging for debugging
|
||||
export PYTHONUNBUFFERED=1
|
||||
export PYTEST_CURRENT_TEST=1
|
||||
|
||||
# Run all tests with extensive logging
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
echo "🚀 Starting Linux test execution with timeout..."
|
||||
echo "Current time: $(date)"
|
||||
echo "Shell PID: $$"
|
||||
echo "Python: $(python --version)"
|
||||
echo "Pytest: $(pytest --version)"
|
||||
|
||||
# Show environment variables for debugging
|
||||
echo "📦 Environment variables:"
|
||||
env | grep -E "PYTHON|PYTEST|CI|RUNNER" | sort
|
||||
|
||||
# Set trap for diagnostics
|
||||
trap diag INT TERM EXIT
|
||||
|
||||
echo "📋 Pre-test diagnostics:"
|
||||
ps -ef | grep -E 'python|pytest' | grep -v grep || echo "No python/pytest processes before test"
|
||||
|
||||
# Check for any listening ports before test
|
||||
echo "🔌 Pre-test network state:"
|
||||
ss -ltn 2>/dev/null | grep -E "555[0-9]|556[0-9]" || echo "No embedding server ports open"
|
||||
|
||||
# Set timeouts - outer must be larger than pytest's internal timeout
|
||||
# IMPORTANT: Keep PYTEST_TIMEOUT_SEC in sync with pyproject.toml [tool.pytest.ini_options] timeout
|
||||
PYTEST_TIMEOUT_SEC=${PYTEST_TIMEOUT_SEC:-300} # Default 300s, matches pyproject.toml
|
||||
BUFFER_SEC=${TIMEOUT_BUFFER_SEC:-60} # Buffer for cleanup after pytest timeout
|
||||
OUTER_TIMEOUT_SEC=${OUTER_TIMEOUT_SEC:-$((PYTEST_TIMEOUT_SEC + BUFFER_SEC))}
|
||||
|
||||
echo "⏰ Timeout configuration:"
|
||||
echo " - Pytest internal timeout: ${PYTEST_TIMEOUT_SEC}s (from pyproject.toml)"
|
||||
echo " - Cleanup buffer: ${BUFFER_SEC}s"
|
||||
echo " - Outer shell timeout: ${OUTER_TIMEOUT_SEC}s (${PYTEST_TIMEOUT_SEC}s + ${BUFFER_SEC}s buffer)"
|
||||
echo " - This ensures pytest can complete its own timeout handling and cleanup"
|
||||
|
||||
echo "🏃 Running pytest with ${OUTER_TIMEOUT_SEC}s outer timeout..."
|
||||
|
||||
# Export for inner shell
|
||||
export PYTEST_TIMEOUT_SEC OUTER_TIMEOUT_SEC BUFFER_SEC
|
||||
|
||||
timeout --preserve-status --signal=INT --kill-after=10 ${OUTER_TIMEOUT_SEC} bash -c '
|
||||
echo "⏱️ Pytest starting at: $(date)"
|
||||
echo "Running command: pytest tests/ -vv --maxfail=3 --tb=short --capture=no"
|
||||
|
||||
# Run pytest with maximum verbosity and no output capture
|
||||
pytest tests/ -vv --maxfail=3 --tb=short --capture=no --log-cli-level=DEBUG 2>&1 | tee pytest.log
|
||||
PYTEST_EXIT=${PIPESTATUS[0]}
|
||||
|
||||
echo "✅ Pytest finished at: $(date) with exit code: $PYTEST_EXIT"
|
||||
echo "Last 20 lines of pytest output:"
|
||||
tail -20 pytest.log || true
|
||||
|
||||
# Immediately check for leftover processes
|
||||
echo "🔍 Post-pytest process check:"
|
||||
ps -ef | grep -E "python|pytest|embedding" | grep -v grep || echo "No leftover processes"
|
||||
|
||||
# Clean up any children before exit
|
||||
echo "🧹 Cleaning up child processes..."
|
||||
pkill -TERM -P $$ 2>/dev/null || true
|
||||
sleep 0.5
|
||||
pkill -KILL -P $$ 2>/dev/null || true
|
||||
|
||||
echo "📊 Final check before exit:"
|
||||
ps -ef | grep -E "python|pytest|embedding" | grep -v grep || echo "All clean"
|
||||
|
||||
exit $PYTEST_EXIT
|
||||
'
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "🔚 Timeout command exited with code: $EXIT_CODE"
|
||||
|
||||
if [ $EXIT_CODE -eq 124 ]; then
|
||||
echo "⚠️ TIMEOUT TRIGGERED - Tests took more than ${OUTER_TIMEOUT_SEC} seconds!"
|
||||
echo "📸 Capturing full diagnostics..."
|
||||
diag
|
||||
|
||||
# Run diagnostic script if available
|
||||
if [ -f scripts/diagnose_hang.sh ]; then
|
||||
echo "🔍 Running diagnostic script..."
|
||||
bash scripts/diagnose_hang.sh || true
|
||||
fi
|
||||
|
||||
# More aggressive cleanup
|
||||
echo "💀 Killing all Python processes owned by runner..."
|
||||
pkill -9 -u runner python || true
|
||||
pkill -9 -u runner pytest || true
|
||||
elif [ $EXIT_CODE -ne 0 ]; then
|
||||
echo "❌ Tests failed with exit code: $EXIT_CODE"
|
||||
else
|
||||
echo "✅ All tests passed!"
|
||||
fi
|
||||
|
||||
# Always show final state
|
||||
echo "📍 Final state check:"
|
||||
ps -ef | grep -E 'python|pytest|embedding' | grep -v grep || echo "No Python processes remaining"
|
||||
|
||||
exit $EXIT_CODE
|
||||
else
|
||||
# For macOS/Windows, run without GNU timeout
|
||||
echo "🚀 Running tests on $RUNNER_OS..."
|
||||
pytest tests/ -vv --maxfail=3 --tb=short --capture=no --log-cli-level=INFO
|
||||
fi
|
||||
|
||||
# Provide tmate session on test failure for debugging
|
||||
- name: Setup tmate session on failure
|
||||
if: ${{ failure() && (inputs.debug_enabled || contains(github.event.head_commit.message, '[debug]')) }}
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
with:
|
||||
timeout-minutes: 30
|
||||
limit-access-to-actor: true
|
||||
|
||||
- name: Run sanity checks (optional)
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run distance function tests if available
|
||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||
echo "Running distance function sanity checks..."
|
||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Link Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1"
|
||||
|
||||
jobs:
|
||||
link-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -34,11 +34,15 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/*
|
||||
!examples/data/2501.14312v1 (1).pdf
|
||||
!examples/data/2506.08276v1.pdf
|
||||
!examples/data/PrideandPrejudice.txt
|
||||
!examples/data/README.md
|
||||
data/*
|
||||
!data/2501.14312v1 (1).pdf
|
||||
!data/2506.08276v1.pdf
|
||||
!data/PrideandPrejudice.txt
|
||||
!data/huawei_pangu.md
|
||||
!data/ground_truth/
|
||||
!data/indices/
|
||||
!data/queries/
|
||||
!data/.gitattributes
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
@@ -86,3 +90,5 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
@@ -10,7 +10,7 @@ repos:
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.1
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-format
|
||||
|
||||
297
README.md
297
README.md
@@ -6,6 +6,7 @@
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
|
||||
</p>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
@@ -16,7 +17,10 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +30,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
@@ -41,40 +45,40 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary><strong>📦 Prerequisites: Install uv (if you don't have it)</strong></summary>
|
||||
### 📦 Prerequisites: Install uv
|
||||
|
||||
Install uv first if you don't have it:
|
||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
📖 [Detailed uv installation methods →](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
|
||||
### 🚀 Quick Install
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
LEANN provides two installation methods: **pip install** (quick and easy) and **build from source** (recommended for development).
|
||||
|
||||
|
||||
|
||||
### 🚀 Quick Install (Recommended for most users)
|
||||
|
||||
Clone the repository to access all examples and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
Clone the repository to access all examples and try amazing applications,
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
```
|
||||
|
||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
### 🔧 Build from Source (Recommended for development)
|
||||
<details>
|
||||
<summary>
|
||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||
</summary>
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
@@ -91,17 +95,17 @@ sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev l
|
||||
uv sync
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
## Quick Star
|
||||
## Quick Start
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
|
||||
[](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb) [Try in this ipynb file →](demo.ipynb)
|
||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannCha
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
@@ -122,11 +126,11 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
> **Generation Model Setup**
|
||||
> LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
@@ -166,7 +170,52 @@ ollama pull llama3.2:1b
|
||||
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
### ⭐ Flexible Configuration
|
||||
|
||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||
|
||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||
|
||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||
|
||||
```bash
|
||||
# Core Parameters (General preprocessing for all examples)
|
||||
--index-dir DIR # Directory to store the index (default: current directory)
|
||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||
--force-rebuild # Force rebuild index even if it exists
|
||||
|
||||
# Embedding Parameters
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
# Search Parameters
|
||||
--top-k N # Number of results to retrieve (default: 20)
|
||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||
|
||||
# Chunking Parameters
|
||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||
|
||||
# Index Building Parameters
|
||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||
--graph-degree N # Graph degree for index construction (default: 32)
|
||||
--build-complexity N # Build complexity for index construction (default: 64)
|
||||
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
||||
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
@@ -174,15 +223,32 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
||||
|
||||
```
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
```bash
|
||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--data-dir DIR # Directory containing documents to process (default: data)
|
||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Process all documents with larger chunks for academic papers
|
||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||
|
||||
# Filter only markdown and Python files with smaller chunks
|
||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
@@ -193,30 +259,29 @@ python ./examples/main_cli_example.py
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
|
||||
```bash
|
||||
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||
```
|
||||
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default mail path (works for most macOS setups)
|
||||
python examples/mail_reader_leann.py
|
||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||
--include-html # Include HTML content in processing (useful for newsletters)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search work emails from a specific account
|
||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||
|
||||
# Process all emails (may take time but indexes everything)
|
||||
python examples/mail_reader_leann.py --max-emails -1
|
||||
|
||||
# Limit number of emails processed (useful for testing)
|
||||
python examples/mail_reader_leann.py --max-emails 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||
# Find all receipts and order confirmations (includes HTML)
|
||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -237,25 +302,25 @@ Once the index is built, you can ask questions like:
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default Chrome profile (auto-finds all profiles)
|
||||
python examples/google_history_reader_leann.py
|
||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search academic research from your browsing history
|
||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||
|
||||
# Limit number of history entries processed (useful for testing)
|
||||
python examples/google_history_reader_leann.py --max-entries 500
|
||||
|
||||
# Run a single query
|
||||
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||
# Track competitor analysis across work profile
|
||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -268,7 +333,7 @@ The default Chrome profile path is configured for a typical macOS setup. If you
|
||||
1. Open Terminal
|
||||
2. Run: `ls ~/Library/Application\ Support/Google/Chrome/`
|
||||
3. Look for folders like "Default", "Profile 1", "Profile 2", etc.
|
||||
4. Use the full path as your `--chrome-profile` argumen
|
||||
4. Use the full path as your `--chrome-profile` argument
|
||||
|
||||
**Common Chrome profile locations:**
|
||||
- macOS: `~/Library/Application Support/Google/Chrome/Default`
|
||||
@@ -295,7 +360,7 @@ Once the index is built, you can ask questions like:
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
@@ -303,7 +368,13 @@ python examples/wechat_history_reader_leann.py --query "Show me all group chats
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
|
||||
First, you need to install the WeChat exporter:
|
||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||
|
||||
```bash
|
||||
brew install sunnyyoung/repo/wechattweak-cli
|
||||
```
|
||||
|
||||
or install it manually (if you have issues with Homebrew):
|
||||
|
||||
```bash
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
@@ -311,31 +382,29 @@ sudo packages/wechat-exporter/wechattweak-cli install
|
||||
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeCha
|
||||
```
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```bash
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default settings (recommended for first run)
|
||||
python examples/wechat_history_reader_leann.py
|
||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||
--force-export # Force re-export even if data exists
|
||||
```
|
||||
|
||||
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search for travel plans discussed in group chats
|
||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
|
||||
|
||||
# Limit number of chat entries processed (useful for testing)
|
||||
python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||
# Re-export and search recent chats (useful after new messages)
|
||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -349,15 +418,57 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||
|
||||
**Key features:**
|
||||
- 🔍 **Semantic code search** across your entire project
|
||||
- 📚 **Context-aware assistance** for debugging and development
|
||||
- 🚀 **Zero-config setup** with automatic language detection
|
||||
|
||||
```bash
|
||||
# Install LEANN globally for MCP integration
|
||||
uv tool install leann-core
|
||||
|
||||
# Setup is automatic - just start using Claude Code!
|
||||
```
|
||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||
|
||||

|
||||
|
||||
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
### Installation
|
||||
|
||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
source .venv/bin/activate
|
||||
leann --help
|
||||
```
|
||||
|
||||
**To make it globally available:**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann-core
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
```
|
||||
|
||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# build from a specific directory, and my_docs is the index name
|
||||
leann build my-docs --docs ./your_documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
@@ -366,7 +477,7 @@ leann search my-docs "machine learning concepts"
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# List all your indexes
|
||||
leann lis
|
||||
leann list
|
||||
```
|
||||
|
||||
**Key CLI features:**
|
||||
@@ -431,13 +542,17 @@ Options:
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||
**Backends:**
|
||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||
|
||||
## Benchmarks
|
||||
|
||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||
|
||||
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||
### Storage Comparison
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||
|
||||
### 📊 Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
@@ -451,8 +566,7 @@ Options:
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR datase
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
@@ -490,7 +604,12 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||
|
||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
|
||||
0
apps/__init__.py
Normal file
0
apps/__init__.py
Normal file
324
apps/base_rag_example.py
Normal file
324
apps/base_rag_example.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Base class for unified RAG examples interface.
|
||||
Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class BaseRAGExample(ABC):
|
||||
"""Base class for all RAG examples with unified interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
default_index_name: str,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.default_index_name = default_index_name
|
||||
self.parser = self._create_parser()
|
||||
|
||||
def _create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create argument parser with common parameters."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
# Core parameters (all examples share these)
|
||||
core_group = parser.add_argument_group("Core Parameters")
|
||||
core_group.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default=f"./{self.default_index_name}",
|
||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Query to run (if not provided, will run in interactive mode)",
|
||||
)
|
||||
# Allow subclasses to override default max_items
|
||||
max_items_default = getattr(self, "max_items_default", -1)
|
||||
core_group.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=max_items_default,
|
||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||
)
|
||||
|
||||
# Embedding parameters
|
||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||
# Allow subclasses to override default embedding_model
|
||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||
embedding_group.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=embedding_model_default,
|
||||
help=f"Embedding model to use (default: {embedding_model_default})",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
llm_group.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["openai", "ollama", "hf", "simulated"],
|
||||
help="LLM backend to use (default: openai)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
search_group.add_argument(
|
||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||
)
|
||||
search_group.add_argument(
|
||||
"--search-complexity",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Search complexity for graph traversal (default: 64)",
|
||||
)
|
||||
|
||||
# Index building parameters
|
||||
index_group = parser.add_argument_group("Index Building Parameters")
|
||||
index_group.add_argument(
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
help="Backend to use for index (default: hnsw)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--graph-degree",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Graph degree for index construction (default: 32)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--build-complexity",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Build complexity for index construction (default: 64)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-compact",
|
||||
action="store_true",
|
||||
help="Disable compact index storage",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-recompute",
|
||||
action="store_true",
|
||||
help="Disable embedding recomputation",
|
||||
)
|
||||
|
||||
# Add source-specific parameters
|
||||
self._add_specific_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add source-specific arguments. Override in subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
# Simulated LLM doesn't need additional configuration
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
is_recompute=not args.no_recompute,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
# Add texts in batches for better progress tracking
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
"""Run interactive chat with the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||
)
|
||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point for the example."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Check if index exists
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
index_exists = Path(args.index_dir).exists()
|
||||
|
||||
if not index_exists or args.force_rebuild:
|
||||
# Load data and build index
|
||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||
texts = await self.load_data(args)
|
||||
|
||||
if not texts:
|
||||
print("No data found to index!")
|
||||
return
|
||||
|
||||
index_path = await self.build_index(args, texts)
|
||||
else:
|
||||
print(f"\nUsing existing index in {args.index_dir}")
|
||||
|
||||
# Run query or interactive mode
|
||||
if args.query:
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
|
||||
|
||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||
"""Helper function to create text chunks from documents."""
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=" ",
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
return all_texts
|
||||
170
apps/browser_rag.py
Normal file
170
apps/browser_rag.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Browser History RAG example using the unified interface.
|
||||
Supports Chrome browser history.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
|
||||
class BrowserRAG(BaseRAGExample):
|
||||
"""RAG example for Chrome browser history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Browser History",
|
||||
description="Process and query Chrome browser history with LEANN",
|
||||
default_index_name="google_history_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add browser-specific arguments."""
|
||||
browser_group = parser.add_argument_group("Browser Parameters")
|
||||
browser_group.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _get_chrome_base_path(self) -> Path:
|
||||
"""Get the base Chrome profile path based on OS."""
|
||||
if sys.platform == "darwin":
|
||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||
elif sys.platform.startswith("linux"):
|
||||
return Path.home() / ".config" / "google-chrome"
|
||||
elif sys.platform == "win32":
|
||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
return []
|
||||
|
||||
profiles = []
|
||||
|
||||
# Check Default profile
|
||||
default_profile = base_path / "Default"
|
||||
if default_profile.exists() and (default_profile / "History").exists():
|
||||
profiles.append(default_profile)
|
||||
|
||||
# Check numbered profiles
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith("Profile "):
|
||||
if (item / "History").exists():
|
||||
profiles.append(item)
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
profile_dirs = [Path(args.chrome_profile)]
|
||||
else:
|
||||
print("Auto-detecting Chrome profiles...")
|
||||
profile_dirs = self._find_chrome_profiles()
|
||||
|
||||
# If specific profile given, filter to just that one
|
||||
if args.chrome_profile:
|
||||
profile_path = Path(args.chrome_profile)
|
||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found!")
|
||||
print("Please specify --chrome-profile manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
|
||||
# Create reader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
# Process each profile
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per profile
|
||||
max_per_profile = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_profile = remaining
|
||||
|
||||
# Load history
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_per_profile,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} history entries from this profile")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No browser history found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for browser history RAG
|
||||
print("\n🌐 Browser History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What websites did I visit about machine learning?'")
|
||||
print("- 'Find my search history about programming'")
|
||||
print("- 'What YouTube videos did I watch recently?'")
|
||||
print("- 'Show me websites about travel planning'")
|
||||
print("\nNote: Make sure Chrome is closed before running\n")
|
||||
|
||||
rag = BrowserRAG()
|
||||
asyncio.run(rag.run())
|
||||
108
apps/document_rag.py
Normal file
108
apps/document_rag.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Document RAG example using the unified interface.
|
||||
Supports PDF, TXT, MD, and other document formats.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory containing documents to index (default: data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
print(f"Filtering by file types: {args.file_types}")
|
||||
else:
|
||||
print("Processing all supported file types")
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What are the main techniques LEANN uses?'")
|
||||
print("- 'What is the technique DLPM?'")
|
||||
print("- 'Who does Elizabeth Bennet marry?'")
|
||||
print(
|
||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||
)
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -52,6 +52,11 @@ class EmlxReader(BaseReader):
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
total_files = 0
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
print(f"Starting to process directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
@@ -59,10 +64,12 @@ class EmlxReader(BaseReader):
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
# Check if we've reached the max count (skip if max_count == -1)
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
total_files += 1
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
@@ -98,17 +105,26 @@ class EmlxReader(BaseReader):
|
||||
and not self.include_html
|
||||
):
|
||||
continue
|
||||
body += part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# break
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding payload: {e}")
|
||||
continue
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding single part payload: {e}")
|
||||
body = ""
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
# Only create document if we have some content
|
||||
if body.strip() or subject != "No Subject":
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
@@ -118,18 +134,34 @@ class EmlxReader(BaseReader):
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
successful_files += 1
|
||||
|
||||
# Print first few successful files for debugging
|
||||
if successful_files <= 3:
|
||||
print(
|
||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
print("Processing summary:")
|
||||
print(f" Total .emlx files found: {total_files}")
|
||||
print(f" Successfully loaded: {successful_files}")
|
||||
print(f" Failed to load: {failed_files}")
|
||||
print(f" Final documents: {len(docs)}")
|
||||
|
||||
return docs
|
||||
156
apps/email_rag.py
Normal file
156
apps/email_rag.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
class EmailRAG(BaseRAGExample):
|
||||
"""RAG example for Apple Mail processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all emails by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Email",
|
||||
description="Process and query Apple Mail emails with LEANN",
|
||||
default_index_name="mail_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add email-specific arguments."""
|
||||
email_group = parser.add_argument_group("Email Parameters")
|
||||
email_group.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
return []
|
||||
|
||||
# Find all Messages directories
|
||||
messages_dirs = []
|
||||
for item in mail_base.rglob("Messages"):
|
||||
if item.is_dir():
|
||||
messages_dirs.append(item)
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
messages_dirs = [Path(args.mail_path)]
|
||||
else:
|
||||
print("Auto-detecting Apple Mail directories...")
|
||||
messages_dirs = self._find_mail_directories()
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Apple Mail directories found!")
|
||||
print("Please specify --mail-path manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(messages_dirs)} mail directories")
|
||||
|
||||
# Create reader
|
||||
reader = EmlxReader(include_html=args.include_html)
|
||||
|
||||
# Process each directory
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
# Count emlx files
|
||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||
print(f"Found {len(emlx_files)} email files")
|
||||
|
||||
# Apply max_items limit per directory
|
||||
max_per_dir = -1 # Default to process all
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_dir = remaining
|
||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||
|
||||
# Load emails - fix the parameter passing
|
||||
documents = reader.load_data(
|
||||
input_dir=str(messages_dir),
|
||||
max_count=max_per_dir,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} emails from this directory")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No emails found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
# Email reader uses chunk_overlap=25 as in original
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||
print(" Windows/Linux support coming soon!\n")
|
||||
|
||||
# Example queries for email RAG
|
||||
print("\n📧 Email RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did my boss say about deadlines?'")
|
||||
print("- 'Find emails about travel expenses'")
|
||||
print("- 'Show me emails from last month about the project'")
|
||||
print("- 'What food did I order from DoorDash?'")
|
||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||
|
||||
rag = EmailRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -97,6 +97,11 @@ class ChromeHistoryReader(BaseReader):
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
# add you may need to close your browser to make the database file available
|
||||
# also highlight in red
|
||||
print(
|
||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||
)
|
||||
return docs
|
||||
|
||||
return docs
|
||||
@@ -411,8 +411,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||
include_non_text = load_kwargs.get("include_non_text", False)
|
||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||
load_kwargs.get("max_length", 1000)
|
||||
load_kwargs.get("time_window_minutes", 30)
|
||||
max_length = load_kwargs.get("max_length", 1000)
|
||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||
|
||||
# Default WeChat export path
|
||||
if wechat_export_dir is None:
|
||||
@@ -460,9 +460,9 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0, # Keep 2 messages overlap between groups
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=0, # No overlap between groups
|
||||
)
|
||||
|
||||
# Create documents from concatenated groups
|
||||
@@ -532,7 +532,9 @@ Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
@@ -560,8 +562,8 @@ Message: {readable_text if readable_text else message_text}
|
||||
|
||||
# Look for common export directory names
|
||||
possible_dirs = [
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export"),
|
||||
]
|
||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
WeChat History RAG example using the unified interface.
|
||||
Supports WeChat chat history export and search.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
|
||||
from .history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
|
||||
class WeChatRAG(BaseRAGExample):
|
||||
"""RAG example for WeChat chat history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Match original default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="WeChat History",
|
||||
description="Process and query WeChat chat history with LEANN",
|
||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add WeChat-specific arguments."""
|
||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||
wechat_group.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default="./wechat_export",
|
||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||
)
|
||||
|
||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||
"""Export WeChat data using wechattweak-cli."""
|
||||
print("Exporting WeChat data...")
|
||||
|
||||
# Check if WeChat is running
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("WeChat is not running. Please start WeChat first.")
|
||||
return False
|
||||
except Exception:
|
||||
pass # pgrep might not be available on all systems
|
||||
|
||||
# Create export directory
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run export command
|
||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||
|
||||
try:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("WeChat data exported successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"Export failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print("\nError: wechattweak-cli not found!")
|
||||
print("Please install it first:")
|
||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||
# Try to find any existing exports in common locations
|
||||
export_dirs = reader.find_wechat_export_dirs()
|
||||
if not export_dirs:
|
||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||
return []
|
||||
|
||||
# Load documents from all found export directories
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per export
|
||||
max_per_export = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_export = remaining
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_per_export,
|
||||
concatenate_messages=True, # Enable message concatenation for better context
|
||||
)
|
||||
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return []
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks with contact information
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
text_splitter = SentenceSplitter(
|
||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
# Add contact information to each chunk
|
||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||
print(" You can still query existing exports on other platforms\n")
|
||||
|
||||
# Example queries for WeChat RAG
|
||||
print("\n💬 WeChat History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
rag = WeChatRAG()
|
||||
asyncio.run(rag.run())
|
||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
@@ -1,9 +1,24 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
# 🧪 LEANN Benchmarks & Testing
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `diskann_vs_hnsw_speed_comparison.py`
|
||||
Performance comparison between DiskANN and HNSW backends:
|
||||
- ✅ **Search latency** comparison with both backends using recompute
|
||||
- ✅ **Index size** and **build time** measurements
|
||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||
- ✅ **Configurable dataset sizes** for different scales
|
||||
|
||||
```bash
|
||||
# Quick comparison with 500 docs, 10 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||
|
||||
# Large-scale comparison with 2000 docs, 20 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||
```
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
@@ -62,7 +62,7 @@ def test_faiss_hnsw():
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/faiss_only.py"],
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
@@ -115,7 +115,7 @@ def test_leann_hnsw():
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
268
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
268
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN vs HNSW Search Performance Comparison
|
||||
|
||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||
- HNSW: With recompute enabled (is_recompute=True)
|
||||
- Tests performance across different dataset sizes
|
||||
- Measures search latency, recall, and index size
|
||||
"""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_test_texts(n_docs: int) -> list[str]:
|
||||
"""Create synthetic test documents for benchmarking."""
|
||||
np.random.seed(42)
|
||||
topics = [
|
||||
"machine learning and artificial intelligence",
|
||||
"natural language processing and text analysis",
|
||||
"computer vision and image recognition",
|
||||
"data science and statistical analysis",
|
||||
"deep learning and neural networks",
|
||||
"information retrieval and search engines",
|
||||
"database systems and data management",
|
||||
"software engineering and programming",
|
||||
"cybersecurity and network protection",
|
||||
"cloud computing and distributed systems",
|
||||
]
|
||||
|
||||
texts = []
|
||||
for i in range(n_docs):
|
||||
topic = topics[i % len(topics)]
|
||||
variation = np.random.randint(1, 100)
|
||||
text = (
|
||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||
f"Additional information about {topic} with details and examples. "
|
||||
f"Technical discussion of {topic} including implementation aspects."
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def benchmark_backend(
|
||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark a specific backend with the given configuration."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||
|
||||
# Build index
|
||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||
start_time = time.time()
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
build_time = time.time() - start_time
|
||||
|
||||
# Measure index size
|
||||
index_dir = Path(index_path).parent
|
||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||
|
||||
# Search benchmark
|
||||
print("🔍 Running search benchmark...")
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
search_times = []
|
||||
all_results = []
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
all_results.append(results)
|
||||
|
||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||
|
||||
# Check for valid scores (detect -inf issues)
|
||||
all_scores = [
|
||||
result.score
|
||||
for results in all_results
|
||||
for result in results
|
||||
if result.score is not None
|
||||
]
|
||||
valid_scores = [
|
||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||
]
|
||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||
|
||||
# Clean up
|
||||
try:
|
||||
if hasattr(searcher, "__del__"):
|
||||
searcher.__del__()
|
||||
del searcher
|
||||
del builder
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||
|
||||
return {
|
||||
"build_time": build_time,
|
||||
"avg_search_time_ms": avg_search_time,
|
||||
"index_size_mb": size_mb,
|
||||
"score_validity_rate": score_validity_rate,
|
||||
}
|
||||
|
||||
|
||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||
"""Run performance comparison between DiskANN and HNSW."""
|
||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||
|
||||
# Create test data
|
||||
texts = create_test_texts(n_docs)
|
||||
test_queries = [
|
||||
"machine learning algorithms",
|
||||
"natural language processing",
|
||||
"computer vision techniques",
|
||||
"data analysis methods",
|
||||
"neural network architectures",
|
||||
"database query optimization",
|
||||
"software development practices",
|
||||
"security vulnerabilities",
|
||||
"cloud infrastructure",
|
||||
"distributed computing",
|
||||
][:n_queries]
|
||||
|
||||
# HNSW benchmark
|
||||
hnsw_results = benchmark_backend(
|
||||
backend_name="hnsw",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable recompute for fair comparison
|
||||
"M": 16,
|
||||
"efConstruction": 200,
|
||||
},
|
||||
)
|
||||
|
||||
# DiskANN benchmark
|
||||
diskann_results = benchmark_backend(
|
||||
backend_name="diskann",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable graph partitioning
|
||||
"num_neighbors": 32,
|
||||
"search_list_size": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# Performance comparison
|
||||
print("\n📈 Performance Comparison Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||
print(f"{'-' * 60}")
|
||||
|
||||
# Build time comparison
|
||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||
print(
|
||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Search time comparison
|
||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||
print(
|
||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Index size comparison
|
||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||
print(
|
||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||
)
|
||||
|
||||
# Score validity
|
||||
print(
|
||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print("\n🎯 Summary:")
|
||||
if search_speedup > 1:
|
||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||
else:
|
||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||
|
||||
if size_ratio > 1:
|
||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||
else:
|
||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||
|
||||
print(
|
||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Handle help request
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||
print()
|
||||
print("Arguments:")
|
||||
print(" n_docs Number of documents to index (default: 500)")
|
||||
print(" n_queries Number of test queries to run (default: 10)")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse command line arguments
|
||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||
print()
|
||||
|
||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Benchmark interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Benchmark failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Ensure clean exit
|
||||
try:
|
||||
gc.collect()
|
||||
print("\n🧹 Cleanup completed")
|
||||
except Exception:
|
||||
pass
|
||||
sys.exit(0)
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -200,10 +200,10 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
@@ -279,7 +279,9 @@ def main():
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print("Please check the data/indices/ directory or provide --index-path manually.")
|
||||
print(
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
@@ -1,44 +0,0 @@
|
||||
---
|
||||
license: mit
|
||||
---
|
||||
|
||||
# LEANN-RAG Evaluation Data
|
||||
|
||||
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||
|
||||
## Dataset Components
|
||||
|
||||
This dataset is structured into three main parts:
|
||||
|
||||
1. **Pre-built LEANN Indices**:
|
||||
* `dpr/`: A pre-built index for the DPR dataset.
|
||||
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||
|
||||
2. **Ground Truth Data**:
|
||||
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||
|
||||
3. **Queries**:
|
||||
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
```
|
||||
|
||||
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir="data"
|
||||
)
|
||||
```
|
||||
|
||||
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Thinking Budget Feature Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||
|
||||
## Feature Description
|
||||
|
||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||
- **`medium`**: Balanced speed and reasoning depth
|
||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### 1. Command Line Interface
|
||||
|
||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||
|
||||
```bash
|
||||
# LEANN CLI
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# RAG Examples
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||
```
|
||||
|
||||
### 2. LLM Backend Support
|
||||
|
||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
options = kwargs.copy()
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget:
|
||||
options.pop("thinking_budget", None)
|
||||
if thinking_budget in ["low", "medium", "high"]:
|
||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||
```
|
||||
|
||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||
|
||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||
# Check if this is an o-series model
|
||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||
if any(model in self.model for model in o_series_models):
|
||||
params["reasoning_effort"] = thinking_budget
|
||||
```
|
||||
|
||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||
|
||||
### 3. Parameter Propagation
|
||||
|
||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||
|
||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||
|
||||
## Files Modified
|
||||
|
||||
### Core Implementation
|
||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||
|
||||
### Documentation
|
||||
- `README.md`: Added thinking budget parameter to usage examples
|
||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||
|
||||
### Examples
|
||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# High reasoning effort for complex questions
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Medium reasoning for balanced performance
|
||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||
|
||||
# Low reasoning for fast responses
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||
```
|
||||
|
||||
### RAG Examples
|
||||
```bash
|
||||
# Email RAG with high reasoning
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Document RAG with medium reasoning
|
||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Ollama Models
|
||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||
|
||||
### OpenAI Models
|
||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||
|
||||
## Testing
|
||||
|
||||
The implementation includes comprehensive testing:
|
||||
- Parameter handling verification
|
||||
- Backend-specific API format validation
|
||||
- CLI argument parsing tests
|
||||
- Integration with existing LEANN architecture
|
||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Comparison between Sentence Transformers and OpenAI embeddings
|
||||
|
||||
This example shows how different embedding models handle complex queries
|
||||
and demonstrates the differences between local and API-based embeddings.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
# OpenAI API key should be set as environment variable
|
||||
# export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Test data
|
||||
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||
|
||||
# Two queries with same intent but different wording
|
||||
query1 = "Tell me my browser history about some conference i often visit"
|
||||
query2 = "browser history about conference I often visit"
|
||||
|
||||
texts = [query1, query2, conference_text, browser_text]
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) # Already normalized
|
||||
|
||||
|
||||
def analyze_embeddings(embeddings, model_name):
|
||||
print(f"\n=== {model_name} Results ===")
|
||||
|
||||
# Results for Query 1
|
||||
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||
|
||||
print(f"Query 1: '{query1}'")
|
||||
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||
|
||||
# Results for Query 2
|
||||
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||
|
||||
print(f"\nQuery 2: '{query2}'")
|
||||
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||
|
||||
# Show the impact
|
||||
print("\n=== Impact Analysis ===")
|
||||
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||
|
||||
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||
print("✅ STABLE: Conference remains winner in both queries")
|
||||
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||
print("✅ STABLE: Browser remains winner in both queries")
|
||||
else:
|
||||
print("🔄 MIXED: Results vary between queries")
|
||||
|
||||
return {
|
||||
"query1_conf": sim1_conf,
|
||||
"query1_browser": sim1_browser,
|
||||
"query2_conf": sim2_conf,
|
||||
"query2_browser": sim2_browser,
|
||||
}
|
||||
|
||||
|
||||
# Test Sentence Transformers
|
||||
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||
try:
|
||||
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||
except Exception as e:
|
||||
print(f"❌ Sentence Transformers failed: {e}")
|
||||
st_results = None
|
||||
|
||||
# Test OpenAI
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing OpenAI (text-embedding-3-small)...")
|
||||
try:
|
||||
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||
except Exception as e:
|
||||
print(f"❌ OpenAI failed: {e}")
|
||||
openai_results = None
|
||||
|
||||
# Compare results
|
||||
if st_results and openai_results:
|
||||
print("\n" + "=" * 60)
|
||||
print("=== COMPARISON SUMMARY ===")
|
||||
300
docs/configuration-guide.md
Normal file
300
docs/configuration-guide.md
Normal file
@@ -0,0 +1,300 @@
|
||||
# LEANN Configuration Guide
|
||||
|
||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||
|
||||
## Getting Started: Simple is Better
|
||||
|
||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||
|
||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||
```bash
|
||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||
```
|
||||
|
||||
**For other data sources**: Limit the dataset size for quick testing
|
||||
```bash
|
||||
# WeChat: Test with recent messages only
|
||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||
|
||||
# Browser history: Last few days
|
||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||
|
||||
# Email: Recent inbox
|
||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||
```
|
||||
|
||||
Once validated, scale up gradually:
|
||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||
- This helps identify issues early before committing to long processing times
|
||||
|
||||
## Embedding Model Selection: Understanding the Trade-offs
|
||||
|
||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||
|
||||
### Small Models (< 100M parameters)
|
||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||
- **Pros**: Lightweight, fast for both indexing and inference
|
||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||
|
||||
### Medium Models (100M-500M parameters)
|
||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||
- **Cons**: Requires more compute than small models
|
||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||
|
||||
### Large Models (500M+ parameters)
|
||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||
- **Cons**: Slower inference, longer index build times
|
||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||
|
||||
### Quick Start: Cloud and Local Embedding Options
|
||||
|
||||
**OpenAI Embeddings (Fastest Setup)**
|
||||
For immediate testing without local model downloads:
|
||||
```bash
|
||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
**Ollama Embeddings (Privacy-Focused)**
|
||||
For local embeddings with complete privacy:
|
||||
```bash
|
||||
# First, pull an embedding model
|
||||
ollama pull nomic-embed-text
|
||||
|
||||
# Use Ollama embeddings
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||
|
||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||
- **Pros**: No local compute needed, consistently fast, high quality
|
||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||
|
||||
**Local Embeddings**
|
||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||
|
||||
</details>
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||
- Full recomputation required
|
||||
- High memory usage during build phase
|
||||
- Excellent recall (95%+)
|
||||
|
||||
```bash
|
||||
# Optimal for most use cases
|
||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||
```
|
||||
|
||||
### DiskANN
|
||||
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
|
||||
|
||||
**How it works:**
|
||||
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
|
||||
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
|
||||
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
|
||||
|
||||
**Trade-offs compared to HNSW:**
|
||||
- ✅ **Faster search latency** (typically 2-8x speedup)
|
||||
- ✅ **Better scaling** for large datasets
|
||||
- ✅ **Smart storage management** with automatic partitioning
|
||||
- ✅ **Better graph locality** with `--ldg-times` parameter for SSD optimization
|
||||
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
|
||||
|
||||
```bash
|
||||
# Recommended for most use cases
|
||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||
|
||||
# For large-scale deployments
|
||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
||||
```
|
||||
|
||||
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||
|
||||
## LLM Selection: Engine and Model Comparison
|
||||
|
||||
### LLM Engines
|
||||
|
||||
**OpenAI** (`--llm openai`)
|
||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||
|
||||
**Ollama** (`--llm ollama`)
|
||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||
|
||||
**HuggingFace** (`--llm hf`)
|
||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||
- **Cons**: More complex initial setup
|
||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||
|
||||
## Parameter Tuning Guide
|
||||
|
||||
### Search Complexity Parameters
|
||||
|
||||
**`--build-complexity`** (index building)
|
||||
- Controls thoroughness during index construction
|
||||
- Higher = better recall but slower build
|
||||
- Recommendations:
|
||||
- 32: Quick prototyping
|
||||
- 64: Balanced (default)
|
||||
- 128: Production systems
|
||||
- 256: Maximum quality
|
||||
|
||||
**`--search-complexity`** (query time)
|
||||
- Controls search thoroughness
|
||||
- Higher = better results but slower
|
||||
- Recommendations:
|
||||
- 16: Fast/Interactive search
|
||||
- 32: High quality with diversity
|
||||
- 64+: Maximum accuracy
|
||||
|
||||
### Top-K Selection
|
||||
|
||||
**`--top-k`** (number of retrieved chunks)
|
||||
- More chunks = better context but slower LLM processing
|
||||
- Should be always smaller than `--search-complexity`
|
||||
- Guidelines:
|
||||
- 10-20: General questions (default: 20)
|
||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||
|
||||
**Trade-off formula**:
|
||||
- Retrieval time ∝ log(n) × search_complexity
|
||||
- LLM processing time ∝ top_k × chunk_size
|
||||
- Total context = top_k × chunk_size tokens
|
||||
|
||||
### Thinking Budget for Reasoning Models
|
||||
|
||||
**`--thinking-budget`** (reasoning effort level)
|
||||
- Controls the computational effort for reasoning models
|
||||
- Options: `low`, `medium`, `high`
|
||||
- Guidelines:
|
||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||
- `medium`: Balanced speed and reasoning depth
|
||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||
- **Supported Models**:
|
||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||
|
||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||
|
||||
**💡 Quick Examples:**
|
||||
```bash
|
||||
# OpenAI o-series reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm openai --llm-model o3 --thinking-budget medium
|
||||
|
||||
# Ollama reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
```
|
||||
|
||||
### Graph Degree (HNSW/DiskANN)
|
||||
|
||||
**`--graph-degree`**
|
||||
- Number of connections per node in the graph
|
||||
- Higher = better recall but more memory
|
||||
- HNSW: 16-32 (default: 32)
|
||||
- DiskANN: 32-128 (default: 64)
|
||||
|
||||
|
||||
## Performance Optimization Checklist
|
||||
|
||||
### If Embedding is Too Slow
|
||||
|
||||
1. **Switch to smaller model**:
|
||||
```bash
|
||||
# From large model
|
||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||
# To small model
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
2. **Limit dataset size for testing**:
|
||||
```bash
|
||||
--max-items 1000 # Process first 1k items only
|
||||
```
|
||||
|
||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||
```bash
|
||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||
```
|
||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||
|
||||
4. **Use Ollama**
|
||||
```bash
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||
### If Search Quality is Poor
|
||||
|
||||
1. **Increase retrieval count**:
|
||||
```bash
|
||||
--top-k 30 # Retrieve more candidates
|
||||
```
|
||||
|
||||
2. **Upgrade embedding model**:
|
||||
```bash
|
||||
# For English
|
||||
--embedding-model BAAI/bge-base-en-v1.5
|
||||
# For multilingual
|
||||
--embedding-model intfloat/multilingual-e5-large
|
||||
```
|
||||
|
||||
## Understanding the Trade-offs
|
||||
|
||||
Every configuration choice involves trade-offs:
|
||||
|
||||
| Factor | Small/Fast | Large/Quality |
|
||||
|--------|------------|---------------|
|
||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||
| Chunk Size | 512 tokens | 128 tokens |
|
||||
| Index Type | HNSW | DiskANN |
|
||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||
|
||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||
|
||||
## Deep Dive: Critical Configuration Decisions
|
||||
|
||||
### When to Disable Recomputation
|
||||
|
||||
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
|
||||
|
||||
```bash
|
||||
--no-recompute # Disable selective recomputation
|
||||
```
|
||||
|
||||
**Trade-offs**:
|
||||
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
|
||||
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
|
||||
|
||||
**Disable when**:
|
||||
- You have abundant storage and memory
|
||||
- Need extremely low latency (< 100ms)
|
||||
- Running a read-heavy workload where storage cost is acceptable
|
||||
|
||||
## Further Reading
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
@@ -5,7 +5,7 @@
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||
|
||||
## 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
@@ -13,7 +13,7 @@
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
|
||||
@@ -72,4 +72,4 @@ Using the wrong distance metric with normalized embeddings can lead to:
|
||||
- **Incorrect ranking** of search results
|
||||
- **Suboptimal performance** compared to using the correct metric
|
||||
|
||||
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).
|
||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
## 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
Run: uv run python examples/basic_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -81,7 +81,7 @@ def main():
|
||||
print()
|
||||
|
||||
print("Demo completed! Try running:")
|
||||
print(" uv run python examples/document_search.py")
|
||||
print(" uv run python apps/document_rag.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,158 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Document search demo with recompute mode
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Import backend packages to trigger plugin registration
|
||||
try:
|
||||
import leann_backend_diskann # noqa: F401
|
||||
import leann_backend_hnsw # noqa: F401
|
||||
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# Import upper-level API from leann-core
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def load_sample_documents():
|
||||
"""Create sample documents for demonstration"""
|
||||
docs = [
|
||||
{
|
||||
"title": "Intro to Python",
|
||||
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||
},
|
||||
{
|
||||
"title": "ML Basics",
|
||||
"content": "Machine learning builds systems that learn from data.",
|
||||
},
|
||||
{
|
||||
"title": "Data Structures",
|
||||
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||
},
|
||||
]
|
||||
return docs
|
||||
|
||||
|
||||
def main():
|
||||
print("==========================================================")
|
||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||
print("==========================================================")
|
||||
|
||||
INDEX_DIR = Path("./test_indices")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
BACKEND_TO_TEST = "diskann"
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# --- 1. Build index ---
|
||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||
|
||||
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
|
||||
|
||||
documents = load_sample_documents()
|
||||
print(f"Loaded {len(documents)} sample documents.")
|
||||
for doc in documents:
|
||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("\nIndex built!")
|
||||
|
||||
# --- 2. Basic search demo ---
|
||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||
|
||||
query = "What is machine learning?"
|
||||
print(f"\nQuery: '{query}'")
|
||||
|
||||
print("\n--- Basic search mode (PQ computation) ---")
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=2)
|
||||
basic_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(
|
||||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||
)
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print("\n[PHASE 3] Recompute search using embedding server...")
|
||||
|
||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||
|
||||
# Configure recompute parameters
|
||||
recompute_params = {
|
||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||
"skip_search_reorder": True, # Skip search reordering
|
||||
"dedup_node_dis": True, # Enable node distance deduplication
|
||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||
"batch_recompute": False, # Don't use batch recomputation
|
||||
"global_pruning": False, # Don't use global pruning
|
||||
"zmq_port": 5555, # ZMQ port
|
||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||||
}
|
||||
|
||||
print("Recompute parameter configuration:")
|
||||
for key, value in recompute_params.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\n🔄 Executing Recompute search...")
|
||||
try:
|
||||
start_time = time.time()
|
||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||
recompute_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(
|
||||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||
)
|
||||
|
||||
# Compare results
|
||||
print("\n--- Result comparison ---")
|
||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i].score
|
||||
recompute_score = recompute_results[i].score
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(
|
||||
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
|
||||
)
|
||||
|
||||
if recompute_time > basic_time:
|
||||
print("✅ Recompute mode working correctly (more accurate but slower)")
|
||||
else:
|
||||
print("i️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Recompute search failed: {e}")
|
||||
print("This usually indicates an embedding server connection issue")
|
||||
|
||||
# --- 4. Chat demo ---
|
||||
print("\n[PHASE 4] Starting chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
chat_response = chat.ask(query)
|
||||
print(f"You: {query}")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
print("\n==========================================================")
|
||||
print("✅ Demo finished successfully!")
|
||||
print("==========================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,324 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
# python-dotenv is not installed; skip loading environment variables
|
||||
dotenv = None
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||
|
||||
# Default Chrome profile path
|
||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs: list[Path],
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
|
||||
Args:
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Chrome profile directory
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing Chrome profile {i + 1}/{len(profile_dirs)}: {profile_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir), max_count=max_count
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print(
|
||||
"\033[91mYou need to close or quit all chrome browser before running this script\033[0m"
|
||||
)
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
profile_path: str | None = None,
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from Chrome history data.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process
|
||||
"""
|
||||
print("Creating LEANN index from Chrome history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
documents = reader.load_data(chrome_profile_path=profile_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} history documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN Chrome History Reader - Create and query browser history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=DEFAULT_CHROME_PROFILE,
|
||||
help=f"Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./google_history_index",
|
||||
help="Directory to store the LEANN index (default: ./chrome_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of history entries to process (default: 1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||
|
||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs, INDEX_PATH, args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"What websites did I visit about machine learning?",
|
||||
"Find my search history about programming",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,342 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# Auto-detect user's mail path
|
||||
def get_mail_path():
|
||||
"""Get the mail path for the current user"""
|
||||
home_dir = os.path.expanduser("~")
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
|
||||
# Default mail path for macOS
|
||||
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_sources(
|
||||
messages_dirs: list[Path],
|
||||
index_path: str = "mail_index.leann",
|
||||
max_count: int = -1,
|
||||
include_html: bool = False,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
mail_path: str,
|
||||
index_path: str = "mail_index.leann",
|
||||
max_count: int = 1000,
|
||||
include_html: bool = False,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from mail data.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from mail data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
documents = reader.load_data(Path(mail_path))
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path, llm_config={"type": "openai", "model": "gpt-4o"})
|
||||
|
||||
print(f"You: {query}")
|
||||
import time
|
||||
|
||||
time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
)
|
||||
time.time()
|
||||
# print(f"Time taken: {end_time - start_time} seconds")
|
||||
# highlight the answer
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="LEANN Mail Reader - Create and query email index")
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./mail_index",
|
||||
help="Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of emails to process (-1 means all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Give me some funny advertisement about apple or other companies",
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include HTML content in email processing (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Embedding model to use (default: facebook/contriever)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"args: {args}")
|
||||
|
||||
# Automatically find all Messages directories under the current user's Mail directory
|
||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||
# messages_dirs = messages_dirs[:1]
|
||||
|
||||
print("len(messages_dirs): ", len(messages_dirs))
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(
|
||||
messages_dirs,
|
||||
INDEX_PATH,
|
||||
args.max_emails,
|
||||
args.include_html,
|
||||
args.embedding_model,
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,135 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import torch
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# --- EMBEDDING MODEL ---
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
# --- END EMBEDDING MODEL ---
|
||||
# Import EmlxReader from the new module
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
def create_and_save_index(
|
||||
mail_path: str,
|
||||
save_dir: str = "mail_index_embedded",
|
||||
max_count: int = 1000,
|
||||
include_html: bool = False,
|
||||
):
|
||||
print("Creating index from mail data with embedded metadata...")
|
||||
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
# Use facebook/contriever as the embedder
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
# set on device
|
||||
|
||||
if torch.cuda.is_available():
|
||||
embed_model._model.to("cuda")
|
||||
# set mps
|
||||
elif torch.backends.mps.is_available():
|
||||
embed_model._model.to("mps")
|
||||
else:
|
||||
embed_model._model.to("cpu")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[text_splitter], embed_model=embed_model
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index_embedded"):
|
||||
try:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LlamaIndex Mail Reader - Create and query email index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
help="Path to mail data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-dir",
|
||||
type=str,
|
||||
default="mail_index_embedded",
|
||||
help="Directory to store the index (default: mail_index_embedded)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Maximum number of emails to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include HTML content in email processing (default: False)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mail_path = args.mail_path
|
||||
save_dir = args.save_dir
|
||||
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(
|
||||
mail_path,
|
||||
save_dir,
|
||||
max_count=args.max_emails,
|
||||
include_html=args.include_html,
|
||||
)
|
||||
if index:
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,136 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# LeannBuilder now automatically detects normalized embeddings and sets appropriate distance metric
|
||||
print(f"Using {args.embedding_model} with {args.embedding_mode} mode")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
# distance_metric is automatically set based on embedding model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
query = args.query
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||
parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="The LLM backend to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
help="The embedding backend mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The host for the Ollama API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||
help="The query to ask the Leann chat system.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
@@ -1,360 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Vector Aggregator for Fat Embeddings
|
||||
==========================================
|
||||
|
||||
This module implements aggregation strategies for multi-vector embeddings,
|
||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||
|
||||
Key features:
|
||||
- MaxSim aggregation (take maximum similarity across patches)
|
||||
- Voting-based aggregation (count patch matches)
|
||||
- Weighted aggregation (attention-score weighted)
|
||||
- Spatial clustering of matching patches
|
||||
- Document-level result consolidation
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Represents a single patch search result."""
|
||||
|
||||
patch_id: int
|
||||
image_name: str
|
||||
image_path: str
|
||||
coordinates: tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
score: float
|
||||
attention_score: float
|
||||
scale: float
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedResult:
|
||||
"""Represents an aggregated document-level result."""
|
||||
|
||||
image_name: str
|
||||
image_path: str
|
||||
doc_score: float
|
||||
patch_count: int
|
||||
best_patch: PatchResult
|
||||
all_patches: list[PatchResult]
|
||||
aggregation_method: str
|
||||
spatial_clusters: list[list[PatchResult]] | None = None
|
||||
|
||||
|
||||
class MultiVectorAggregator:
|
||||
"""
|
||||
Aggregates multiple patch-level results into document-level results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aggregation_method: str = "maxsim",
|
||||
spatial_clustering: bool = True,
|
||||
cluster_distance_threshold: float = 100.0,
|
||||
):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||
spatial_clustering: Whether to cluster spatially close patches
|
||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||
"""
|
||||
self.aggregation_method = aggregation_method
|
||||
self.spatial_clustering = spatial_clustering
|
||||
self.cluster_distance_threshold = cluster_distance_threshold
|
||||
|
||||
def aggregate_results(
|
||||
self, search_results: list[dict[str, Any]], top_k: int = 10
|
||||
) -> list[AggregatedResult]:
|
||||
"""
|
||||
Aggregate patch-level search results into document-level results.
|
||||
|
||||
Args:
|
||||
search_results: List of search results from LeannSearcher
|
||||
top_k: Number of top documents to return
|
||||
|
||||
Returns:
|
||||
List of aggregated document results
|
||||
"""
|
||||
# Group results by image
|
||||
image_groups = defaultdict(list)
|
||||
|
||||
for result in search_results:
|
||||
metadata = result.metadata
|
||||
if "image_name" in metadata and "patch_id" in metadata:
|
||||
patch_result = PatchResult(
|
||||
patch_id=metadata["patch_id"],
|
||||
image_name=metadata["image_name"],
|
||||
image_path=metadata["image_path"],
|
||||
coordinates=tuple(metadata["coordinates"]),
|
||||
score=result.score,
|
||||
attention_score=metadata.get("attention_score", 0.0),
|
||||
scale=metadata.get("scale", 1.0),
|
||||
metadata=metadata,
|
||||
)
|
||||
image_groups[metadata["image_name"]].append(patch_result)
|
||||
|
||||
# Aggregate each image group
|
||||
aggregated_results = []
|
||||
for image_name, patches in image_groups.items():
|
||||
if len(patches) == 0:
|
||||
continue
|
||||
|
||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||
aggregated_results.append(agg_result)
|
||||
|
||||
# Sort by aggregated score and return top-k
|
||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||
return aggregated_results[:top_k]
|
||||
|
||||
def _aggregate_image_patches(
|
||||
self, image_name: str, patches: list[PatchResult]
|
||||
) -> AggregatedResult:
|
||||
"""Aggregate patches for a single image."""
|
||||
|
||||
if self.aggregation_method == "maxsim":
|
||||
doc_score = max(patch.score for patch in patches)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "voting":
|
||||
# Count patches above threshold
|
||||
threshold = np.percentile([p.score for p in patches], 75)
|
||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "weighted":
|
||||
# Weight by attention scores
|
||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||
total_weights = sum(p.attention_score for p in patches)
|
||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||
|
||||
elif self.aggregation_method == "mean":
|
||||
doc_score = np.mean([patch.score for patch in patches])
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||
|
||||
# Spatial clustering if enabled
|
||||
spatial_clusters = None
|
||||
if self.spatial_clustering:
|
||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||
|
||||
return AggregatedResult(
|
||||
image_name=image_name,
|
||||
image_path=patches[0].image_path,
|
||||
doc_score=float(doc_score),
|
||||
patch_count=len(patches),
|
||||
best_patch=best_patch,
|
||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||
aggregation_method=self.aggregation_method,
|
||||
spatial_clusters=spatial_clusters,
|
||||
)
|
||||
|
||||
def _cluster_patches_spatially(self, patches: list[PatchResult]) -> list[list[PatchResult]]:
|
||||
"""Cluster patches that are spatially close to each other."""
|
||||
if len(patches) <= 1:
|
||||
return [patches]
|
||||
|
||||
clusters = []
|
||||
remaining_patches = patches.copy()
|
||||
|
||||
while remaining_patches:
|
||||
# Start new cluster with highest scoring remaining patch
|
||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||
current_cluster = [seed_patch]
|
||||
remaining_patches.remove(seed_patch)
|
||||
|
||||
# Add nearby patches to cluster
|
||||
added_to_cluster = True
|
||||
while added_to_cluster:
|
||||
added_to_cluster = False
|
||||
for patch in remaining_patches.copy():
|
||||
if self._is_patch_nearby(patch, current_cluster):
|
||||
current_cluster.append(patch)
|
||||
remaining_patches.remove(patch)
|
||||
added_to_cluster = True
|
||||
|
||||
clusters.append(current_cluster)
|
||||
|
||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||
|
||||
def _is_patch_nearby(self, patch: PatchResult, cluster: list[PatchResult]) -> bool:
|
||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||
patch_center = self._get_patch_center(patch.coordinates)
|
||||
|
||||
for cluster_patch in cluster:
|
||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||
distance = np.sqrt(
|
||||
(patch_center[0] - cluster_center[0]) ** 2
|
||||
+ (patch_center[1] - cluster_center[1]) ** 2
|
||||
)
|
||||
|
||||
if distance <= self.cluster_distance_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_patch_center(self, coordinates: tuple[int, int, int, int]) -> tuple[float, float]:
|
||||
"""Get center point of a patch."""
|
||||
x1, y1, x2, y2 = coordinates
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def print_aggregated_results(
|
||||
self, results: list[AggregatedResult], max_patches_per_doc: int = 3
|
||||
):
|
||||
"""Pretty print aggregated results."""
|
||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||
print("=" * 80)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n{i + 1}. {result.image_name}")
|
||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||
print(f" Path: {result.image_path}")
|
||||
|
||||
# Show best patch
|
||||
best = result.best_patch
|
||||
print(
|
||||
f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})"
|
||||
)
|
||||
|
||||
# Show top patches
|
||||
print(" 📍 Top Patches:")
|
||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||
print(
|
||||
f" {j + 1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}"
|
||||
)
|
||||
|
||||
# Show spatial clusters if available
|
||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||
cluster_score = max(p.score for p in cluster)
|
||||
print(
|
||||
f" Cluster {j + 1}: {len(cluster)} patches (best: {cluster_score:.4f})"
|
||||
)
|
||||
|
||||
|
||||
def demo_aggregation():
|
||||
"""Demonstrate the multi-vector aggregation functionality."""
|
||||
print("=== Multi-Vector Aggregation Demo ===")
|
||||
|
||||
# Simulate some patch-level search results
|
||||
# In real usage, these would come from LeannSearcher.search()
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, score, metadata):
|
||||
self.score = score
|
||||
self.metadata = metadata
|
||||
|
||||
# Simulate results for 2 images with multiple patches each
|
||||
mock_results = [
|
||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||
MockResult(
|
||||
0.85,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 3,
|
||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||
"attention_score": 0.92,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.78,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 7,
|
||||
"coordinates": [200, 300, 324, 424], # Cat area
|
||||
"attention_score": 0.88,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.72,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 12,
|
||||
"coordinates": [150, 100, 274, 224], # Appliances
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.65,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 15,
|
||||
"coordinates": [50, 250, 174, 374], # Furniture
|
||||
"attention_score": 0.70,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
# Image 2: city_street.jpg - 3 patches
|
||||
MockResult(
|
||||
0.68,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 2,
|
||||
"coordinates": [300, 100, 424, 224], # Buildings
|
||||
"attention_score": 0.80,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.62,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 8,
|
||||
"coordinates": [100, 350, 224, 474], # Street level
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.55,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 11,
|
||||
"coordinates": [400, 200, 524, 324], # Sky area
|
||||
"attention_score": 0.60,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Test different aggregation methods
|
||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||
|
||||
for method in methods:
|
||||
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method,
|
||||
spatial_clustering=True,
|
||||
cluster_distance_threshold=100.0,
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
aggregator.print_aggregated_results(aggregated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_aggregation()
|
||||
@@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenAI Embedding Example
|
||||
|
||||
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
def main():
|
||||
# Check if OpenAI API key is available
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||
return False
|
||||
|
||||
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||
|
||||
# Sample texts
|
||||
sample_texts = [
|
||||
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||
"Natural language processing helps computers understand and generate human language.",
|
||||
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
||||
"Computer vision allows machines to interpret and understand visual information.",
|
||||
"Reinforcement learning trains agents to make decisions through trial and error.",
|
||||
"Data science combines statistics, math, and programming to extract insights from data.",
|
||||
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||
"Python is a popular programming language used extensively in data science and AI.",
|
||||
"Neural networks are inspired by the structure and function of the human brain.",
|
||||
"Big data refers to extremely large datasets that require special tools to process.",
|
||||
]
|
||||
|
||||
INDEX_DIR = Path("./simple_openai_test_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||
|
||||
print("\n=== Building Index with OpenAI Embeddings ===")
|
||||
print(f"Index path: {INDEX_PATH}")
|
||||
|
||||
try:
|
||||
# Use proper configuration for OpenAI embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# HNSW settings for OpenAI embeddings
|
||||
M=16, # Smaller graph degree
|
||||
efConstruction=64, # Smaller construction complexity
|
||||
is_compact=True, # Enable compact storage for recompute
|
||||
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||
for i, text in enumerate(sample_texts):
|
||||
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||
builder.add_text(text, metadata)
|
||||
|
||||
print("Building index...")
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("✅ Index built successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n=== Testing Search ===")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
|
||||
test_queries = [
|
||||
"What is machine learning?",
|
||||
"How do neural networks work?",
|
||||
"Programming languages for data science",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\n🔍 Query: '{query}'")
|
||||
results = searcher.search(query, top_k=3)
|
||||
|
||||
print(f" Found {len(results)} results:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i + 1}. Score: {result.score:.4f}")
|
||||
print(f" Text: {result.text[:80]}...")
|
||||
|
||||
print("\n✅ Search test completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during search: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
if success:
|
||||
print("\n🎉 Simple OpenAI index test completed successfully!")
|
||||
else:
|
||||
print("\n💥 Simple OpenAI index test failed!")
|
||||
@@ -1,23 +0,0 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannChat
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
|
||||
async def main():
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1
|
||||
)
|
||||
print(f"\n[PHASE 2] Response: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,320 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Default WeChat export directory
|
||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs: list[Path],
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple WeChat export data sources.
|
||||
|
||||
Args:
|
||||
export_dirs: List of Path objects pointing to WeChat export directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process per export
|
||||
"""
|
||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each WeChat export directory
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = (
|
||||
"[Contact] means the message is from: "
|
||||
+ doc.metadata["contact_name"]
|
||||
+ "\n"
|
||||
+ node.get_content()
|
||||
)
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
export_dir: str | None = None,
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from WeChat chat history data.
|
||||
|
||||
Args:
|
||||
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process
|
||||
"""
|
||||
print("Creating LEANN index from WeChat chat history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=export_dir,
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} chat documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function with integrated WeChat export functionality."""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||
|
||||
print(f"Using WeChat export directory: {args.export_dir}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Initialize WeChat reader with export capabilities
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1 +1,7 @@
|
||||
from . import diskann_backend as diskann_backend
|
||||
from . import graph_partition
|
||||
|
||||
# Export main classes and functions
|
||||
from .graph_partition import GraphPartitioner, partition_graph
|
||||
|
||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||
|
||||
@@ -4,9 +4,10 @@ import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
from leann.interface import (
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
@@ -84,6 +85,43 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||
f.write(data.tobytes())
|
||||
|
||||
|
||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||
|
||||
Args:
|
||||
data: The embedding data array
|
||||
|
||||
Returns:
|
||||
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
||||
"""
|
||||
num_vectors, dim = data.shape
|
||||
|
||||
# Calculate embedding storage size
|
||||
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
||||
embedding_size_gb = embedding_size_bytes / (1024**3)
|
||||
|
||||
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
||||
# This controls Product Quantization size - smaller means more compression
|
||||
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
||||
|
||||
# build_memory_maximum: Based on available system RAM for sharding control
|
||||
# This controls how much memory DiskANN uses during index construction
|
||||
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||
|
||||
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
||||
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
||||
|
||||
logger.info(
|
||||
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
||||
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
||||
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
||||
)
|
||||
|
||||
return search_memory_gb, build_memory_gb
|
||||
|
||||
|
||||
@register_backend("diskann")
|
||||
class DiskannBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -99,6 +137,71 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||
"""
|
||||
Safely cleanup files after partition.
|
||||
In partition mode, C++ doesn't read _disk.index content,
|
||||
so we can delete it if all derived files exist.
|
||||
"""
|
||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||
|
||||
# Required files that C++ partition mode needs
|
||||
# Note: C++ generates these with _disk.index suffix
|
||||
disk_suffix = "_disk.index"
|
||||
required_files = [
|
||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||
]
|
||||
|
||||
# Check if all required files exist
|
||||
missing_files = []
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if not file_path.exists():
|
||||
missing_files.append(filename)
|
||||
|
||||
if missing_files:
|
||||
logger.warning(
|
||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||
)
|
||||
logger.info("Keeping all original files for safety")
|
||||
return
|
||||
|
||||
# Calculate space savings
|
||||
space_saved = 0
|
||||
files_to_delete = []
|
||||
|
||||
if disk_index_file.exists():
|
||||
space_saved += disk_index_file.stat().st_size
|
||||
files_to_delete.append(disk_index_file)
|
||||
|
||||
if beam_search_file.exists():
|
||||
space_saved += beam_search_file.stat().st_size
|
||||
files_to_delete.append(beam_search_file)
|
||||
|
||||
# Safe to delete!
|
||||
for file_to_delete in files_to_delete:
|
||||
try:
|
||||
os.remove(file_to_delete)
|
||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||
|
||||
if space_saved > 0:
|
||||
space_saved_mb = space_saved / (1024 * 1024)
|
||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||
|
||||
# Show what files are kept
|
||||
logger.info("📁 Kept essential files for partition mode:")
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -113,6 +216,17 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
|
||||
# Extract is_recompute from nested backend_kwargs if needed
|
||||
is_recompute = build_kwargs.get("is_recompute", False)
|
||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||
|
||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||
if "backend_kwargs" in build_kwargs:
|
||||
nested_params = build_kwargs.pop("backend_kwargs")
|
||||
build_kwargs.update(nested_params)
|
||||
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
@@ -121,6 +235,16 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
|
||||
# Calculate smart memory configuration if not explicitly provided
|
||||
if (
|
||||
"search_memory_maximum" not in build_kwargs
|
||||
or "build_memory_maximum" not in build_kwargs
|
||||
):
|
||||
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
||||
else:
|
||||
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
||||
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
@@ -131,12 +255,36 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_prefix,
|
||||
build_kwargs.get("complexity", 64),
|
||||
build_kwargs.get("graph_degree", 32),
|
||||
build_kwargs.get("search_memory_maximum", 4.0),
|
||||
build_kwargs.get("build_memory_maximum", 8.0),
|
||||
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
||||
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
||||
build_kwargs.get("num_threads", 8),
|
||||
build_kwargs.get("pq_disk_bytes", 0),
|
||||
"",
|
||||
)
|
||||
|
||||
# Auto-partition if is_recompute is enabled
|
||||
if build_kwargs.get("is_recompute", False):
|
||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||
from .graph_partition import partition_graph
|
||||
|
||||
# Partition the index using absolute paths
|
||||
# Convert to absolute paths to avoid issues with working directory changes
|
||||
absolute_index_dir = Path(index_dir).resolve()
|
||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
index_prefix_path=absolute_index_prefix_path,
|
||||
output_dir=str(absolute_index_dir),
|
||||
partition_prefix=index_prefix,
|
||||
)
|
||||
|
||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||
|
||||
logger.info("✅ Graph partitioning completed successfully!")
|
||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||
logger.info(f" - Partition file: {partition_bin_path}")
|
||||
|
||||
finally:
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
@@ -165,7 +313,26 @@ class DiskannSearcher(BaseSearcher):
|
||||
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||
# C++ internally constructs: index_prefix + "_disk.index"
|
||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||
|
||||
# Auto-detect partition files and set partition_prefix
|
||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||
|
||||
partition_prefix = ""
|
||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||
# C++ expects full path prefix, not just filename
|
||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
else:
|
||||
logger.debug("No partition files detected, using standard index files")
|
||||
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
@@ -173,8 +340,14 @@ class DiskannSearcher(BaseSearcher):
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
}
|
||||
|
||||
# Log partition configuration for debugging
|
||||
if partition_prefix:
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
@@ -211,7 +384,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
@@ -264,6 +437,8 @@ class DiskannSearcher(BaseSearcher):
|
||||
use_global_pruning = True
|
||||
|
||||
# Perform search with suppressed C++ output based on log level
|
||||
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
||||
recompute_neighors = False
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
@@ -272,9 +447,9 @@ class DiskannSearcher(BaseSearcher):
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
use_deferred_fetch,
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_embeddings,
|
||||
recompute_neighors,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
@@ -284,3 +459,25 @@ class DiskannSearcher(BaseSearcher):
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup DiskANN-specific resources including C++ index."""
|
||||
# Call parent cleanup first
|
||||
super().cleanup()
|
||||
|
||||
# Delete the C++ index to trigger destructors
|
||||
try:
|
||||
if hasattr(self, "_index") and self._index is not None:
|
||||
del self._index
|
||||
self._index = None
|
||||
self._current_zmq_port = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection to ensure C++ objects are destroyed
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
@@ -32,10 +33,11 @@ if not logger.handlers:
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: str | None = None,
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
distance_metric: str = "l2",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
@@ -79,7 +81,8 @@ def create_diskann_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
@@ -97,6 +100,7 @@ def create_diskann_embedding_server(
|
||||
socket = context.socket(
|
||||
zmq.REP
|
||||
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
@@ -260,9 +264,16 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "mips", "cosine"],
|
||||
help="Distance metric for similarity computation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -272,4 +283,5 @@ if __name__ == "__main__":
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides Python bindings for the graph partition functionality
|
||||
of DiskANN, allowing users to partition disk-based indices for better
|
||||
performance.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GraphPartitioner:
|
||||
"""
|
||||
A Python interface for DiskANN's graph partition functionality.
|
||||
|
||||
This class provides methods to partition disk-based indices for improved
|
||||
search performance and memory efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, build_type: str = "release"):
|
||||
"""
|
||||
Initialize the GraphPartitioner.
|
||||
|
||||
Args:
|
||||
build_type: Build type for the executables ("debug" or "release")
|
||||
"""
|
||||
self.build_type = build_type
|
||||
self._ensure_executables()
|
||||
|
||||
def _get_executable_path(self, name: str) -> str:
|
||||
"""Get the path to a graph partition executable."""
|
||||
# Get the directory where this Python module is located
|
||||
module_dir = Path(__file__).parent
|
||||
# Navigate to the graph_partition directory
|
||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||
|
||||
if not executable_path.exists():
|
||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||
|
||||
return str(executable_path)
|
||||
|
||||
def _ensure_executables(self):
|
||||
"""Ensure that the required executables are built."""
|
||||
try:
|
||||
self._get_executable_path("partitioner")
|
||||
self._get_executable_path("index_relayout")
|
||||
except FileNotFoundError:
|
||||
# Try to build the executables automatically
|
||||
print("Executables not found, attempting to build them...")
|
||||
self._build_executables()
|
||||
|
||||
def _build_executables(self):
|
||||
"""Build the required executables."""
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Clean any existing build
|
||||
if (graph_partition_dir / "build").exists():
|
||||
shutil.rmtree(graph_partition_dir / "build")
|
||||
|
||||
# Run the build script
|
||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||
|
||||
# Check if executables were created
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
print(f"✅ Built partitioner: {partitioner_path}")
|
||||
print(f"✅ Built index_relayout: {relayout_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to build executables: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def partition_graph(
|
||||
self,
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Partition a disk-based index for improved performance.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning:
|
||||
- gp_times: Number of LDG partition iterations (default: 10)
|
||||
- lock_nums: Number of lock nodes (default: 10)
|
||||
- cut: Cut adjacency list degree (default: 100)
|
||||
- scale_factor: Scale factor (default: 1)
|
||||
- data_type: Data type (default: "float")
|
||||
- thread_nums: Number of threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the partitioning process fails
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine partition prefix
|
||||
if partition_prefix is None:
|
||||
partition_prefix = Path(index_prefix_path).name
|
||||
|
||||
# Get executable paths
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Change to the graph_partition directory for temporary files
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Create temporary data directory
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Run partitioner
|
||||
gp_file_path = graph_gp_path / "_part.bin"
|
||||
partitioner_cmd = [
|
||||
partitioner_path,
|
||||
"--index_file",
|
||||
old_index_file,
|
||||
"--data_type",
|
||||
params["data_type"],
|
||||
"--gp_file",
|
||||
str(gp_file_path),
|
||||
"-T",
|
||||
str(params["thread_nums"]),
|
||||
"--ldg_times",
|
||||
str(params["gp_times"]),
|
||||
"--scale",
|
||||
str(params["scale_factor"]),
|
||||
"--mode",
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||
result = subprocess.run(
|
||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Partitioner failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Run relayout
|
||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||
relayout_cmd = [
|
||||
relayout_path,
|
||||
old_index_file,
|
||||
str(gp_file_path),
|
||||
params["data_type"],
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||
result = subprocess.run(
|
||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Relayout failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Copy results to output directory
|
||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||
|
||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||
shutil.copy2(gp_file_path, partition_bin_path)
|
||||
|
||||
print(f"Results copied to: {output_dir}")
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||
"""
|
||||
Get information about a partition file.
|
||||
|
||||
Args:
|
||||
partition_bin_path: Path to the partition binary file
|
||||
|
||||
Returns:
|
||||
Dictionary containing partition information
|
||||
"""
|
||||
if not os.path.exists(partition_bin_path):
|
||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||
|
||||
# For now, return basic file information
|
||||
# In the future, this could parse the binary file for detailed info
|
||||
stat = os.stat(partition_bin_path)
|
||||
return {
|
||||
"file_size": stat.st_size,
|
||||
"file_path": partition_bin_path,
|
||||
"modified_time": stat.st_mtime,
|
||||
}
|
||||
|
||||
|
||||
def partition_graph(
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
build_type: str = "release",
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Convenience function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
build_type: Build type for executables ("debug" or "release")
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
partitioner = GraphPartitioner(build_type=build_type)
|
||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Example: partition an index
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||
)
|
||||
print("Partitioning completed successfully!")
|
||||
print(f"Disk graph index: {disk_graph_path}")
|
||||
print(f"Partition binary: {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Partitioning failed: {e}")
|
||||
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simplified Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides a simple Python interface for graph partitioning
|
||||
that directly calls the existing executables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def partition_graph_simple(
|
||||
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Simple function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Find the graph_partition directory
|
||||
current_file = Path(__file__)
|
||||
graph_partition_dir = current_file.parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
|
||||
if not graph_partition_dir.exists():
|
||||
raise RuntimeError(f"Graph partition directory not found: {graph_partition_dir}")
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run the build script with our parameters
|
||||
cmd = [str(graph_partition_dir / "build.sh"), "release", "split_graph", index_prefix_path]
|
||||
|
||||
# Set environment variables for parameters
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"GP_TIMES": str(params["gp_times"]),
|
||||
"GP_LOCK_NUMS": str(params["lock_nums"]),
|
||||
"GP_CUT": str(params["cut"]),
|
||||
"GP_SCALE_F": str(params["scale_factor"]),
|
||||
"DATA_TYPE": params["data_type"],
|
||||
"GP_T": str(params["thread_nums"]),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Running graph partition with command: {' '.join(cmd)}")
|
||||
print(f"Working directory: {graph_partition_dir}")
|
||||
|
||||
# Run the command
|
||||
result = subprocess.run(
|
||||
cmd, env=env, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Command failed with return code {result.returncode}")
|
||||
print(f"stdout: {result.stdout}")
|
||||
print(f"stderr: {result.stderr}")
|
||||
raise RuntimeError(
|
||||
f"Graph partitioning failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Check if output files were created
|
||||
disk_graph_path = Path(output_dir) / "_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / "_partition.bin"
|
||||
|
||||
if not disk_graph_path.exists():
|
||||
raise RuntimeError(f"Expected output file not found: {disk_graph_path}")
|
||||
|
||||
if not partition_bin_path.exists():
|
||||
raise RuntimeError(f"Expected output file not found: {partition_bin_path}")
|
||||
|
||||
print("✅ Partitioning completed successfully!")
|
||||
print(f" Disk graph index: {disk_graph_path}")
|
||||
print(f" Partition binary: {partition_bin_path}")
|
||||
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph_simple(
|
||||
"/Users/yichuan/Desktop/release2/leann/diskannbuild/test_doc_files",
|
||||
gp_times=5,
|
||||
lock_nums=5,
|
||||
cut=50,
|
||||
)
|
||||
print("Success! Output files:")
|
||||
print(f" - {disk_graph_path}")
|
||||
print(f" - {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.15"
|
||||
dependencies = ["leann-core==0.1.15", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.2.7"
|
||||
dependencies = ["leann-core==0.2.7", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...b2dc4ea2c7
@@ -10,6 +10,14 @@ if(APPLE)
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
|
||||
# Force use of system libc++ to avoid version mismatch
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||
|
||||
# Set minimum macOS version for better compatibility
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Use system ZeroMQ instead of building from source
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
@@ -7,6 +8,12 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Set up logging to avoid print buffer issues
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# --- FourCCs (add more if needed) ---
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||
@@ -243,6 +250,12 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
output_filename: Output CSR index file
|
||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||
"""
|
||||
# Disable buffering for print statements to avoid deadlock in CI/pytest
|
||||
import functools
|
||||
|
||||
global print
|
||||
print = functools.partial(print, flush=True)
|
||||
|
||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||
start_time = time.time()
|
||||
original_hnsw_data = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from leann.interface import (
|
||||
@@ -124,7 +124,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
from . import faiss # type: ignore
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
self.distance_metric = (
|
||||
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||
)
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -150,7 +152,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
@@ -200,6 +202,16 @@ class HNSWSearcher(BaseSearcher):
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||
# This prevents early termination when all scores are in a narrow range
|
||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||
if self.distance_metric == "cosine" and any(
|
||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||
):
|
||||
params.check_relative_distance = False
|
||||
else:
|
||||
params.check_relative_distance = True
|
||||
|
||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||
params.pq_pruning_ratio = prune_ratio
|
||||
|
||||
@@ -233,3 +245,25 @@ class HNSWSearcher(BaseSearcher):
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup HNSW-specific resources including C++ ZMQ connections."""
|
||||
# Call parent cleanup first
|
||||
super().cleanup()
|
||||
|
||||
# Additional cleanup for C++ side ZMQ connections
|
||||
# The ZmqDistanceComputer in C++ uses ZMQ connections that need cleanup
|
||||
try:
|
||||
# Delete the index to trigger C++ destructors
|
||||
if hasattr(self, "index"):
|
||||
del self.index
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection to ensure C++ objects are destroyed
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
@@ -33,7 +34,7 @@ if not logger.handlers:
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: str | None = None,
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
distance_metric: str = "mips",
|
||||
@@ -81,19 +82,8 @@ def create_hnsw_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Convert relative paths to absolute paths based on metadata file location
|
||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||
passage_sources = []
|
||||
for source in meta["passage_sources"]:
|
||||
source_copy = source.copy()
|
||||
# Convert relative paths to absolute paths
|
||||
if not Path(source_copy["path"]).is_absolute():
|
||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||
if not Path(source_copy["index_path"]).is_absolute():
|
||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||
passage_sources.append(source_copy)
|
||||
|
||||
passages = PassageManager(passage_sources)
|
||||
# Let PassageManager handle path resolution uniformly
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
@@ -102,6 +92,7 @@ def create_hnsw_embedding_server(
|
||||
"""ZMQ server thread"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
|
||||
@@ -295,7 +286,7 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.15"
|
||||
version = "0.2.7"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.1.15",
|
||||
"leann-core==0.2.7",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.15"
|
||||
version = "0.2.7"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -31,6 +31,8 @@ dependencies = [
|
||||
"PyPDF2>=3.0.0",
|
||||
"pymupdf>=1.23.0",
|
||||
"pdfplumber>=0.10.0",
|
||||
"nbconvert>=7.0.0", # For .ipynb file support
|
||||
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
]
|
||||
@@ -44,6 +46,7 @@ colab = [
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
leann_mcp = "leann.mcp:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -8,6 +8,10 @@ if platform.system() == "Darwin":
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
os.environ["KMP_BLOCKTIME"] = "0"
|
||||
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
|
||||
if os.environ.get("CI") == "true":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||
|
||||
@@ -10,7 +10,7 @@ import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -23,12 +23,17 @@ from .registry import BACKEND_REGISTRY
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_registered_backends() -> list[str]:
|
||||
"""Get list of registered backend names."""
|
||||
return list(BACKEND_REGISTRY.keys())
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: list[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
port: int | None = None,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@@ -82,21 +87,26 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1s timeout on receive
|
||||
socket.setsockopt(zmq.SNDTIMEO, 1000) # 1s timeout on send
|
||||
socket.setsockopt(zmq.IMMEDIATE, 1) # Don't wait for connection
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
try:
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -110,7 +120,9 @@ class SearchResult:
|
||||
|
||||
|
||||
class PassageManager:
|
||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||
def __init__(
|
||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||
):
|
||||
self.offset_maps = {}
|
||||
self.passage_files = {}
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
@@ -120,10 +132,26 @@ class PassageManager:
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
|
||||
# Fix path resolution for Colab and other environments
|
||||
# Fix path resolution - relative paths should be relative to metadata file directory
|
||||
if not Path(index_file).is_absolute():
|
||||
# If relative path, try to resolve it properly
|
||||
index_file = str(Path(index_file).resolve())
|
||||
if metadata_file_path:
|
||||
# Resolve relative to metadata file directory
|
||||
metadata_dir = Path(metadata_file_path).parent
|
||||
logger.debug(
|
||||
f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
|
||||
)
|
||||
index_file = str((metadata_dir / index_file).resolve())
|
||||
passage_file = str((metadata_dir / passage_file).resolve())
|
||||
logger.debug(f"PassageManager: Resolved index_file: {index_file}")
|
||||
else:
|
||||
# Fallback to current directory resolution (legacy behavior)
|
||||
logger.warning(
|
||||
"PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
|
||||
)
|
||||
logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
|
||||
index_file = str(Path(index_file).resolve())
|
||||
passage_file = str(Path(passage_file).resolve())
|
||||
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
|
||||
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
@@ -152,12 +180,12 @@ class LeannBuilder:
|
||||
self,
|
||||
backend_name: str,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: int | None = None,
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**backend_kwargs,
|
||||
):
|
||||
self.backend_name = backend_name
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||
self.backend_factory = backend_factory
|
||||
@@ -237,7 +265,7 @@ class LeannBuilder:
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: list[dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
@@ -309,8 +337,8 @@ class LeannBuilder:
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
"path": passages_file.name, # Use relative path (just filename)
|
||||
"index_path": offset_file.name, # Use relative path (just filename)
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -425,8 +453,8 @@ class LeannBuilder:
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
"path": passages_file.name, # Use relative path (just filename)
|
||||
"index_path": offset_file.name, # Use relative path (just filename)
|
||||
}
|
||||
],
|
||||
"built_from_precomputed_embeddings": True,
|
||||
@@ -454,14 +482,23 @@ class LeannSearcher:
|
||||
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
|
||||
parent_dir = Path(index_path).parent
|
||||
print(
|
||||
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
||||
)
|
||||
# highlight in red the filenotfound error
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
||||
)
|
||||
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
self.passage_manager = PassageManager(
|
||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||
)
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
@@ -488,6 +525,16 @@ class LeannSearcher:
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Smart top_k detection and adjustment
|
||||
total_docs = len(self.passage_manager.global_offset_map)
|
||||
original_top_k = top_k
|
||||
if top_k > total_docs:
|
||||
top_k = total_docs
|
||||
logger.warning(
|
||||
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
||||
)
|
||||
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
||||
|
||||
zmq_port = None
|
||||
|
||||
start_time = time.time()
|
||||
@@ -524,15 +571,15 @@ class LeannSearcher:
|
||||
zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
time.time() - start_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||
|
||||
enriched_results = []
|
||||
if "labels" in results and "distances" in results:
|
||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
||||
for i, (string_id, dist) in enumerate(
|
||||
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||
zip(results["labels"][0], results["distances"][0])
|
||||
):
|
||||
try:
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
@@ -558,19 +605,45 @@ class LeannSearcher:
|
||||
)
|
||||
except KeyError:
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
logger.error(
|
||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||
)
|
||||
|
||||
# Define color codes outside the loop for final message
|
||||
GREEN = "\033[92m"
|
||||
RESET = "\033[0m"
|
||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||
return enriched_results
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server and ZMQ resources.
|
||||
|
||||
This method should be called after you're done using the searcher,
|
||||
especially in test environments or batch processing scenarios.
|
||||
"""
|
||||
# Stop embedding server
|
||||
if hasattr(self.backend_impl, "embedding_server_manager"):
|
||||
self.backend_impl.embedding_server_manager.stop_server()
|
||||
|
||||
# Set ZMQ linger but don't terminate global context
|
||||
try:
|
||||
import zmq
|
||||
|
||||
# Just set linger on the global instance
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.linger = 0
|
||||
# NEVER call ctx.term() or destroy() on the global instance
|
||||
# That would block waiting for all sockets to close
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class LeannChat:
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
llm_config: dict[str, Any] | None = None,
|
||||
llm_config: Optional[dict[str, Any]] = None,
|
||||
enable_warmup: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -586,7 +659,7 @@ class LeannChat:
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
llm_kwargs: dict[str, Any] | None = None,
|
||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
@@ -614,7 +687,10 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
ask_time = time.time()
|
||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||
ask_time = time.time() - ask_time
|
||||
logger.info(f" Ask time: {ask_time} seconds")
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
@@ -631,3 +707,12 @@ class LeannChat:
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server resources.
|
||||
|
||||
This method should be called after you're done using the chat interface,
|
||||
especially in test environments or batch processing scenarios.
|
||||
"""
|
||||
if hasattr(self.searcher, "cleanup"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
@@ -8,7 +8,7 @@ import difflib
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,12 +17,12 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_ollama_models() -> list[str]:
|
||||
def check_ollama_models(host: str) -> list[str]:
|
||||
"""Check available Ollama models and return a list"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return [model["name"] for model in data.get("models", [])]
|
||||
@@ -309,10 +309,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
return search_hf_models_fuzzy(query, limit)
|
||||
|
||||
|
||||
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models()
|
||||
available_models = check_ollama_models(host)
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
@@ -358,7 +360,11 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
error_msg += (
|
||||
"\n\nDid you mean one of these installed models?\n"
|
||||
+ "\nTry to use ollama pull to install the model you need\n"
|
||||
)
|
||||
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
@@ -465,7 +471,7 @@ class OllamaChat(LLMInterface):
|
||||
requests.get(host)
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama")
|
||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
@@ -485,11 +491,35 @@ class OllamaChat(LLMInterface):
|
||||
import requests
|
||||
|
||||
full_url = f"{self.host}/api/generate"
|
||||
|
||||
# Handle thinking budget for reasoning models
|
||||
options = kwargs.copy()
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget:
|
||||
# Remove thinking_budget from options as it's not a standard Ollama option
|
||||
options.pop("thinking_budget", None)
|
||||
# Only apply reasoning parameters to models that support it
|
||||
reasoning_supported_models = [
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
"deepseek-r1",
|
||||
"deepseek-coder",
|
||||
]
|
||||
|
||||
if thinking_budget in ["low", "medium", "high"]:
|
||||
if any(model in self.model.lower() for model in reasoning_supported_models):
|
||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False, # Keep it simple for now
|
||||
"options": kwargs,
|
||||
"options": options,
|
||||
}
|
||||
logger.debug(f"Sending request to Ollama: {payload}")
|
||||
try:
|
||||
@@ -542,14 +572,41 @@ class HFChat(LLMInterface):
|
||||
self.device = "cpu"
|
||||
logger.info("No GPU detected. Using CPU.")
|
||||
|
||||
# Load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# Load tokenizer and model with timeout protection
|
||||
try:
|
||||
import signal
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError("Model download/loading timed out")
|
||||
|
||||
# Set timeout for model loading (60 seconds)
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(60)
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer for {model_name}...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
logger.info(f"Loading model {model_name}...")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
logger.info(f"Successfully loaded {model_name}")
|
||||
finally:
|
||||
signal.alarm(0) # Cancel the alarm
|
||||
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
|
||||
|
||||
except TimeoutError:
|
||||
logger.error(f"Model loading timed out for {model_name}")
|
||||
raise RuntimeError(
|
||||
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_name}: {e}")
|
||||
raise
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
@@ -628,7 +685,7 @@ class HFChat(LLMInterface):
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
@@ -653,11 +710,38 @@ class OpenAIChat(LLMInterface):
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||
}
|
||||
|
||||
# Handle max_tokens vs max_completion_tokens based on model
|
||||
max_tokens = kwargs.get("max_tokens", 1000)
|
||||
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
|
||||
# o-series models use max_completion_tokens
|
||||
params["max_completion_tokens"] = max_tokens
|
||||
params["temperature"] = 1.0
|
||||
else:
|
||||
# Other models use max_tokens
|
||||
params["max_tokens"] = max_tokens
|
||||
|
||||
# Handle thinking budget for reasoning models
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||
# Check if this is an o-series model (partial match for model names)
|
||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||
if any(model in self.model for model in o_series_models):
|
||||
# Use the correct OpenAI reasoning parameter format
|
||||
params["reasoning_effort"] = thinking_budget
|
||||
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
||||
)
|
||||
|
||||
# Add other kwargs (excluding thinking_budget as it's handled above)
|
||||
for k, v in kwargs.items():
|
||||
if k not in ["max_tokens", "temperature", "thinking_budget"]:
|
||||
params[k] = v
|
||||
|
||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||
|
||||
try:
|
||||
@@ -677,7 +761,7 @@ class SimulatedChat(LLMInterface):
|
||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||
|
||||
|
||||
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
"""
|
||||
Factory function to get an LLM interface based on configuration.
|
||||
|
||||
|
||||
@@ -41,13 +41,23 @@ def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
||||
|
||||
class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
# Always use project-local .leann directory (like .git)
|
||||
self.indexes_dir = Path.cwd() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Default parser for documents
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
# Code-optimized parser
|
||||
self.code_parser = SentenceSplitter(
|
||||
chunk_size=512, # Larger chunks for code context
|
||||
chunk_overlap=50, # Less overlap to preserve function boundaries
|
||||
separator="\n", # Split by lines for code
|
||||
paragraph_separator="\n\n", # Preserve logical code blocks
|
||||
)
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
@@ -64,10 +74,11 @@ class LeannCLI:
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -75,18 +86,34 @@ Examples:
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||
build_parser.add_argument(
|
||||
"index_name", nargs="?", help="Index name (default: current directory name)"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--docs", type=str, default=".", help="Documents directory (default: current directory)"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||
)
|
||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||
build_parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
build_parser.add_argument(
|
||||
"--file-types",
|
||||
type=str,
|
||||
help="Comma-separated list of file extensions to include (e.g., '.txt,.pdf,.pptx'). If not specified, uses default supported types.",
|
||||
)
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
@@ -96,7 +123,12 @@ Examples:
|
||||
search_parser.add_argument("--complexity", type=int, default=64)
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument(
|
||||
"--recompute-embeddings",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Recompute embeddings (default: True)",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
@@ -119,94 +151,370 @@ Examples:
|
||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument(
|
||||
"--recompute-embeddings",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Recompute embeddings (default: True)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
|
||||
# List command
|
||||
subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
return parser
|
||||
|
||||
def register_project_dir(self):
|
||||
"""Register current project directory in global registry"""
|
||||
global_registry = Path.home() / ".leann" / "projects.json"
|
||||
global_registry.parent.mkdir(exist_ok=True)
|
||||
|
||||
current_dir = str(Path.cwd())
|
||||
|
||||
# Load existing registry
|
||||
projects = []
|
||||
if global_registry.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(global_registry) as f:
|
||||
projects = json.load(f)
|
||||
except Exception:
|
||||
projects = []
|
||||
|
||||
# Add current directory if not already present
|
||||
if current_dir not in projects:
|
||||
projects.append(current_dir)
|
||||
|
||||
# Save registry
|
||||
import json
|
||||
|
||||
with open(global_registry, "w") as f:
|
||||
json.dump(projects, f, indent=2)
|
||||
|
||||
def _build_gitignore_parser(self, docs_dir: str):
|
||||
"""Build gitignore parser using gitignore-parser library."""
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
# Try to parse the root .gitignore
|
||||
gitignore_path = Path(docs_dir) / ".gitignore"
|
||||
|
||||
if gitignore_path.exists():
|
||||
try:
|
||||
# gitignore-parser automatically handles all subdirectory .gitignore files!
|
||||
matches = parse_gitignore(str(gitignore_path))
|
||||
print(f"📋 Loaded .gitignore from {docs_dir} (includes all subdirectories)")
|
||||
return matches
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not parse .gitignore: {e}")
|
||||
else:
|
||||
print("📋 No .gitignore found")
|
||||
|
||||
# Fallback: basic pattern matching for essential files
|
||||
essential_patterns = {".git", ".DS_Store", "__pycache__", "node_modules", ".venv", "venv"}
|
||||
|
||||
def basic_matches(file_path):
|
||||
path_parts = Path(file_path).parts
|
||||
return any(part in essential_patterns for part in path_parts)
|
||||
|
||||
return basic_matches
|
||||
|
||||
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
|
||||
"""Check if a file should be excluded using gitignore parser."""
|
||||
return gitignore_matches(str(relative_path))
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
# Get all project directories with .leann
|
||||
global_registry = Path.home() / ".leann" / "projects.json"
|
||||
all_projects = []
|
||||
|
||||
if global_registry.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(global_registry) as f:
|
||||
all_projects = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Filter to only existing directories with .leann
|
||||
valid_projects = []
|
||||
for project_dir in all_projects:
|
||||
project_path = Path(project_dir)
|
||||
if project_path.exists() and (project_path / ".leann" / "indexes").exists():
|
||||
valid_projects.append(project_path)
|
||||
|
||||
# Add current project if it has .leann but not in registry
|
||||
current_path = Path.cwd()
|
||||
if (current_path / ".leann" / "indexes").exists() and current_path not in valid_projects:
|
||||
valid_projects.append(current_path)
|
||||
|
||||
if not valid_projects:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
total_indexes = 0
|
||||
current_dir = Path.cwd()
|
||||
|
||||
if not index_dirs:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
for project_path in valid_projects:
|
||||
indexes_dir = project_path / ".leann" / "indexes"
|
||||
if not indexes_dir.exists():
|
||||
continue
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
index_dirs = [d for d in indexes_dir.iterdir() if d.is_dir()]
|
||||
if not index_dirs:
|
||||
continue
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||
1024 * 1024
|
||||
)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
# Show project header
|
||||
if project_path == current_dir:
|
||||
print(f"\n📁 Current project ({project_path}):")
|
||||
else:
|
||||
print(f"\n📂 {project_path}:")
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print("\nUsage:")
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
for index_dir in index_dirs:
|
||||
total_indexes += 1
|
||||
index_name = index_dir.name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
status = "✓" if meta_file.exists() else "✗"
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f" {total_indexes}. {index_name} [{status}]")
|
||||
if status == "✓":
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||
1024 * 1024
|
||||
)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
if total_indexes > 0:
|
||||
print(f"\nTotal: {total_indexes} indexes across {len(valid_projects)} projects")
|
||||
print("\nUsage (current project only):")
|
||||
|
||||
# Show example from current project
|
||||
current_indexes_dir = current_dir / ".leann" / "indexes"
|
||||
if current_indexes_dir.exists():
|
||||
current_index_dirs = [d for d in current_indexes_dir.iterdir() if d.is_dir()]
|
||||
if current_index_dirs:
|
||||
example_name = current_index_dirs[0].name
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str, custom_file_types: str | None = None):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
if custom_file_types:
|
||||
print(f"Using custom file types: {custom_file_types}")
|
||||
|
||||
# Try to use better PDF parsers first
|
||||
# Build gitignore parser
|
||||
gitignore_matches = self._build_gitignore_parser(docs_dir)
|
||||
|
||||
# Try to use better PDF parsers first, but only if PDFs are requested
|
||||
documents = []
|
||||
docs_path = Path(docs_dir)
|
||||
|
||||
for file_path in docs_path.rglob("*.pdf"):
|
||||
print(f"Processing PDF: {file_path}")
|
||||
# Check if we should process PDFs
|
||||
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
||||
|
||||
# Try PyMuPDF first (best quality)
|
||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||
if text is None:
|
||||
# Try pdfplumber
|
||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||
if should_process_pdfs:
|
||||
for file_path in docs_path.rglob("*.pdf"):
|
||||
# Check if file matches any exclude pattern
|
||||
relative_path = file_path.relative_to(docs_path)
|
||||
if self._should_exclude_file(relative_path, gitignore_matches):
|
||||
continue
|
||||
|
||||
if text:
|
||||
# Create a simple document structure
|
||||
from llama_index.core import Document
|
||||
print(f"Processing PDF: {file_path}")
|
||||
|
||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||
documents.append(doc)
|
||||
else:
|
||||
# Fallback to default reader
|
||||
print(f"Using default reader for {file_path}")
|
||||
default_docs = SimpleDirectoryReader(
|
||||
str(file_path.parent),
|
||||
filename_as_id=True,
|
||||
required_exts=[file_path.suffix],
|
||||
).load_data()
|
||||
documents.extend(default_docs)
|
||||
# Try PyMuPDF first (best quality)
|
||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||
if text is None:
|
||||
# Try pdfplumber
|
||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||
|
||||
if text:
|
||||
# Create a simple document structure
|
||||
from llama_index.core import Document
|
||||
|
||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||
documents.append(doc)
|
||||
else:
|
||||
# Fallback to default reader
|
||||
print(f"Using default reader for {file_path}")
|
||||
try:
|
||||
default_docs = SimpleDirectoryReader(
|
||||
str(file_path.parent),
|
||||
filename_as_id=True,
|
||||
required_exts=[file_path.suffix],
|
||||
).load_data()
|
||||
documents.extend(default_docs)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process {file_path}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
documents.extend(other_docs)
|
||||
if custom_file_types:
|
||||
# Parse custom file types from comma-separated string
|
||||
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
|
||||
# Ensure extensions start with a dot
|
||||
code_extensions = [ext if ext.startswith(".") else f".{ext}" for ext in code_extensions]
|
||||
else:
|
||||
# Use default supported file types
|
||||
code_extensions = [
|
||||
# Original document types
|
||||
".txt",
|
||||
".md",
|
||||
".docx",
|
||||
".pptx",
|
||||
# Code files for Claude Code integration
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".sql",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".ps1",
|
||||
".bat",
|
||||
# Config and markup files
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".xml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".vue",
|
||||
".svelte",
|
||||
# Data science
|
||||
".ipynb",
|
||||
".R",
|
||||
".py",
|
||||
".jl",
|
||||
]
|
||||
# Try to load other file types, but don't fail if none are found
|
||||
try:
|
||||
# Create a custom file filter function using our PathSpec
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Return True if file should be included (not excluded)"""
|
||||
try:
|
||||
docs_path_obj = Path(docs_dir)
|
||||
file_path_obj = Path(file_path)
|
||||
relative_path = file_path_obj.relative_to(docs_path_obj)
|
||||
return not self._should_exclude_file(relative_path, gitignore_matches)
|
||||
except (ValueError, OSError):
|
||||
return True # Include files that can't be processed
|
||||
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Filter documents after loading based on gitignore rules
|
||||
filtered_docs = []
|
||||
for doc in other_docs:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents.extend(filtered_docs)
|
||||
except ValueError as e:
|
||||
if "No files found" in str(e):
|
||||
print("No additional files found for other supported types.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
all_texts = []
|
||||
|
||||
# Define code file extensions for intelligent chunking
|
||||
code_file_exts = {
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".sql",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".ps1",
|
||||
".bat",
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".xml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".vue",
|
||||
".svelte",
|
||||
".ipynb",
|
||||
".R",
|
||||
".jl",
|
||||
}
|
||||
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
# Check if this is a code file based on source path
|
||||
source_path = doc.metadata.get("source", "")
|
||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||
|
||||
# Use appropriate parser based on file type
|
||||
parser = self.code_parser if is_code_file else self.node_parser
|
||||
nodes = parser.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
@@ -215,15 +523,23 @@ Examples:
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
# Use current directory name if index_name not provided
|
||||
if args.index_name:
|
||||
index_name = args.index_name
|
||||
else:
|
||||
index_name = Path.cwd().name
|
||||
print(f"Using current directory name as index: '{index_name}'")
|
||||
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
print(f"📂 Indexing: {Path(docs_dir).resolve()}")
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
all_texts = self.load_documents(docs_dir, args.file_types)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
@@ -235,6 +551,7 @@ Examples:
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
@@ -248,6 +565,9 @@ Examples:
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
# Register this project directory in global registry
|
||||
self.register_project_dir()
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
@@ -308,6 +628,11 @@ Examples:
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
@@ -316,11 +641,17 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
@@ -329,6 +660,7 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
|
||||
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -35,7 +36,7 @@ def compute_embeddings(
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
batch_size: Batch size for processing
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
@@ -55,6 +56,8 @@ def compute_embeddings(
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
elif mode == "ollama":
|
||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||
|
||||
@@ -365,3 +368,262 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
def compute_embeddings_ollama(
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (default: http://localhost:11434)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
||||
)
|
||||
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# Check if Ollama is running
|
||||
try:
|
||||
response = requests.get(f"{host}/api/version", timeout=5)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = (
|
||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
||||
"Please ensure Ollama is running:\n"
|
||||
" • macOS/Linux: ollama serve\n"
|
||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||
"Installation: https://ollama.com/download"
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
||||
|
||||
# Check if model exists and provide helpful suggestions
|
||||
try:
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
model_names = [model["name"] for model in models.get("models", [])]
|
||||
|
||||
# Filter for embedding models (models that support embeddings)
|
||||
embedding_models = []
|
||||
suggested_embedding_models = [
|
||||
"nomic-embed-text",
|
||||
"mxbai-embed-large",
|
||||
"bge-m3",
|
||||
"all-minilm",
|
||||
"snowflake-arctic-embed",
|
||||
]
|
||||
|
||||
for model in model_names:
|
||||
# Check if it's an embedding model (by name patterns or known models)
|
||||
base_name = model.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
||||
embedding_models.append(model)
|
||||
|
||||
# Check if model exists (handle versioned names)
|
||||
model_found = any(
|
||||
model_name == name.split(":")[0] or model_name == name for name in model_names
|
||||
)
|
||||
|
||||
if not model_found:
|
||||
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
||||
|
||||
# Suggest pulling the model
|
||||
error_msg += "📦 To install this embedding model:\n"
|
||||
error_msg += f" ollama pull {model_name}\n\n"
|
||||
|
||||
# Show available embedding models
|
||||
if embedding_models:
|
||||
error_msg += "✅ Available embedding models:\n"
|
||||
for model in embedding_models[:5]:
|
||||
error_msg += f" • {model}\n"
|
||||
if len(embedding_models) > 5:
|
||||
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
||||
else:
|
||||
error_msg += "💡 Popular embedding models to install:\n"
|
||||
for model in suggested_embedding_models[:3]:
|
||||
error_msg += f" • ollama pull {model}\n"
|
||||
|
||||
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Verify the model supports embeddings by testing it
|
||||
try:
|
||||
test_response = requests.post(
|
||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
||||
)
|
||||
if test_response.status_code != 200:
|
||||
error_msg = (
|
||||
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
||||
f"Please use an embedding model like:\n"
|
||||
)
|
||||
for model in suggested_embedding_models[:3]:
|
||||
error_msg += f" • {model}\n"
|
||||
raise ValueError(error_msg)
|
||||
except requests.exceptions.RequestException:
|
||||
# If test fails, continue anyway - model might still work
|
||||
pass
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Could not verify model existence: {e}")
|
||||
|
||||
# Process embeddings with optimized concurrent processing
|
||||
import requests
|
||||
|
||||
def get_single_embedding(text_idx_tuple):
|
||||
"""Helper function to get embedding for a single text."""
|
||||
text, idx = text_idx_tuple
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result.get("embedding")
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError(f"No embedding returned for text {idx}")
|
||||
|
||||
return idx, embedding
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
|
||||
return idx, None
|
||||
|
||||
except Exception as e:
|
||||
if retry_count >= max_retries - 1:
|
||||
logger.error(f"Failed to get embedding for text {idx}: {e}")
|
||||
return idx, None
|
||||
retry_count += 1
|
||||
|
||||
return idx, None
|
||||
|
||||
# Determine if we should use concurrent processing
|
||||
use_concurrent = (
|
||||
len(texts) > 5 and not is_build
|
||||
) # Don't use concurrent in build mode to avoid overwhelming
|
||||
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
|
||||
|
||||
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
|
||||
failed_indices = []
|
||||
|
||||
if use_concurrent:
|
||||
logger.info(
|
||||
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_idx = {
|
||||
executor.submit(get_single_embedding, (text, idx)): idx
|
||||
for idx, text in enumerate(texts)
|
||||
}
|
||||
|
||||
# Add progress bar for concurrent processing
|
||||
try:
|
||||
if is_build or len(texts) > 10:
|
||||
from tqdm import tqdm
|
||||
|
||||
futures_iterator = tqdm(
|
||||
as_completed(future_to_idx),
|
||||
total=len(texts),
|
||||
desc="Computing Ollama embeddings",
|
||||
)
|
||||
else:
|
||||
futures_iterator = as_completed(future_to_idx)
|
||||
except ImportError:
|
||||
futures_iterator = as_completed(future_to_idx)
|
||||
|
||||
# Collect results as they complete
|
||||
for future in futures_iterator:
|
||||
try:
|
||||
idx, embedding = future.result()
|
||||
if embedding is not None:
|
||||
all_embeddings[idx] = embedding
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
except Exception as e:
|
||||
idx = future_to_idx[future]
|
||||
logger.error(f"Exception for text {idx}: {e}")
|
||||
failed_indices.append(idx)
|
||||
|
||||
else:
|
||||
# Sequential processing with progress bar
|
||||
show_progress = is_build or len(texts) > 10
|
||||
|
||||
try:
|
||||
if show_progress:
|
||||
from tqdm import tqdm
|
||||
|
||||
iterator = tqdm(
|
||||
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
|
||||
)
|
||||
else:
|
||||
iterator = enumerate(texts)
|
||||
except ImportError:
|
||||
iterator = enumerate(texts)
|
||||
|
||||
for idx, text in iterator:
|
||||
result_idx, embedding = get_single_embedding((text, idx))
|
||||
if embedding is not None:
|
||||
all_embeddings[idx] = embedding
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
|
||||
# Handle failed embeddings
|
||||
if failed_indices:
|
||||
if len(failed_indices) == len(texts):
|
||||
raise RuntimeError("Failed to compute any embeddings")
|
||||
|
||||
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
|
||||
|
||||
# Use zero embeddings as fallback for failed ones
|
||||
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
||||
if valid_embedding:
|
||||
embedding_dim = len(valid_embedding)
|
||||
for idx in failed_indices:
|
||||
all_embeddings[idx] = [0.0] * embedding_dim
|
||||
|
||||
# Remove None values and convert to numpy array
|
||||
all_embeddings = [e for e in all_embeddings if e is not None]
|
||||
|
||||
# Convert to numpy array and normalize
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
# Normalize embeddings (L2 normalization)
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -182,8 +184,8 @@ class EmbeddingServerManager:
|
||||
e.g., "leann_backend_diskann.embedding_server"
|
||||
"""
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: subprocess.Popen | None = None
|
||||
self.server_port: int | None = None
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
self._atexit_registered = False
|
||||
|
||||
def start_server(
|
||||
@@ -293,6 +295,8 @@ class EmbeddingServerManager:
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if kwargs.get("distance_metric"):
|
||||
command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
|
||||
return command
|
||||
|
||||
@@ -301,13 +305,24 @@ class EmbeddingServerManager:
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
|
||||
# Let server output go directly to console
|
||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||
# In CI environment, redirect output to avoid buffer deadlock
|
||||
# Embedding servers use many print statements that can fill buffers
|
||||
is_ci = os.environ.get("CI") == "true"
|
||||
if is_ci:
|
||||
stdout_target = subprocess.DEVNULL
|
||||
stderr_target = subprocess.DEVNULL
|
||||
logger.info("CI environment detected, redirecting embedding server output to DEVNULL")
|
||||
else:
|
||||
stdout_target = None # Direct to console for visible logs
|
||||
stderr_target = None # Direct to console for visible logs
|
||||
|
||||
# IMPORTANT: Use a new session so we can manage the whole process group reliably
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=None, # Direct to console
|
||||
stderr=None, # Direct to console
|
||||
stdout=stdout_target,
|
||||
stderr=stderr_target,
|
||||
start_new_session=True,
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
@@ -349,34 +364,50 @@ class EmbeddingServerManager:
|
||||
logger.info(
|
||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
# Try terminating the whole process group first (POSIX)
|
||||
try:
|
||||
pgid = os.getpgid(self.server_process.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception:
|
||||
# Fallback to terminating just the process
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
self.server_process.wait(timeout=3)
|
||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
|
||||
# Clean up process resources to prevent resource tracker warnings
|
||||
try:
|
||||
self.server_process.wait() # Ensure process is fully cleaned up
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pgid = os.getpgid(self.server_process.pid)
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except Exception:
|
||||
self.server_process.kill()
|
||||
try:
|
||||
self.server_process.wait(timeout=2)
|
||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(
|
||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
||||
)
|
||||
# Don't hang indefinitely
|
||||
|
||||
# Clean up process resources without waiting
|
||||
# The process should already be terminated/killed above
|
||||
# Don't wait here as it can hang CI indefinitely
|
||||
self.server_process = None
|
||||
|
||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||
"""Launch the server process with Colab-specific settings."""
|
||||
logger.info(f"Colab Command: {' '.join(command)}")
|
||||
|
||||
# In Colab, we need to be more careful about process management
|
||||
# In Colab, redirect to DEVNULL to avoid pipe blocking
|
||||
# PIPE without reading can cause hangs
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
self.server_port = port
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -34,7 +34,9 @@ class LeannBackendSearcherInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
|
||||
@@ -48,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -74,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
|
||||
176
packages/leann-core/src/leann/mcp.py
Executable file
176
packages/leann-core/src/leann/mcp.py
Executable file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def handle_request(request):
|
||||
if request.get("method") == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"capabilities": {"tools": {}},
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
elif request.get("method") == "tools/list":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "leann_search",
|
||||
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
|
||||
|
||||
🎯 **Perfect for**:
|
||||
- "How does authentication work?" → finds auth-related code
|
||||
- "Error handling patterns" → locates try-catch blocks and error logic
|
||||
- "Database connection setup" → finds DB initialization code
|
||||
- "API endpoint definitions" → locates route handlers
|
||||
- "Configuration management" → finds config files and usage
|
||||
|
||||
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
|
||||
},
|
||||
"complexity": {
|
||||
"type": "integer",
|
||||
"default": 32,
|
||||
"minimum": 16,
|
||||
"maximum": 128,
|
||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||
},
|
||||
},
|
||||
"required": ["index_name", "query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leann_status",
|
||||
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_name": {
|
||||
"type": "string",
|
||||
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leann_list",
|
||||
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
elif request.get("method") == "tools/call":
|
||||
tool_name = request["params"]["name"]
|
||||
args = request["params"].get("arguments", {})
|
||||
|
||||
try:
|
||||
if tool_name == "leann_search":
|
||||
# Validate required parameters
|
||||
if not args.get("index_name") or not args.get("query"):
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Error: Both index_name and query are required",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Build simplified command
|
||||
cmd = [
|
||||
"leann",
|
||||
"search",
|
||||
args["index_name"],
|
||||
args["query"],
|
||||
f"--top-k={args.get('top_k', 5)}",
|
||||
f"--complexity={args.get('complexity', 32)}",
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_status":
|
||||
if args.get("index_name"):
|
||||
# Check specific index status - for now, we'll use leann list and filter
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
# We could enhance this to show more detailed status per index
|
||||
else:
|
||||
# Show all indexes status
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_list":
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result.stdout
|
||||
if result.returncode == 0
|
||||
else f"Error: {result.stderr}",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"error": {"code": -1, "message": str(e)},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
for line in sys.stdin:
|
||||
try:
|
||||
request = json.loads(line.strip())
|
||||
response = handle_request(request)
|
||||
if response:
|
||||
print(json.dumps(response))
|
||||
sys.stdout.flush()
|
||||
except Exception as e:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": -1, "message": str(e)},
|
||||
}
|
||||
print(json.dumps(error_response))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -63,12 +63,19 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
if not self.embedding_model:
|
||||
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||
|
||||
# Get distance_metric from meta if not provided in kwargs
|
||||
distance_metric = (
|
||||
kwargs.get("distance_metric")
|
||||
or self.meta.get("backend_kwargs", {}).get("distance_metric")
|
||||
or "mips"
|
||||
)
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
@@ -125,10 +132,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
context = None
|
||||
socket = None
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout
|
||||
socket.setsockopt(zmq.IMMEDIATE, 1) # Don't wait for connection
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
|
||||
# Send embedding request
|
||||
@@ -140,9 +152,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Convert response to numpy array
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return np.array(response, dtype=np.float32)
|
||||
@@ -151,6 +160,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||
finally:
|
||||
if socket:
|
||||
socket.close(linger=0)
|
||||
if context:
|
||||
context.term()
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
@@ -162,7 +176,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -184,7 +198,27 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Ensures the embedding server is stopped when the searcher is destroyed."""
|
||||
def cleanup(self):
|
||||
"""Cleanup resources including embedding server and ZMQ connections."""
|
||||
# Stop embedding server
|
||||
if hasattr(self, "embedding_server_manager"):
|
||||
self.embedding_server_manager.stop_server()
|
||||
|
||||
# Set ZMQ linger but don't terminate global context
|
||||
try:
|
||||
import zmq
|
||||
|
||||
# Just set linger on the global instance
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.linger = 0
|
||||
# NEVER call ctx.term() on the global instance
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Ensures resources are cleaned up when the searcher is destroyed."""
|
||||
try:
|
||||
self.cleanup()
|
||||
except Exception:
|
||||
# Ignore errors during destruction
|
||||
pass
|
||||
|
||||
91
packages/leann-mcp/README.md
Normal file
91
packages/leann-mcp/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# 🔥 LEANN Claude Code Integration
|
||||
|
||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
**Step 1:** First, complete the basic LEANN installation following the [📦 Installation guide](../../README.md#installation) in the root README:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
**Step 2:** Install LEANN globally for MCP integration:
|
||||
```bash
|
||||
uv tool install leann-core
|
||||
```
|
||||
|
||||
This makes the `leann` command available system-wide, which `leann_mcp` requires.
|
||||
|
||||
## 🚀 Quick Setup
|
||||
|
||||
Add the LEANN MCP server to Claude Code:
|
||||
|
||||
```bash
|
||||
claude mcp add leann-server -- leann_mcp
|
||||
```
|
||||
|
||||
## 🛠️ Available Tools
|
||||
|
||||
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
|
||||
|
||||
- **`leann_list`** - List all available indexes across your projects
|
||||
- **`leann_search`** - Perform semantic searches across code and documents
|
||||
- **`leann_ask`** - Ask natural language questions and get AI-powered answers from your codebase
|
||||
|
||||
## 🎯 Quick Start Example
|
||||
|
||||
```bash
|
||||
# Build an index for your project (change to your actual path)
|
||||
leann build my-project --docs ./
|
||||
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
|
||||
**Try this in Claude Code:**
|
||||
```
|
||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
||||
</p>
|
||||
|
||||
|
||||
## 🧠 How It Works
|
||||
|
||||
The integration consists of three key components working seamlessly together:
|
||||
|
||||
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
|
||||
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
|
||||
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
|
||||
|
||||
## 📁 File Support
|
||||
|
||||
LEANN understands **30+ file types** including:
|
||||
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
|
||||
- **Data**: SQL, YAML, JSON, CSV, XML
|
||||
- **Documentation**: Markdown, TXT, PDF
|
||||
- **And many more!**
|
||||
|
||||
## 💾 Storage & Organization
|
||||
|
||||
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
|
||||
- **Global registry**: Project tracking at `~/.leann/projects.json`
|
||||
- **Multi-project support**: Switch between different codebases seamlessly
|
||||
- **Portable**: Transfer indexes between machines with minimal overhead
|
||||
|
||||
## 🗑️ Uninstalling
|
||||
|
||||
To remove the LEANN MCP server from Claude Code:
|
||||
|
||||
```bash
|
||||
claude mcp remove leann-server
|
||||
```
|
||||
To remove LEANN
|
||||
```
|
||||
uv pip uninstall leann leann-backend-hnsw leann-core
|
||||
```
|
||||
@@ -5,11 +5,8 @@ LEANN is a revolutionary vector database that democratizes personal AI. Transfor
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Default installation (HNSW backend, recommended)
|
||||
# Default installation (includes both HNSW and DiskANN backends)
|
||||
uv pip install leann
|
||||
|
||||
# With DiskANN backend (for large-scale deployments)
|
||||
uv pip install leann[diskann]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
@@ -19,8 +16,8 @@ from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
# Build an index (choose backend: "hnsw" or "diskann")
|
||||
builder = LeannBuilder(backend_name="hnsw") # or "diskann" for large-scale deployments
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.1.15"
|
||||
version = "0.2.7"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -24,16 +24,15 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw
|
||||
# Default installation: core + hnsw + diskann
|
||||
dependencies = [
|
||||
"leann-core>=0.1.0",
|
||||
"leann-backend-hnsw>=0.1.0",
|
||||
"leann-backend-diskann>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
diskann = [
|
||||
"leann-backend-diskann>=0.1.0",
|
||||
]
|
||||
# All backends now included by default
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/yichuan-w/LEANN"
|
||||
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "leann-workspace"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
@@ -32,9 +32,9 @@ dependencies = [
|
||||
"pypdfium2>=4.30.0",
|
||||
# LlamaIndex core and readers - updated versions
|
||||
"llama-index>=0.12.44",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||
"llama-index-readers-docling",
|
||||
"llama-index-node-parser-docling",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||
"llama-index-vector-stores-faiss>=0.4.0",
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
# Other dependencies
|
||||
@@ -43,19 +43,35 @@ dependencies = [
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"psutil>=5.8.0",
|
||||
"pybind11>=3.0.0",
|
||||
"pathspec>=0.12.1",
|
||||
"nbconvert>=7.16.6",
|
||||
"gitignore-parser>=0.1.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
"pytest-cov>=4.0",
|
||||
"pytest>=8.3.0", # Minimum version for Python 3.13 support
|
||||
"pytest-cov>=5.0",
|
||||
"pytest-xdist>=3.5", # For parallel test execution
|
||||
"black>=23.0",
|
||||
"ruff>=0.1.0",
|
||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"pre-commit>=3.5.0",
|
||||
]
|
||||
|
||||
test = [
|
||||
"pytest>=8.3.0", # Minimum version for Python 3.13 support
|
||||
"pytest-timeout>=2.3",
|
||||
"anyio>=4.0", # For async test support (includes pytest plugin)
|
||||
"psutil>=5.9.0", # For process cleanup in tests
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
]
|
||||
|
||||
diskann = [
|
||||
"leann-backend-diskann",
|
||||
]
|
||||
@@ -78,7 +94,7 @@ leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = tr
|
||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
target-version = "py39"
|
||||
line-length = 100
|
||||
extend-exclude = [
|
||||
"third_party",
|
||||
@@ -123,3 +139,33 @@ line-ending = "auto"
|
||||
dev = [
|
||||
"ruff>=0.12.4",
|
||||
]
|
||||
|
||||
[tool.lychee]
|
||||
accept = ["200", "403", "429", "503"]
|
||||
timeout = 20
|
||||
max_retries = 2
|
||||
exclude = ["localhost", "127.0.0.1", "example.com"]
|
||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
||||
scheme = ["https", "http"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"openai: marks tests that require OpenAI API key",
|
||||
]
|
||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||
timeout_method = "thread" # Use thread method to avoid non-daemon thread issues
|
||||
addopts = [
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"--strict-markers",
|
||||
"--disable-warnings",
|
||||
]
|
||||
env = [
|
||||
"HF_HUB_DISABLE_SYMLINKS=1",
|
||||
"TOKENIZERS_PARALLELISM=false",
|
||||
]
|
||||
|
||||
103
scripts/diagnose_hang.sh
Executable file
103
scripts/diagnose_hang.sh
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
# Diagnostic script for debugging CI hangs
|
||||
|
||||
echo "========================================="
|
||||
echo " CI HANG DIAGNOSTIC SCRIPT"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
echo "📅 Current time: $(date)"
|
||||
echo "🖥️ Hostname: $(hostname)"
|
||||
echo "👤 User: $(whoami)"
|
||||
echo "📂 Working directory: $(pwd)"
|
||||
echo ""
|
||||
|
||||
echo "=== PYTHON ENVIRONMENT ==="
|
||||
python --version 2>&1 || echo "Python not found"
|
||||
pip list 2>&1 | head -20 || echo "pip not available"
|
||||
echo ""
|
||||
|
||||
echo "=== PROCESS INFORMATION ==="
|
||||
echo "Current shell PID: $$"
|
||||
echo "Parent PID: $PPID"
|
||||
echo ""
|
||||
|
||||
echo "All Python processes:"
|
||||
ps aux | grep -E "[p]ython" || echo "No Python processes"
|
||||
echo ""
|
||||
|
||||
echo "All pytest processes:"
|
||||
ps aux | grep -E "[p]ytest" || echo "No pytest processes"
|
||||
echo ""
|
||||
|
||||
echo "Embedding server processes:"
|
||||
ps aux | grep -E "[e]mbedding_server" || echo "No embedding server processes"
|
||||
echo ""
|
||||
|
||||
echo "Zombie processes:"
|
||||
ps aux | grep "<defunct>" || echo "No zombie processes"
|
||||
echo ""
|
||||
|
||||
echo "=== NETWORK INFORMATION ==="
|
||||
echo "Network listeners on typical embedding server ports:"
|
||||
ss -ltn 2>/dev/null | grep -E ":555[0-9]|:556[0-9]" || netstat -ltn 2>/dev/null | grep -E ":555[0-9]|:556[0-9]" || echo "No listeners on embedding ports"
|
||||
echo ""
|
||||
|
||||
echo "All network listeners:"
|
||||
ss -ltn 2>/dev/null | head -20 || netstat -ltn 2>/dev/null | head -20 || echo "Cannot get network info"
|
||||
echo ""
|
||||
|
||||
echo "=== FILE DESCRIPTORS ==="
|
||||
echo "Open files for current shell:"
|
||||
lsof -p $$ 2>/dev/null | head -20 || echo "lsof not available"
|
||||
echo ""
|
||||
|
||||
if [ -d "/proc/$$" ]; then
|
||||
echo "File descriptors for current shell (/proc/$$/fd):"
|
||||
ls -la /proc/$$/fd 2>/dev/null | head -20 || echo "Cannot access /proc/$$/fd"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo "=== SYSTEM RESOURCES ==="
|
||||
echo "Memory usage:"
|
||||
free -h 2>/dev/null || vm_stat 2>/dev/null || echo "Cannot get memory info"
|
||||
echo ""
|
||||
|
||||
echo "Disk usage:"
|
||||
df -h . 2>/dev/null || echo "Cannot get disk info"
|
||||
echo ""
|
||||
|
||||
echo "CPU info:"
|
||||
nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo "Cannot get CPU info"
|
||||
echo ""
|
||||
|
||||
echo "=== PYTHON SPECIFIC CHECKS ==="
|
||||
python -c "
|
||||
import sys
|
||||
import os
|
||||
print(f'Python executable: {sys.executable}')
|
||||
print(f'Python path: {sys.path[:3]}...')
|
||||
print(f'Environment PYTHONPATH: {os.environ.get(\"PYTHONPATH\", \"Not set\")}')
|
||||
print(f'Site packages: {[p for p in sys.path if \"site-packages\" in p][:2]}')
|
||||
" 2>&1 || echo "Cannot run Python diagnostics"
|
||||
echo ""
|
||||
|
||||
echo "=== ZMQ SPECIFIC CHECKS ==="
|
||||
python -c "
|
||||
try:
|
||||
import zmq
|
||||
print(f'ZMQ version: {zmq.zmq_version()}')
|
||||
print(f'PyZMQ version: {zmq.pyzmq_version()}')
|
||||
ctx = zmq.Context.instance()
|
||||
print(f'ZMQ context instance: {ctx}')
|
||||
except Exception as e:
|
||||
print(f'ZMQ check failed: {e}')
|
||||
" 2>&1 || echo "Cannot check ZMQ"
|
||||
echo ""
|
||||
|
||||
echo "=== PYTEST CHECK ==="
|
||||
pytest --version 2>&1 || echo "pytest not found"
|
||||
echo ""
|
||||
|
||||
echo "=== END OF DIAGNOSTICS ==="
|
||||
echo "Generated at: $(date)"
|
||||
@@ -1,161 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if (
|
||||
part.get_content_type() == "text/plain"
|
||||
or part.get_content_type() == "text/html"
|
||||
):
|
||||
body += part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
if count == 0:
|
||||
print("--------------------------------")
|
||||
print("dir path", dirpath)
|
||||
print(metadata)
|
||||
print(doc_content)
|
||||
print("--------------------------------")
|
||||
body = []
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
print(
|
||||
"-------------------------------- get content type -------------------------------"
|
||||
)
|
||||
print(part.get_content_type())
|
||||
print(part)
|
||||
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
|
||||
print(
|
||||
"-------------------------------- get content type -------------------------------"
|
||||
)
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
print(body)
|
||||
|
||||
print(body)
|
||||
print("--------------------------------")
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
# Use the custom EmlxReader instead of MboxReader
|
||||
documents = EmlxReader().load_data(
|
||||
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
max_count=1000,
|
||||
) # Returns list of documents
|
||||
|
||||
# Configure the index with larger chunk size to handle long metadata
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Create a custom text splitter with larger chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[text_splitter]
|
||||
) # Initialize index with documents
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||
print(res)
|
||||
@@ -1,219 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
|
||||
"""
|
||||
Create the index from mail data and save it to disk.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
save_dir: Directory to save the index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating index from mail data...")
|
||||
|
||||
# Load documents
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
# Create text splitter
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||
|
||||
# Save the index
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||
|
||||
if index:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"What emails mention GSR appointments?",
|
||||
"Find emails about deadlines",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,219 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with reduced metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create minimal metadata (only essential info)
|
||||
metadata = {
|
||||
"subject": subject[:50], # Truncate subject
|
||||
"from": from_addr[:30], # Truncate from
|
||||
"date": date[:20], # Truncate date
|
||||
"filename": filename, # Keep filename
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def create_and_save_index(
|
||||
mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000
|
||||
):
|
||||
"""
|
||||
Create the index from mail data and save it to disk.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
save_dir: Directory to save the index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating index from mail data with small chunks...")
|
||||
|
||||
# Load documents
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
# Create text splitter with small chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||
|
||||
# Save the index
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index_small"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index_small"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||
|
||||
if index:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"What emails mention GSR appointments?",
|
||||
"Find emails about deadlines",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,154 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Check if directory exists and is accessible
|
||||
if not os.path.exists(input_dir):
|
||||
print(f"Error: Directory '{input_dir}' does not exist")
|
||||
return docs
|
||||
|
||||
if not os.access(input_dir, os.R_OK):
|
||||
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
|
||||
print("This is likely due to macOS security restrictions on Mail app data")
|
||||
return docs
|
||||
|
||||
print(f"Scanning directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
print(f"Found .emlx file: {filepath}")
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def main():
|
||||
# Use the current directory where the sample.emlx file is located
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
print("Testing EmlxReader with sample .emlx file...")
|
||||
print(f"Scanning directory: {current_dir}")
|
||||
|
||||
# Use the custom EmlxReader
|
||||
documents = EmlxReader().load_data(current_dir, max_count=1000)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
|
||||
return
|
||||
|
||||
print(f"\nSuccessfully loaded {len(documents)} document(s)")
|
||||
|
||||
# Initialize index with documents
|
||||
index = VectorStoreIndex.from_documents(documents)
|
||||
query_engine = index.as_query_engine()
|
||||
|
||||
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
|
||||
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||
print(f"Response: {res}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,105 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
save_dir = "mail_index"
|
||||
|
||||
# Check if index exists
|
||||
if not os.path.exists(save_dir) or not os.path.exists(
|
||||
os.path.join(save_dir, "vector_store.json")
|
||||
):
|
||||
print(f"Index not found in {save_dir}")
|
||||
print("Please run mail_reader_save_load.py first to create the index.")
|
||||
return
|
||||
|
||||
# Load the index
|
||||
index = load_index(save_dir)
|
||||
|
||||
if not index:
|
||||
print("Failed to load index.")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Email Query Interface")
|
||||
print("=" * 60)
|
||||
print("Type 'quit' to exit")
|
||||
print("Type 'help' for example queries")
|
||||
print("=" * 60)
|
||||
|
||||
# Interactive query loop
|
||||
while True:
|
||||
try:
|
||||
query = input("\nEnter your query: ").strip()
|
||||
|
||||
if query.lower() == "quit":
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("\nExample queries:")
|
||||
print("- Hows Berkeley Graduate Student Instructor")
|
||||
print("- What emails mention GSR appointments?")
|
||||
print("- Find emails about deadlines")
|
||||
print("- Search for emails from specific sender")
|
||||
print("- Find emails about meetings")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
query_index(index, query)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error processing query: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import zmq
|
||||
|
||||
sys.path.append("packages/leann-backend-diskann")
|
||||
from leann_backend_diskann import embedding_pb2
|
||||
|
||||
|
||||
def test_zmq_with_same_model():
|
||||
print("=== Testing ZMQ with same model as main_cli_example.py ===")
|
||||
|
||||
# Test the exact same model that main_cli_example.py uses
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
# Start server with the same model
|
||||
import subprocess
|
||||
|
||||
server_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||
"--zmq-port",
|
||||
"5556", # Use different port to avoid conflicts
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
print(f"Starting server with command: {' '.join(server_cmd)}")
|
||||
server_process = subprocess.Popen(
|
||||
server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
print("Waiting for server to start...")
|
||||
time.sleep(10)
|
||||
|
||||
# Check if server is running
|
||||
if server_process.poll() is not None:
|
||||
stdout, stderr = server_process.communicate()
|
||||
print(f"Server failed to start. stdout: {stdout}")
|
||||
print(f"Server failed to start. stderr: {stderr}")
|
||||
return False
|
||||
|
||||
print(f"Server started with PID: {server_process.pid}")
|
||||
|
||||
try:
|
||||
# Test client
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect("tcp://127.0.0.1:5556")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
# Create request with same format as C++
|
||||
request = embedding_pb2.NodeEmbeddingRequest()
|
||||
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
|
||||
|
||||
print(f"Sending request with {len(request.node_ids)} node IDs...")
|
||||
start_time = time.time()
|
||||
|
||||
# Send request
|
||||
socket.send(request.SerializeToString())
|
||||
|
||||
# Receive response
|
||||
response_data = socket.recv()
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Received response in {end_time - start_time:.3f} seconds")
|
||||
print(f"Response size: {len(response_data)} bytes")
|
||||
|
||||
# Parse response
|
||||
response = embedding_pb2.NodeEmbeddingResponse()
|
||||
response.ParseFromString(response_data)
|
||||
|
||||
print(f"Response dimensions: {list(response.dimensions)}")
|
||||
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
|
||||
print(f"Missing IDs: {list(response.missing_ids)}")
|
||||
|
||||
# Calculate expected size
|
||||
if len(response.dimensions) == 2:
|
||||
batch_size = response.dimensions[0]
|
||||
embedding_dim = response.dimensions[1]
|
||||
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
|
||||
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
|
||||
|
||||
if len(response.embeddings_data) == expected_bytes:
|
||||
print("✅ Response format is correct!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Response format mismatch!")
|
||||
return False
|
||||
else:
|
||||
print("❌ Invalid response dimensions!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during ZMQ test: {e}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
server_process.terminate()
|
||||
server_process.wait()
|
||||
print("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_zmq_with_same_model()
|
||||
if success:
|
||||
print("\n✅ ZMQ communication test passed!")
|
||||
else:
|
||||
print("\n❌ ZMQ communication test failed!")
|
||||
106
tests/README.md
Normal file
106
tests/README.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# LEANN Tests
|
||||
|
||||
This directory contains automated tests for the LEANN project using pytest.
|
||||
|
||||
## Test Files
|
||||
|
||||
### `test_readme_examples.py`
|
||||
Tests the examples shown in README.md:
|
||||
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
||||
- Import statements work correctly
|
||||
- Different backend options (HNSW, DiskANN)
|
||||
- Different LLM configuration options (parametrized for both backends)
|
||||
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
||||
|
||||
### `test_basic.py`
|
||||
Basic functionality tests that verify:
|
||||
- All packages can be imported correctly
|
||||
- C++ extensions (FAISS, DiskANN) load properly
|
||||
- Basic index building and searching works for both HNSW and DiskANN backends
|
||||
- Uses parametrized tests to test both backends
|
||||
|
||||
### `test_document_rag.py`
|
||||
Tests the document RAG example functionality:
|
||||
- Tests with facebook/contriever embeddings
|
||||
- Tests with OpenAI embeddings (if API key is available)
|
||||
- Tests error handling with invalid parameters
|
||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||
|
||||
### `test_diskann_partition.py`
|
||||
Tests DiskANN graph partitioning functionality:
|
||||
- Tests DiskANN index building without partitioning (baseline)
|
||||
- Tests automatic graph partitioning with `is_recompute=True`
|
||||
- Verifies that partition files are created and large files are cleaned up for storage saving
|
||||
- Tests search functionality with partitioned indices
|
||||
- Validates medoid and max_base_norm file generation and usage
|
||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Install test dependencies:
|
||||
```bash
|
||||
# Using extras
|
||||
uv pip install -e ".[test]"
|
||||
```
|
||||
|
||||
### Run all tests:
|
||||
```bash
|
||||
pytest tests/
|
||||
|
||||
# Or with coverage
|
||||
pytest tests/ --cov=leann --cov-report=html
|
||||
|
||||
# Run in parallel (faster)
|
||||
pytest tests/ -n auto
|
||||
```
|
||||
|
||||
### Run specific tests:
|
||||
```bash
|
||||
# Only basic tests
|
||||
pytest tests/test_basic.py
|
||||
|
||||
# Only tests that don't require OpenAI
|
||||
pytest tests/ -m "not openai"
|
||||
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Run DiskANN partition tests (requires local machine, not CI)
|
||||
pytest tests/test_diskann_partition.py
|
||||
```
|
||||
|
||||
### Run with specific backend:
|
||||
```bash
|
||||
# Test only HNSW backend
|
||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
||||
|
||||
# Test only DiskANN backend
|
||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
||||
|
||||
# All DiskANN tests (parametrized + specialized partition tests)
|
||||
pytest tests/ -k diskann
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Tests are automatically run in GitHub Actions:
|
||||
1. After building wheel packages
|
||||
2. On multiple Python versions (3.9 - 3.13)
|
||||
3. On both Ubuntu and macOS
|
||||
4. Using pytest with appropriate markers and flags
|
||||
|
||||
### pytest.ini Configuration
|
||||
|
||||
The `pytest.ini` file configures:
|
||||
- Test discovery paths
|
||||
- Default timeout (600 seconds)
|
||||
- Environment variables (HF_HUB_DISABLE_SYMLINKS, TOKENIZERS_PARALLELISM)
|
||||
- Custom markers for slow and OpenAI tests
|
||||
- Verbose output with short tracebacks
|
||||
|
||||
### Known Issues
|
||||
|
||||
- OpenAI tests are automatically skipped if no API key is provided
|
||||
301
tests/conftest.py
Normal file
301
tests/conftest.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Global test configuration and cleanup fixtures."""
|
||||
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
# Enable faulthandler to dump stack traces
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _ci_backtraces():
|
||||
"""Dump stack traces before CI timeout to diagnose hanging."""
|
||||
if os.getenv("CI") == "true":
|
||||
# Dump stack traces 10s before the 180s timeout
|
||||
faulthandler.dump_traceback_later(170, repeat=True)
|
||||
yield
|
||||
faulthandler.cancel_dump_traceback_later()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def global_test_cleanup() -> Generator:
|
||||
"""Global cleanup fixture that runs after all tests.
|
||||
|
||||
This ensures all ZMQ connections and child processes are properly cleaned up,
|
||||
preventing the test runner from hanging on exit.
|
||||
"""
|
||||
yield
|
||||
|
||||
# Cleanup after all tests
|
||||
print("\n🧹 Running global test cleanup...")
|
||||
|
||||
# 1. Force cleanup of any LeannSearcher instances
|
||||
try:
|
||||
import gc
|
||||
|
||||
# Force garbage collection to trigger __del__ methods
|
||||
gc.collect()
|
||||
time.sleep(0.2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Set ZMQ linger but DON'T term Context.instance()
|
||||
# Terminating the global instance can block if other code still has sockets
|
||||
try:
|
||||
import zmq
|
||||
|
||||
# Just set linger on the global instance, don't terminate it
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.linger = 0
|
||||
# Do NOT call ctx.term() or ctx.destroy() on the global instance!
|
||||
# That would block waiting for all sockets to close
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Kill any leftover child processes (including grandchildren)
|
||||
try:
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process()
|
||||
# Get ALL descendants recursively
|
||||
children = current_process.children(recursive=True)
|
||||
|
||||
if children:
|
||||
print(f"\n⚠️ Cleaning up {len(children)} leftover child processes...")
|
||||
|
||||
# First try to terminate gracefully
|
||||
for child in children:
|
||||
try:
|
||||
print(f" Terminating {child.pid} ({child.name()})")
|
||||
child.terminate()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Wait a bit for processes to terminate
|
||||
gone, alive = psutil.wait_procs(children, timeout=2)
|
||||
|
||||
# Force kill any remaining processes
|
||||
for child in alive:
|
||||
try:
|
||||
print(f" Force killing process {child.pid} ({child.name()})")
|
||||
child.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Final wait to ensure cleanup
|
||||
psutil.wait_procs(alive, timeout=1)
|
||||
except ImportError:
|
||||
# psutil not installed, try basic process cleanup
|
||||
try:
|
||||
# Send SIGTERM to all child processes
|
||||
os.killpg(os.getpgid(os.getpid()), signal.SIGTERM)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Warning: Error during process cleanup: {e}")
|
||||
|
||||
# List and clean up remaining threads
|
||||
try:
|
||||
import threading
|
||||
|
||||
threads = [t for t in threading.enumerate() if t is not threading.main_thread()]
|
||||
if threads:
|
||||
print(f"\n⚠️ {len(threads)} non-main threads still running:")
|
||||
for t in threads:
|
||||
print(f" - {t.name} (daemon={t.daemon})")
|
||||
|
||||
# Force cleanup of pytest-timeout threads that block exit
|
||||
if "pytest_timeout" in t.name and not t.daemon:
|
||||
print(f" 🔧 Converting pytest-timeout thread to daemon: {t.name}")
|
||||
try:
|
||||
t.daemon = True
|
||||
print(" ✓ Converted to daemon thread")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
|
||||
# Check if only daemon threads remain
|
||||
non_daemon = [
|
||||
t for t in threading.enumerate() if t is not threading.main_thread() and not t.daemon
|
||||
]
|
||||
if non_daemon:
|
||||
print(f"\n⚠️ {len(non_daemon)} non-daemon threads still blocking exit")
|
||||
# Force exit in CI to prevent hanging
|
||||
if os.environ.get("CI") == "true":
|
||||
print("🔨 Forcing exit in CI environment...")
|
||||
os._exit(0)
|
||||
except Exception as e:
|
||||
print(f"Thread cleanup error: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_cleanup_searcher():
|
||||
"""Fixture that automatically cleans up LeannSearcher instances."""
|
||||
searchers = []
|
||||
|
||||
def register(searcher):
|
||||
"""Register a searcher for cleanup."""
|
||||
searchers.append(searcher)
|
||||
return searcher
|
||||
|
||||
yield register
|
||||
|
||||
# Cleanup all registered searchers
|
||||
for searcher in searchers:
|
||||
try:
|
||||
searcher.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _reap_children():
|
||||
"""Reap all child processes at session end as a safety net."""
|
||||
yield
|
||||
|
||||
# Final aggressive cleanup
|
||||
try:
|
||||
import psutil
|
||||
|
||||
me = psutil.Process()
|
||||
kids = me.children(recursive=True)
|
||||
for p in kids:
|
||||
try:
|
||||
p.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_, alive = psutil.wait_procs(kids, timeout=2)
|
||||
for p in alive:
|
||||
try:
|
||||
p.kill()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_after_each_test():
|
||||
"""Cleanup after each test to prevent resource leaks."""
|
||||
yield
|
||||
|
||||
# Force garbage collection to trigger any __del__ methods
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
# Give a moment for async cleanup
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with better timeout handling."""
|
||||
# Set default timeout method to thread if not specified
|
||||
if not config.getoption("--timeout-method", None):
|
||||
config.option.timeout_method = "thread"
|
||||
|
||||
# Add more logging
|
||||
print(f"🔧 Pytest configured at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f" Python version: {os.sys.version}")
|
||||
print(f" Platform: {os.sys.platform}")
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""Called after the Session object has been created."""
|
||||
print(f"🏁 Pytest session starting at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f" Session ID: {id(session)}")
|
||||
|
||||
# Show initial process state
|
||||
try:
|
||||
import psutil
|
||||
|
||||
current = psutil.Process()
|
||||
print(f" Current PID: {current.pid}")
|
||||
print(f" Parent PID: {current.ppid()}")
|
||||
children = current.children(recursive=True)
|
||||
if children:
|
||||
print(f" ⚠️ Already have {len(children)} child processes at start!")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
"""Called after whole test run finished."""
|
||||
print(f"🏁 Pytest session finishing at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f" Exit status: {exitstatus}")
|
||||
|
||||
# Aggressive cleanup before pytest exits
|
||||
print("🧹 Starting aggressive cleanup...")
|
||||
|
||||
# First, clean up child processes
|
||||
try:
|
||||
import psutil
|
||||
|
||||
current = psutil.Process()
|
||||
children = current.children(recursive=True)
|
||||
|
||||
if children:
|
||||
print(f" Found {len(children)} child processes to clean up:")
|
||||
for child in children:
|
||||
try:
|
||||
print(f" - PID {child.pid}: {child.name()} (status: {child.status()})")
|
||||
child.terminate()
|
||||
except Exception as e:
|
||||
print(f" - Failed to terminate {child.pid}: {e}")
|
||||
|
||||
# Wait briefly then kill
|
||||
time.sleep(0.5)
|
||||
_, alive = psutil.wait_procs(children, timeout=1)
|
||||
|
||||
for child in alive:
|
||||
try:
|
||||
print(f" - Force killing {child.pid}")
|
||||
child.kill()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(" No child processes found")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Process cleanup error: {e}")
|
||||
|
||||
# Second, clean up problematic threads
|
||||
try:
|
||||
import threading
|
||||
|
||||
threads = [t for t in threading.enumerate() if t is not threading.main_thread()]
|
||||
if threads:
|
||||
print(f" Found {len(threads)} non-main threads:")
|
||||
for t in threads:
|
||||
print(f" - {t.name} (daemon={t.daemon})")
|
||||
# Convert pytest-timeout threads to daemon so they don't block exit
|
||||
if "pytest_timeout" in t.name and not t.daemon:
|
||||
try:
|
||||
t.daemon = True
|
||||
print(" ✓ Converted to daemon")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force exit if non-daemon threads remain in CI
|
||||
non_daemon = [
|
||||
t for t in threading.enumerate() if t is not threading.main_thread() and not t.daemon
|
||||
]
|
||||
if non_daemon and os.environ.get("CI") == "true":
|
||||
print(f" ⚠️ {len(non_daemon)} non-daemon threads remain, forcing exit...")
|
||||
os._exit(exitstatus or 0)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Thread cleanup error: {e}")
|
||||
|
||||
print(f"✅ Pytest exiting at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
95
tests/test_basic.py
Normal file
95
tests/test_basic.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Basic functionality tests for CI pipeline using pytest.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from test_timeout import ci_timeout
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test that all packages can be imported."""
|
||||
|
||||
# Test C++ extensions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
@ci_timeout(120) # 2 minute timeout for backend tests
|
||||
def test_backend_basic(backend_name):
|
||||
"""Test basic functionality for each backend."""
|
||||
from leann.api import LeannBuilder, LeannSearcher, SearchResult
|
||||
|
||||
# Create temporary directory for index
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"test.{backend_name}")
|
||||
|
||||
# Test with small data
|
||||
texts = [f"This is document {i} about topic {i % 5}" for i in range(100)]
|
||||
|
||||
# Configure builder based on backend
|
||||
if backend_name == "hnsw":
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
M=16,
|
||||
efConstruction=200,
|
||||
)
|
||||
else: # diskann
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
)
|
||||
|
||||
# Add texts
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
# Build index
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test search
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("document about topic 2", top_k=5)
|
||||
|
||||
# Verify results
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], SearchResult)
|
||||
assert "topic 2" in results[0].text or "document" in results[0].text
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
@ci_timeout(180) # 3 minute timeout for large index test
|
||||
def test_large_index():
|
||||
"""Test with larger dataset."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_large.hnsw")
|
||||
texts = [f"Document {i}: {' '.join([f'word{j}' for j in range(50)])}" for i in range(1000)]
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search(["word10 word20"], top_k=10)
|
||||
assert len(results[0]) == 10
|
||||
49
tests/test_ci_minimal.py
Normal file
49
tests/test_ci_minimal.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Minimal tests for CI that don't require model loading or significant memory.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_package_imports():
|
||||
"""Test that all core packages can be imported."""
|
||||
# Core package
|
||||
|
||||
# Backend packages
|
||||
|
||||
# Core modules
|
||||
|
||||
assert True # If we get here, imports worked
|
||||
|
||||
|
||||
def test_cli_help():
|
||||
"""Test that CLI example shows help."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, "apps/document_rag.py", "--help"], capture_output=True, text=True
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "usage:" in result.stdout.lower() or "usage:" in result.stderr.lower()
|
||||
assert "--llm" in result.stdout or "--llm" in result.stderr
|
||||
|
||||
|
||||
def test_backend_registration():
|
||||
"""Test that backends are properly registered."""
|
||||
from leann.api import get_registered_backends
|
||||
|
||||
backends = get_registered_backends()
|
||||
assert "hnsw" in backends
|
||||
assert "diskann" in backends
|
||||
|
||||
|
||||
def test_version_info():
|
||||
"""Test that packages have version information."""
|
||||
import leann
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
|
||||
# Check that packages have __version__ or can be imported
|
||||
assert hasattr(leann, "__version__") or True
|
||||
assert hasattr(leann_backend_hnsw, "__version__") or True
|
||||
assert hasattr(leann_backend_diskann, "__version__") or True
|
||||
369
tests/test_diskann_partition.py
Normal file
369
tests/test_diskann_partition.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Test DiskANN graph partitioning functionality.
|
||||
|
||||
Tests the automatic graph partitioning feature that was implemented to save
|
||||
storage space by partitioning large DiskANN indices and safely deleting
|
||||
redundant files while maintaining search functionality.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_without_partition():
|
||||
"""Test DiskANN index building without partition (baseline)."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_no_partition.leann")
|
||||
|
||||
# Test data - enough to trigger index building
|
||||
texts = [
|
||||
f"Document {i} discusses topic {i % 10} with detailed analysis of subject {i // 10}."
|
||||
for i in range(500)
|
||||
]
|
||||
|
||||
# Build without partition (is_recompute=False)
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
is_recompute=False, # No partition
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Verify index was created
|
||||
index_dir = Path(index_path).parent
|
||||
assert index_dir.exists()
|
||||
|
||||
# Check that traditional DiskANN files exist
|
||||
index_prefix = Path(index_path).stem
|
||||
# Core DiskANN files (beam search index may not be created for small datasets)
|
||||
required_files = [
|
||||
f"{index_prefix}_disk.index",
|
||||
f"{index_prefix}_pq_compressed.bin",
|
||||
f"{index_prefix}_pq_pivots.bin",
|
||||
]
|
||||
|
||||
# Check all generated files first for debugging
|
||||
generated_files = [f.name for f in index_dir.glob(f"{index_prefix}*")]
|
||||
print(f"Generated files: {generated_files}")
|
||||
|
||||
for required_file in required_files:
|
||||
file_path = index_dir / required_file
|
||||
assert file_path.exists(), f"Required file {required_file} not found"
|
||||
|
||||
# Ensure no partition files exist in non-partition mode
|
||||
partition_files = [f"{index_prefix}_disk_graph.index", f"{index_prefix}_partition.bin"]
|
||||
|
||||
for partition_file in partition_files:
|
||||
file_path = index_dir / partition_file
|
||||
assert not file_path.exists(), (
|
||||
f"Partition file {partition_file} should not exist in non-partition mode"
|
||||
)
|
||||
|
||||
# Test search functionality
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("topic 3 analysis", top_k=3)
|
||||
|
||||
assert len(results) > 0
|
||||
assert all(result.score is not None and result.score != float("-inf") for result in results)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_with_partition():
|
||||
"""Test DiskANN index building with automatic graph partitioning."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_with_partition.leann")
|
||||
|
||||
# Test data - enough to trigger partitioning
|
||||
texts = [
|
||||
f"Document {i} explores subject {i % 15} with comprehensive coverage of area {i // 15}."
|
||||
for i in range(500)
|
||||
]
|
||||
|
||||
# Build with partition (is_recompute=True)
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
is_recompute=True, # Enable automatic partitioning
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Verify index was created
|
||||
index_dir = Path(index_path).parent
|
||||
assert index_dir.exists()
|
||||
|
||||
# Check that partition files exist
|
||||
index_prefix = Path(index_path).stem
|
||||
partition_files = [
|
||||
f"{index_prefix}_disk_graph.index", # Partitioned graph
|
||||
f"{index_prefix}_partition.bin", # Partition metadata
|
||||
f"{index_prefix}_pq_compressed.bin",
|
||||
f"{index_prefix}_pq_pivots.bin",
|
||||
]
|
||||
|
||||
for partition_file in partition_files:
|
||||
file_path = index_dir / partition_file
|
||||
assert file_path.exists(), f"Expected partition file {partition_file} not found"
|
||||
|
||||
# Check that large files were cleaned up (storage saving goal)
|
||||
large_files = [f"{index_prefix}_disk.index", f"{index_prefix}_disk_beam_search.index"]
|
||||
|
||||
for large_file in large_files:
|
||||
file_path = index_dir / large_file
|
||||
assert not file_path.exists(), (
|
||||
f"Large file {large_file} should have been deleted for storage saving"
|
||||
)
|
||||
|
||||
# Verify required auxiliary files for partition mode exist
|
||||
required_files = [
|
||||
f"{index_prefix}_disk.index_medoids.bin",
|
||||
f"{index_prefix}_disk.index_max_base_norm.bin",
|
||||
]
|
||||
|
||||
for req_file in required_files:
|
||||
file_path = index_dir / req_file
|
||||
assert file_path.exists(), (
|
||||
f"Required auxiliary file {req_file} missing for partition mode"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_partition_search_functionality():
|
||||
"""Test that search works correctly with partitioned indices."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_partition_search.leann")
|
||||
|
||||
# Create diverse test data
|
||||
texts = [
|
||||
"LEANN is a storage-efficient approximate nearest neighbor search system.",
|
||||
"Graph partitioning helps reduce memory usage in large scale vector search.",
|
||||
"DiskANN provides high-performance disk-based approximate nearest neighbor search.",
|
||||
"Vector embeddings enable semantic search over unstructured text data.",
|
||||
"Approximate nearest neighbor algorithms trade accuracy for speed and storage.",
|
||||
] * 100 # Repeat to get enough data
|
||||
|
||||
# Build with partitioning
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True, # Enable partitioning
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test search with partitioned index
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test various queries
|
||||
test_queries = [
|
||||
("vector search algorithms", 5),
|
||||
("LEANN storage efficiency", 3),
|
||||
("graph partitioning memory", 4),
|
||||
("approximate nearest neighbor", 7),
|
||||
]
|
||||
|
||||
for query, top_k in test_queries:
|
||||
results = searcher.search(query, top_k=top_k)
|
||||
|
||||
# Verify search results
|
||||
assert len(results) == top_k, f"Expected {top_k} results for query '{query}'"
|
||||
assert all(result.score is not None for result in results), (
|
||||
"All results should have scores"
|
||||
)
|
||||
assert all(result.score != float("-inf") for result in results), (
|
||||
"No result should have -inf score"
|
||||
)
|
||||
assert all(result.text is not None for result in results), (
|
||||
"All results should have text"
|
||||
)
|
||||
|
||||
# Scores should be in descending order (higher similarity first)
|
||||
scores = [result.score for result in results]
|
||||
assert scores == sorted(scores, reverse=True), (
|
||||
"Results should be sorted by score descending"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip DiskANN partition tests in CI - requires specific hardware and large memory",
|
||||
)
|
||||
def test_diskann_medoid_and_norm_files():
|
||||
"""Test that medoid and max_base_norm files are correctly generated and used."""
|
||||
import struct
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_medoid_norm.leann")
|
||||
|
||||
# Small but sufficient dataset
|
||||
texts = [f"Test document {i} with content about subject {i % 10}." for i in range(200)]
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
index_dir = Path(index_path).parent
|
||||
index_prefix = Path(index_path).stem
|
||||
|
||||
# Test medoids file
|
||||
medoids_file = index_dir / f"{index_prefix}_disk.index_medoids.bin"
|
||||
assert medoids_file.exists(), "Medoids file should be generated"
|
||||
|
||||
# Read and validate medoids file format
|
||||
with open(medoids_file, "rb") as f:
|
||||
nshards = struct.unpack("<I", f.read(4))[0]
|
||||
one_val = struct.unpack("<I", f.read(4))[0]
|
||||
medoid_id = struct.unpack("<I", f.read(4))[0]
|
||||
|
||||
assert nshards == 1, "Single-shot build should have 1 shard"
|
||||
assert one_val == 1, "Expected value should be 1"
|
||||
assert medoid_id >= 0, "Medoid ID should be valid (not hardcoded 0)"
|
||||
|
||||
# Test max_base_norm file
|
||||
norm_file = index_dir / f"{index_prefix}_disk.index_max_base_norm.bin"
|
||||
assert norm_file.exists(), "Max base norm file should be generated"
|
||||
|
||||
# Read and validate norm file
|
||||
with open(norm_file, "rb") as f:
|
||||
npts = struct.unpack("<I", f.read(4))[0]
|
||||
ndims = struct.unpack("<I", f.read(4))[0]
|
||||
norm_val = struct.unpack("<f", f.read(4))[0]
|
||||
|
||||
assert npts == 1, "Should have 1 norm point"
|
||||
assert ndims == 1, "Should have 1 dimension"
|
||||
assert norm_val > 0, "Norm value should be positive"
|
||||
assert norm_val != float("inf"), "Norm value should be finite"
|
||||
|
||||
# Test that search works with these files
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("test subject", top_k=3)
|
||||
|
||||
# Verify that scores are not -inf (which indicates norm file was loaded correctly)
|
||||
assert len(results) > 0
|
||||
assert all(result.score != float("-inf") for result in results), (
|
||||
"Scores should not be -inf when norm file is correct"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true",
|
||||
reason="Skip performance comparison in CI - requires significant compute time",
|
||||
)
|
||||
def test_diskann_vs_hnsw_performance():
|
||||
"""Compare DiskANN (with partition) vs HNSW performance."""
|
||||
import time
|
||||
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Test data
|
||||
texts = [
|
||||
f"Performance test document {i} covering topic {i % 20} in detail." for i in range(1000)
|
||||
]
|
||||
query = "performance topic test"
|
||||
|
||||
# Test DiskANN with partitioning
|
||||
diskann_path = str(Path(temp_dir) / "perf_diskann.leann")
|
||||
diskann_builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
diskann_builder.add_text(text)
|
||||
|
||||
start_time = time.time()
|
||||
diskann_builder.build_index(diskann_path)
|
||||
|
||||
# Test HNSW
|
||||
hnsw_path = str(Path(temp_dir) / "perf_hnsw.leann")
|
||||
hnsw_builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=True,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
hnsw_builder.add_text(text)
|
||||
|
||||
start_time = time.time()
|
||||
hnsw_builder.build_index(hnsw_path)
|
||||
|
||||
# Compare search performance
|
||||
diskann_searcher = LeannSearcher(diskann_path)
|
||||
hnsw_searcher = LeannSearcher(hnsw_path)
|
||||
|
||||
# Warm up searches
|
||||
diskann_searcher.search(query, top_k=5)
|
||||
hnsw_searcher.search(query, top_k=5)
|
||||
|
||||
# Timed searches
|
||||
start_time = time.time()
|
||||
diskann_results = diskann_searcher.search(query, top_k=10)
|
||||
diskann_search_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
hnsw_results = hnsw_searcher.search(query, top_k=10)
|
||||
hnsw_search_time = time.time() - start_time
|
||||
|
||||
# Basic assertions
|
||||
assert len(diskann_results) == 10
|
||||
assert len(hnsw_results) == 10
|
||||
assert all(r.score != float("-inf") for r in diskann_results)
|
||||
assert all(r.score != float("-inf") for r in hnsw_results)
|
||||
|
||||
# Performance ratio (informational)
|
||||
if hnsw_search_time > 0:
|
||||
speed_ratio = hnsw_search_time / diskann_search_time
|
||||
print(f"DiskANN search time: {diskann_search_time:.4f}s")
|
||||
print(f"HNSW search time: {hnsw_search_time:.4f}s")
|
||||
print(f"DiskANN is {speed_ratio:.2f}x faster than HNSW")
|
||||
125
tests/test_document_rag.py
Normal file
125
tests/test_document_rag.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Test document_rag functionality using pytest.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from test_timeout import ci_timeout
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_data_dir():
|
||||
"""Return the path to test data directory."""
|
||||
return Path("data")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
def test_document_rag_simulated(test_data_dir):
|
||||
"""Test document_rag with simulated LLM."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a subdirectory that doesn't exist yet to force index creation
|
||||
index_dir = Path(temp_dir) / "test_index"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"apps/document_rag.py",
|
||||
"--llm",
|
||||
"simulated",
|
||||
"--embedding-model",
|
||||
"facebook/contriever",
|
||||
"--embedding-mode",
|
||||
"sentence-transformers",
|
||||
"--index-dir",
|
||||
str(index_dir),
|
||||
"--data-dir",
|
||||
str(test_data_dir),
|
||||
"--query",
|
||||
"What is Pride and Prejudice about?",
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
|
||||
|
||||
# Check return code
|
||||
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||
|
||||
# Verify output
|
||||
output = result.stdout + result.stderr
|
||||
assert "Index saved to" in output or "Using existing index" in output
|
||||
assert "This is a simulated answer" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip OpenAI embedding tests in CI to avoid hanging"
|
||||
)
|
||||
@ci_timeout(60) # 60 second timeout to avoid hanging on OpenAI API calls
|
||||
def test_document_rag_openai(test_data_dir):
|
||||
"""Test document_rag with OpenAI embeddings."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use a subdirectory that doesn't exist yet to force index creation
|
||||
index_dir = Path(temp_dir) / "test_index_openai"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"apps/document_rag.py",
|
||||
"--llm",
|
||||
"simulated", # Use simulated LLM to avoid GPT-4 costs
|
||||
"--embedding-model",
|
||||
"text-embedding-3-small",
|
||||
"--embedding-mode",
|
||||
"openai",
|
||||
"--index-dir",
|
||||
str(index_dir),
|
||||
"--data-dir",
|
||||
str(test_data_dir),
|
||||
"--query",
|
||||
"What is Pride and Prejudice about?",
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
|
||||
|
||||
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||
|
||||
# Verify cosine distance was used
|
||||
output = result.stdout + result.stderr
|
||||
assert any(
|
||||
msg in output
|
||||
for msg in [
|
||||
"distance_metric='cosine'",
|
||||
"Automatically setting distance_metric='cosine'",
|
||||
"Using cosine distance",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_document_rag_error_handling(test_data_dir):
|
||||
"""Test document_rag with invalid parameters."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"apps/document_rag.py",
|
||||
"--llm",
|
||||
"invalid_llm_type",
|
||||
"--index-dir",
|
||||
temp_dir,
|
||||
"--data-dir",
|
||||
str(test_data_dir),
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
|
||||
# Should fail with invalid LLM type
|
||||
assert result.returncode != 0
|
||||
assert "invalid choice" in result.stderr or "invalid_llm_type" in result.stderr
|
||||
178
tests/test_readme_examples.py
Normal file
178
tests/test_readme_examples.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Test examples from README.md to ensure documentation is accurate.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from test_timeout import ci_timeout
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
@ci_timeout(90) # 90 second timeout for this comprehensive test
|
||||
def test_readme_basic_example(backend_name):
|
||||
"""Test the basic example from README.md with both backends."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
# This is the exact code from README (with smaller model for CI)
|
||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||
from leann.api import SearchResult
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
INDEX_PATH = str(Path(temp_dir) / f"demo_{backend_name}.leann")
|
||||
|
||||
# Build an index
|
||||
# In CI, use a smaller model to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2", # Smaller model
|
||||
dimensions=384, # Smaller dimensions
|
||||
)
|
||||
else:
|
||||
builder = LeannBuilder(backend_name=backend_name)
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana-crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Verify index was created
|
||||
# The index path should be a directory containing index files
|
||||
index_dir = Path(INDEX_PATH).parent
|
||||
assert index_dir.exists()
|
||||
# Check that index files were created
|
||||
index_files = list(index_dir.glob(f"{Path(INDEX_PATH).stem}.*"))
|
||||
assert len(index_files) > 0
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Verify search results
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], SearchResult)
|
||||
assert results[0].score != float("-inf"), (
|
||||
f"should return valid scores, got {results[0].score}"
|
||||
)
|
||||
# The second text about banana-crocodile should be more relevant
|
||||
assert "banana" in results[0].text or "crocodile" in results[0].text
|
||||
|
||||
# Chat with your data (using simulated LLM to avoid external dependencies)
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "simulated"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
|
||||
# Verify chat works
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
def test_readme_imports():
|
||||
"""Test that the imports shown in README work correctly."""
|
||||
# These are the imports shown in README
|
||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
# Verify they are the correct types
|
||||
assert callable(LeannBuilder)
|
||||
assert callable(LeannSearcher)
|
||||
assert callable(LeannChat)
|
||||
|
||||
|
||||
@ci_timeout(60) # 60 second timeout
|
||||
def test_backend_options():
|
||||
"""Test different backend options mentioned in documentation."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use smaller model in CI to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
model_args = {
|
||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"dimensions": 384,
|
||||
}
|
||||
else:
|
||||
model_args = {}
|
||||
|
||||
# Test HNSW backend (as shown in README)
|
||||
hnsw_path = str(Path(temp_dir) / "test_hnsw.leann")
|
||||
builder_hnsw = LeannBuilder(backend_name="hnsw", **model_args)
|
||||
builder_hnsw.add_text("Test document for HNSW backend")
|
||||
builder_hnsw.build_index(hnsw_path)
|
||||
assert Path(hnsw_path).parent.exists()
|
||||
assert len(list(Path(hnsw_path).parent.glob(f"{Path(hnsw_path).stem}.*"))) > 0
|
||||
|
||||
# Test DiskANN backend (mentioned as available option)
|
||||
diskann_path = str(Path(temp_dir) / "test_diskann.leann")
|
||||
builder_diskann = LeannBuilder(backend_name="diskann", **model_args)
|
||||
builder_diskann.add_text("Test document for DiskANN backend")
|
||||
builder_diskann.build_index(diskann_path)
|
||||
assert Path(diskann_path).parent.exists()
|
||||
assert len(list(Path(diskann_path).parent.glob(f"{Path(diskann_path).stem}.*"))) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
@ci_timeout(75) # 75 second timeout for LLM tests
|
||||
def test_llm_config_simulated(backend_name):
|
||||
"""Test simulated LLM configuration option with both backends."""
|
||||
# Skip on macOS CI due to MPS environment issues with all-MiniLM-L6-v2
|
||||
if os.environ.get("CI") == "true" and platform.system() == "Darwin":
|
||||
pytest.skip("Skipping on macOS CI due to MPS environment issues with all-MiniLM-L6-v2")
|
||||
|
||||
# Skip DiskANN tests in CI due to hardware requirements
|
||||
if os.environ.get("CI") == "true" and backend_name == "diskann":
|
||||
pytest.skip("Skip DiskANN tests in CI - requires specific hardware and large memory")
|
||||
|
||||
from leann import LeannBuilder, LeannChat
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Build a simple index
|
||||
index_path = str(Path(temp_dir) / f"test_{backend_name}.leann")
|
||||
# Use smaller model in CI to avoid memory issues
|
||||
if os.environ.get("CI") == "true":
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
dimensions=384,
|
||||
)
|
||||
else:
|
||||
builder = LeannBuilder(backend_name=backend_name)
|
||||
builder.add_text("Test document for LLM testing")
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test simulated LLM config
|
||||
llm_config = {"type": "simulated"}
|
||||
chat = LeannChat(index_path, llm_config=llm_config)
|
||||
response = chat.ask("What is this document about?", top_k=1)
|
||||
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires HF model download and may timeout")
|
||||
def test_llm_config_hf():
|
||||
"""Test HuggingFace LLM configuration option."""
|
||||
from leann import LeannBuilder, LeannChat
|
||||
|
||||
pytest.importorskip("transformers") # Skip if transformers not installed
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Build a simple index
|
||||
index_path = str(Path(temp_dir) / "test.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("Test document for LLM testing")
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test HF LLM config
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-0.6B"}
|
||||
chat = LeannChat(index_path, llm_config=llm_config)
|
||||
response = chat.ask("What is this document about?", top_k=1)
|
||||
|
||||
assert isinstance(response, str)
|
||||
assert len(response) > 0
|
||||
129
tests/test_timeout.py
Normal file
129
tests/test_timeout.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Test timeout utilities for CI environments.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def timeout_test(seconds: int = 30):
|
||||
"""
|
||||
Decorator to add timeout to test functions, especially useful in CI environments.
|
||||
|
||||
Args:
|
||||
seconds: Timeout in seconds (default: 30)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Only apply timeout in CI environment
|
||||
if os.environ.get("CI") != "true":
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Set up timeout handler
|
||||
def timeout_handler(signum, frame):
|
||||
print(f"\n❌ Test {func.__name__} timed out after {seconds} seconds in CI!")
|
||||
print("This usually indicates a hanging process or infinite loop.")
|
||||
# Try to cleanup any hanging processes
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
subprocess.run(
|
||||
["pkill", "-f", "embedding_server"], capture_output=True, timeout=2
|
||||
)
|
||||
subprocess.run(
|
||||
["pkill", "-f", "hnsw_embedding"], capture_output=True, timeout=2
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Exit with timeout code
|
||||
sys.exit(124) # Standard timeout exit code
|
||||
|
||||
# Set signal handler and alarm
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(seconds)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
signal.alarm(0) # Cancel alarm
|
||||
return result
|
||||
except Exception:
|
||||
signal.alarm(0) # Cancel alarm on exception
|
||||
raise
|
||||
finally:
|
||||
# Restore original handler
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def ci_timeout(seconds: int = 60):
|
||||
"""
|
||||
Timeout decorator specifically for CI environments.
|
||||
Uses threading for more reliable timeout handling.
|
||||
|
||||
Args:
|
||||
seconds: Timeout in seconds (default: 60)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Only apply in CI
|
||||
if os.environ.get("CI") != "true":
|
||||
return func(*args, **kwargs)
|
||||
|
||||
import threading
|
||||
|
||||
result = [None]
|
||||
exception = [None]
|
||||
finished = threading.Event()
|
||||
|
||||
def target():
|
||||
try:
|
||||
result[0] = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exception[0] = e
|
||||
finally:
|
||||
finished.set()
|
||||
|
||||
# Start function in thread
|
||||
thread = threading.Thread(target=target, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wait for completion or timeout
|
||||
if not finished.wait(timeout=seconds):
|
||||
print(f"\n💥 CI TIMEOUT: Test {func.__name__} exceeded {seconds}s limit!")
|
||||
print("This usually indicates hanging embedding servers or infinite loops.")
|
||||
|
||||
# Try to cleanup embedding servers
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
subprocess.run(
|
||||
["pkill", "-9", "-f", "embedding_server"], capture_output=True, timeout=2
|
||||
)
|
||||
subprocess.run(
|
||||
["pkill", "-9", "-f", "hnsw_embedding"], capture_output=True, timeout=2
|
||||
)
|
||||
print("Attempted to kill hanging embedding servers.")
|
||||
except Exception as e:
|
||||
print(f"Cleanup failed: {e}")
|
||||
|
||||
# Raise TimeoutError instead of sys.exit for better pytest integration
|
||||
raise TimeoutError(f"Test {func.__name__} timed out after {seconds} seconds")
|
||||
|
||||
if exception[0]:
|
||||
raise exception[0]
|
||||
|
||||
return result[0]
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user