Compare commits
111 Commits
v0.1.14
...
feature/gr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9e5d5d6aa | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
a437f558a3 | ||
|
|
742c9baabc | ||
|
|
60eef4b440 | ||
|
|
f2c5355c73 | ||
|
|
439debbd3f | ||
|
|
3ff5aac8e0 | ||
|
|
a35bfb0354 | ||
|
|
a6dad47280 | ||
|
|
67fef60466 | ||
|
|
131f10b286 | ||
|
|
e3762458fc | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
05e1efa00a | ||
|
|
6363fc5f83 | ||
|
|
319dc34a24 | ||
|
|
72a5993f02 | ||
|
|
250272a3be | ||
|
|
042da1fe09 | ||
|
|
2d9c183ebb | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
a8421c0475 | ||
|
|
0ec00e1a60 | ||
|
|
777b5fed01 | ||
|
|
440ad6e816 | ||
|
|
5567302316 | ||
|
|
8714472cd8 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c799d61a5a | ||
|
|
c96d653072 | ||
|
|
e409933149 | ||
|
|
bc31876a9f | ||
|
|
e421c44b8b | ||
|
|
af69aa0508 | ||
|
|
575b354976 | ||
|
|
65bbff1d93 | ||
|
|
df798d350d | ||
|
|
3fa6b2aa17 | ||
|
|
ba95554fe7 | ||
|
|
677eb0bae3 | ||
|
|
9cdfcec331 | ||
|
|
f30d1a2530 | ||
|
|
df69a49123 | ||
|
|
65b54ff905 | ||
|
|
4db3e94f35 | ||
|
|
a2568f3ddc | ||
|
|
45bdad4fa7 | ||
|
|
8b538d1ef9 | ||
|
|
ada8bcbc70 | ||
|
|
6061e8f2de | ||
|
|
9842ad8330 | ||
|
|
7d920f9071 | ||
|
|
f28f15000c | ||
|
|
1d657fd9f6 | ||
|
|
d217adbe40 | ||
|
|
f790ec634f | ||
|
|
b8da9d7b12 | ||
|
|
0cb0463929 | ||
|
|
b982241249 | ||
|
|
c66f197e1d | ||
|
|
4a1353761a | ||
|
|
a72090d2ab | ||
|
|
669e622430 | ||
|
|
77d7b60a61 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 |
11
.github/workflows/build-and-publish.yml
vendored
11
.github/workflows/build-and-publish.yml
vendored
@@ -5,7 +5,16 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
debug_enabled:
|
||||
type: boolean
|
||||
description: 'Run with tmate debugging enabled (SSH access to runner)'
|
||||
required: false
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
debug_enabled: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled || false }}
|
||||
|
||||
344
.github/workflows/build-reusable.yml
vendored
344
.github/workflows/build-reusable.yml
vendored
@@ -8,6 +8,11 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
debug_enabled:
|
||||
description: 'Enable tmate debugging session for troubleshooting'
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
@@ -17,23 +22,23 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
|
||||
- name: Install ruff
|
||||
run: |
|
||||
uv tool install ruff
|
||||
|
||||
uv tool install ruff==0.12.7
|
||||
|
||||
- name: Run ruff check
|
||||
run: |
|
||||
ruff check .
|
||||
|
||||
|
||||
- name: Run ruff format check
|
||||
run: |
|
||||
ruff format --check .
|
||||
@@ -65,40 +70,41 @@ jobs:
|
||||
- os: macos-latest
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
|
||||
# Don't install LLVM, use system clang for better compatibility
|
||||
brew install libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
@@ -107,41 +113,46 @@ jobs:
|
||||
else
|
||||
uv pip install --system delocate
|
||||
fi
|
||||
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
# Build core (platform independent) on all platforms for consistency
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
# Use system clang instead of homebrew LLVM for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# sgesdd_ is only available on macOS 13.3+
|
||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||
uv build --wheel --python python
|
||||
else
|
||||
uv build --wheel --python python
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Build meta package (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
|
||||
# Build meta package (platform independent) on all platforms
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
@@ -153,16 +164,21 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
# Repair DiskANN wheel
|
||||
|
||||
# Repair DiskANN wheel - use show first to debug
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
echo "Checking DiskANN wheel contents before repair:"
|
||||
unzip -l dist/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found"
|
||||
auditwheel show dist/*.whl || echo "auditwheel show failed"
|
||||
auditwheel repair dist/*.whl -w dist_repaired
|
||||
echo "Checking DiskANN wheel contents after repair:"
|
||||
unzip -l dist_repaired/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found after repair"
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
@@ -174,7 +190,7 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
@@ -183,14 +199,262 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create a virtual environment with the correct Python version
|
||||
uv venv --python python${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Install the built wheels directly to ensure we use locally built packages
|
||||
# Use only locally built wheels on all platforms for full consistency
|
||||
FIND_LINKS="--find-links packages/leann-core/dist --find-links packages/leann/dist"
|
||||
FIND_LINKS="$FIND_LINKS --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist"
|
||||
|
||||
uv pip install leann-core leann leann-backend-hnsw leann-backend-diskann \
|
||||
$FIND_LINKS --force-reinstall
|
||||
|
||||
# Install test dependencies using extras
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
# Debug: Check if _diskannpy module is installed correctly
|
||||
echo "Checking installed DiskANN module structure:"
|
||||
python -c "import leann_backend_diskann; print('leann_backend_diskann location:', leann_backend_diskann.__file__)" || echo "Failed to import leann_backend_diskann"
|
||||
python -c "from leann_backend_diskann import _diskannpy; print('_diskannpy imported successfully')" || echo "Failed to import _diskannpy"
|
||||
ls -la $(python -c "import leann_backend_diskann; import os; print(os.path.dirname(leann_backend_diskann.__file__))" 2>/dev/null) 2>/dev/null || echo "Failed to list module directory"
|
||||
|
||||
# Extra debugging for Python 3.13
|
||||
if [[ "${{ matrix.python }}" == "3.13" ]]; then
|
||||
echo "=== Python 3.13 Debug Info ==="
|
||||
echo "Python version details:"
|
||||
python --version
|
||||
python -c "import sys; print(f'sys.version_info: {sys.version_info}')"
|
||||
|
||||
echo "Pytest version:"
|
||||
python -m pytest --version
|
||||
|
||||
echo "Testing basic pytest collection:"
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
timeout --signal=INT 10 python -m pytest --collect-only tests/test_ci_minimal.py -v || echo "Collection timed out or failed"
|
||||
else
|
||||
# No timeout on macOS/Windows
|
||||
python -m pytest --collect-only tests/test_ci_minimal.py -v || echo "Collection failed"
|
||||
fi
|
||||
|
||||
echo "Testing single simple test:"
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
timeout --signal=INT 10 python -m pytest tests/test_ci_minimal.py::test_package_imports --full-trace -v || echo "Simple test timed out or failed"
|
||||
else
|
||||
# No timeout on macOS/Windows
|
||||
python -m pytest tests/test_ci_minimal.py::test_package_imports --full-trace -v || echo "Simple test failed"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Enable tmate debugging session if requested
|
||||
- name: Setup tmate session for debugging
|
||||
if: ${{ inputs.debug_enabled }}
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
with:
|
||||
detached: true
|
||||
timeout-minutes: 30
|
||||
limit-access-to-actor: true
|
||||
|
||||
- name: Run tests with pytest
|
||||
# Timeout hierarchy:
|
||||
# 1. Individual test timeout: 20s (see pyproject.toml markers)
|
||||
# 2. Pytest session timeout: 300s (see pyproject.toml [tool.pytest.ini_options])
|
||||
# 3. Outer shell timeout: 360s (300s + 60s buffer for cleanup)
|
||||
# 4. GitHub Actions job timeout: 6 hours (default)
|
||||
env:
|
||||
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
HF_HUB_DISABLE_SYMLINKS: 1
|
||||
TOKENIZERS_PARALLELISM: false
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Define comprehensive diagnostic function
|
||||
diag() {
|
||||
echo "===== COMPREHENSIVE DIAGNOSTICS BEGIN ====="
|
||||
date
|
||||
echo ""
|
||||
echo "### Current Shell Info ###"
|
||||
echo "Shell PID: $$"
|
||||
echo "Shell PPID: $PPID"
|
||||
echo "Current directory: $(pwd)"
|
||||
echo ""
|
||||
|
||||
echo "### Process Tree (full) ###"
|
||||
pstree -ap 2>/dev/null || ps auxf || true
|
||||
echo ""
|
||||
|
||||
echo "### All Python/Pytest Processes ###"
|
||||
ps -ef | grep -E 'python|pytest' | grep -v grep || true
|
||||
echo ""
|
||||
|
||||
echo "### Embedding Server Processes ###"
|
||||
ps -ef | grep -E 'embedding|zmq|diskann' | grep -v grep || true
|
||||
echo ""
|
||||
|
||||
echo "### Network Listeners ###"
|
||||
ss -ltnp 2>/dev/null || netstat -ltn 2>/dev/null || true
|
||||
echo ""
|
||||
|
||||
echo "### Open File Descriptors (lsof) ###"
|
||||
lsof -p $$ 2>/dev/null | head -20 || true
|
||||
echo ""
|
||||
|
||||
echo "### Zombie Processes ###"
|
||||
ps aux | grep '<defunct>' || echo "No zombie processes"
|
||||
echo ""
|
||||
|
||||
echo "### Current Jobs ###"
|
||||
jobs -l || true
|
||||
echo ""
|
||||
|
||||
echo "### /proc/PID/fd for current shell ###"
|
||||
ls -la /proc/$$/fd 2>/dev/null || true
|
||||
echo ""
|
||||
|
||||
echo "===== COMPREHENSIVE DIAGNOSTICS END ====="
|
||||
}
|
||||
|
||||
# Enable verbose logging for debugging
|
||||
export PYTHONUNBUFFERED=1
|
||||
export PYTEST_CURRENT_TEST=1
|
||||
|
||||
# Run all tests with extensive logging
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
echo "🚀 Starting Linux test execution with timeout..."
|
||||
echo "Current time: $(date)"
|
||||
echo "Shell PID: $$"
|
||||
echo "Python: $(python --version)"
|
||||
echo "Pytest: $(pytest --version)"
|
||||
|
||||
# Show environment variables for debugging
|
||||
echo "📦 Environment variables:"
|
||||
env | grep -E "PYTHON|PYTEST|CI|RUNNER" | sort
|
||||
|
||||
# Set trap for diagnostics
|
||||
trap diag INT TERM EXIT
|
||||
|
||||
echo "📋 Pre-test diagnostics:"
|
||||
ps -ef | grep -E 'python|pytest' | grep -v grep || echo "No python/pytest processes before test"
|
||||
|
||||
# Check for any listening ports before test
|
||||
echo "🔌 Pre-test network state:"
|
||||
ss -ltn 2>/dev/null | grep -E "555[0-9]|556[0-9]" || echo "No embedding server ports open"
|
||||
|
||||
# Set timeouts - outer must be larger than pytest's internal timeout
|
||||
# IMPORTANT: Keep PYTEST_TIMEOUT_SEC in sync with pyproject.toml [tool.pytest.ini_options] timeout
|
||||
PYTEST_TIMEOUT_SEC=${PYTEST_TIMEOUT_SEC:-300} # Default 300s, matches pyproject.toml
|
||||
BUFFER_SEC=${TIMEOUT_BUFFER_SEC:-60} # Buffer for cleanup after pytest timeout
|
||||
OUTER_TIMEOUT_SEC=${OUTER_TIMEOUT_SEC:-$((PYTEST_TIMEOUT_SEC + BUFFER_SEC))}
|
||||
|
||||
echo "⏰ Timeout configuration:"
|
||||
echo " - Pytest internal timeout: ${PYTEST_TIMEOUT_SEC}s (from pyproject.toml)"
|
||||
echo " - Cleanup buffer: ${BUFFER_SEC}s"
|
||||
echo " - Outer shell timeout: ${OUTER_TIMEOUT_SEC}s (${PYTEST_TIMEOUT_SEC}s + ${BUFFER_SEC}s buffer)"
|
||||
echo " - This ensures pytest can complete its own timeout handling and cleanup"
|
||||
|
||||
echo "🏃 Running pytest with ${OUTER_TIMEOUT_SEC}s outer timeout..."
|
||||
|
||||
# Export for inner shell
|
||||
export PYTEST_TIMEOUT_SEC OUTER_TIMEOUT_SEC BUFFER_SEC
|
||||
|
||||
timeout --preserve-status --signal=INT --kill-after=10 ${OUTER_TIMEOUT_SEC} bash -c '
|
||||
echo "⏱️ Pytest starting at: $(date)"
|
||||
echo "Running command: pytest tests/ -vv --maxfail=3 --tb=short --capture=no"
|
||||
|
||||
# Run pytest with maximum verbosity and no output capture
|
||||
pytest tests/ -vv --maxfail=3 --tb=short --capture=no --log-cli-level=DEBUG 2>&1 | tee pytest.log
|
||||
PYTEST_EXIT=${PIPESTATUS[0]}
|
||||
|
||||
echo "✅ Pytest finished at: $(date) with exit code: $PYTEST_EXIT"
|
||||
echo "Last 20 lines of pytest output:"
|
||||
tail -20 pytest.log || true
|
||||
|
||||
# Immediately check for leftover processes
|
||||
echo "🔍 Post-pytest process check:"
|
||||
ps -ef | grep -E "python|pytest|embedding" | grep -v grep || echo "No leftover processes"
|
||||
|
||||
# Clean up any children before exit
|
||||
echo "🧹 Cleaning up child processes..."
|
||||
pkill -TERM -P $$ 2>/dev/null || true
|
||||
sleep 0.5
|
||||
pkill -KILL -P $$ 2>/dev/null || true
|
||||
|
||||
echo "📊 Final check before exit:"
|
||||
ps -ef | grep -E "python|pytest|embedding" | grep -v grep || echo "All clean"
|
||||
|
||||
exit $PYTEST_EXIT
|
||||
'
|
||||
|
||||
EXIT_CODE=$?
|
||||
echo "🔚 Timeout command exited with code: $EXIT_CODE"
|
||||
|
||||
if [ $EXIT_CODE -eq 124 ]; then
|
||||
echo "⚠️ TIMEOUT TRIGGERED - Tests took more than ${OUTER_TIMEOUT_SEC} seconds!"
|
||||
echo "📸 Capturing full diagnostics..."
|
||||
diag
|
||||
|
||||
# Run diagnostic script if available
|
||||
if [ -f scripts/diagnose_hang.sh ]; then
|
||||
echo "🔍 Running diagnostic script..."
|
||||
bash scripts/diagnose_hang.sh || true
|
||||
fi
|
||||
|
||||
# More aggressive cleanup
|
||||
echo "💀 Killing all Python processes owned by runner..."
|
||||
pkill -9 -u runner python || true
|
||||
pkill -9 -u runner pytest || true
|
||||
elif [ $EXIT_CODE -ne 0 ]; then
|
||||
echo "❌ Tests failed with exit code: $EXIT_CODE"
|
||||
else
|
||||
echo "✅ All tests passed!"
|
||||
fi
|
||||
|
||||
# Always show final state
|
||||
echo "📍 Final state check:"
|
||||
ps -ef | grep -E 'python|pytest|embedding' | grep -v grep || echo "No Python processes remaining"
|
||||
|
||||
exit $EXIT_CODE
|
||||
else
|
||||
# For macOS/Windows, run without GNU timeout
|
||||
echo "🚀 Running tests on $RUNNER_OS..."
|
||||
pytest tests/ -vv --maxfail=3 --tb=short --capture=no --log-cli-level=INFO
|
||||
fi
|
||||
|
||||
# Provide tmate session on test failure for debugging
|
||||
- name: Setup tmate session on failure
|
||||
if: ${{ failure() && (inputs.debug_enabled || contains(github.event.head_commit.message, '[debug]')) }}
|
||||
uses: mxschmitt/action-tmate@v3
|
||||
with:
|
||||
timeout-minutes: 30
|
||||
limit-access-to-actor: true
|
||||
|
||||
- name: Run sanity checks (optional)
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run distance function tests if available
|
||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||
echo "Running distance function sanity checks..."
|
||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
path: packages/*/dist/
|
||||
|
||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Link Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1"
|
||||
|
||||
jobs:
|
||||
link-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
34
.github/workflows/release-manual.yml
vendored
34
.github/workflows/release-manual.yml
vendored
@@ -16,10 +16,10 @@ jobs:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
# Remove 'v' prefix if present for validation
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||
echo "Current version: $CURRENT_VERSION"
|
||||
echo "Target version: ${{ inputs.version }}"
|
||||
|
||||
|
||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
fi
|
||||
|
||||
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
build-packages:
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
needs: update-version
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: 'main'
|
||||
ref: 'main'
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
@@ -69,26 +69,26 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
ref: 'main'
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
@@ -98,12 +98,12 @@ jobs:
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
# Check if tag already exists
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
git push origin "v${{ inputs.version }}"
|
||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
|
||||
# Check if release already exists
|
||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||
@@ -126,4 +126,4 @@ jobs:
|
||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
20
.gitignore
vendored
20
.gitignore
vendored
@@ -9,7 +9,7 @@ demo/indices/
|
||||
outputs/
|
||||
*.pkl
|
||||
*.pdf
|
||||
*.idx
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
lm_eval.egg-info/
|
||||
@@ -34,11 +34,15 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/*
|
||||
!examples/data/2501.14312v1 (1).pdf
|
||||
!examples/data/2506.08276v1.pdf
|
||||
!examples/data/PrideandPrejudice.txt
|
||||
!examples/data/README.md
|
||||
data/*
|
||||
!data/2501.14312v1 (1).pdf
|
||||
!data/2506.08276v1.pdf
|
||||
!data/PrideandPrejudice.txt
|
||||
!data/huawei_pangu.md
|
||||
!data/ground_truth/
|
||||
!data/indices/
|
||||
!data/queries/
|
||||
!data/.gitattributes
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
@@ -85,4 +89,6 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
|
||||
batchtest.py
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
@@ -9,15 +9,8 @@ repos:
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.1.1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.1
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
377
README.md
377
README.md
@@ -6,6 +6,7 @@
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
|
||||
</p>
|
||||
|
||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||
@@ -16,7 +17,10 @@ LEANN is an innovative vector database that democratizes personal AI. Transform
|
||||
|
||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||
|
||||
|
||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
||||
|
||||
|
||||
|
||||
@@ -26,21 +30,55 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||
</p>
|
||||
|
||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||
|
||||
|
||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||
|
||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||
|
||||
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||
|
||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||
|
||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||
|
||||
## Installation
|
||||
> `pip leann` coming soon!
|
||||
|
||||
### 📦 Prerequisites: Install uv
|
||||
|
||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
### 🚀 Quick Install
|
||||
|
||||
Clone the repository to access all examples and try amazing applications,
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
```
|
||||
|
||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
||||
</summary>
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
@@ -48,27 +86,65 @@ git submodule update --init --recursive
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
# Install uv first if you don't have it:
|
||||
# curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
|
||||
# Install with HNSW backend (default, recommended for most users)
|
||||
uv sync
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
**Ollama Setup (Recommended for full privacy):**
|
||||
|
||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||
## Quick Start
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
|
||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
|
||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
### Generation Model Setup
|
||||
|
||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||
|
||||
<details>
|
||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||
|
||||
Set your OpenAI API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||
|
||||
**macOS:**
|
||||
|
||||
@@ -80,6 +156,7 @@ ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
**Linux:**
|
||||
|
||||
```bash
|
||||
# Install Ollama
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
@@ -91,45 +168,54 @@ ollama serve &
|
||||
ollama pull llama3.2:1b
|
||||
```
|
||||
|
||||
## Quick Start in 30s
|
||||
</details>
|
||||
|
||||
Our declarative API makes RAG as easy as writing a config file.
|
||||
[Try in this ipynb file →](demo.ipynb) [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
||||
### ⭐ Flexible Configuration
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||
|
||||
# 1. Build the index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
builder.add_text("C# is a powerful programming language")
|
||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||
builder.build_index("knowledge.leann")
|
||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
||||
|
||||
# 3. Chat with LEANN using retrieved results
|
||||
llm_config = {
|
||||
"type": "ollama",
|
||||
"model": "llama3.2:1b"
|
||||
}
|
||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
||||
|
||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||
response = chat.ask(
|
||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||
top_k=2,
|
||||
)
|
||||
```bash
|
||||
# Core Parameters (General preprocessing for all examples)
|
||||
--index-dir DIR # Directory to store the index (default: current directory)
|
||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
||||
--force-rebuild # Force rebuild index even if it exists
|
||||
|
||||
# Embedding Parameters
|
||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
||||
|
||||
# LLM Parameters (Text generation models)
|
||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
||||
|
||||
# Search Parameters
|
||||
--top-k N # Number of results to retrieve (default: 20)
|
||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
||||
|
||||
# Chunking Parameters
|
||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
||||
|
||||
# Index Building Parameters
|
||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
||||
--graph-degree N # Graph degree for index construction (default: 32)
|
||||
--build-complexity N # Build complexity for index construction (default: 64)
|
||||
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
|
||||
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
|
||||
```
|
||||
|
||||
## RAG on Everything!
|
||||
</details>
|
||||
|
||||
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||
|
||||
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
||||
|
||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||
|
||||
@@ -137,51 +223,65 @@ Ask questions directly about your personal PDFs, documents, and any directory co
|
||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
||||
|
||||
```bash
|
||||
# Drop your PDFs, .txt, .md files into examples/data/
|
||||
uv run ./examples/main_cli_example.py
|
||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
||||
```
|
||||
|
||||
```
|
||||
# Or use python directly
|
||||
source .venv/bin/activate
|
||||
python ./examples/main_cli_example.py
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
--data-dir DIR # Directory containing documents to process (default: data)
|
||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
||||
```
|
||||
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Process all documents with larger chunks for academic papers
|
||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
||||
|
||||
# Filter only markdown and Python files with smaller chunks
|
||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||
|
||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||
</p>
|
||||
|
||||
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||
|
||||
```bash
|
||||
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
||||
```
|
||||
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default mail path (works for most macOS setups)
|
||||
python examples/mail_reader_leann.py
|
||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
||||
--include-html # Include HTML content in processing (useful for newsletters)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search work emails from a specific account
|
||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
||||
|
||||
# Process all emails (may take time but indexes everything)
|
||||
python examples/mail_reader_leann.py --max-emails -1
|
||||
|
||||
# Limit number of emails processed (useful for testing)
|
||||
python examples/mail_reader_leann.py --max-emails 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
|
||||
# Find all receipts and order confirmations (includes HTML)
|
||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -202,25 +302,25 @@ Once the index is built, you can ask questions like:
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
||||
```
|
||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default Chrome profile (auto-finds all profiles)
|
||||
python examples/google_history_reader_leann.py
|
||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
||||
```
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search academic research from your browsing history
|
||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
||||
|
||||
# Limit number of history entries processed (useful for testing)
|
||||
python examples/google_history_reader_leann.py --max-entries 500
|
||||
|
||||
# Run a single query
|
||||
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
|
||||
# Track competitor analysis across work profile
|
||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -260,7 +360,7 @@ Once the index is built, you can ask questions like:
|
||||
</p>
|
||||
|
||||
```bash
|
||||
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
||||
```
|
||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||
|
||||
@@ -268,7 +368,13 @@ python examples/wechat_history_reader_leann.py --query "Show me all group chats
|
||||
<details>
|
||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||
|
||||
First, you need to install the WeChat exporter:
|
||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
||||
|
||||
```bash
|
||||
brew install sunnyyoung/repo/wechattweak-cli
|
||||
```
|
||||
|
||||
or install it manually (if you have issues with Homebrew):
|
||||
|
||||
```bash
|
||||
sudo packages/wechat-exporter/wechattweak-cli install
|
||||
@@ -277,30 +383,28 @@ sudo packages/wechat-exporter/wechattweak-cli install
|
||||
**Troubleshooting:**
|
||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||
```
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
```bash
|
||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||
Failed to find or export WeChat data. Exiting.
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
||||
|
||||
#### Parameters
|
||||
```bash
|
||||
# Use default settings (recommended for first run)
|
||||
python examples/wechat_history_reader_leann.py
|
||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
||||
--force-export # Force re-export even if data exists
|
||||
```
|
||||
|
||||
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
|
||||
#### Example Commands
|
||||
```bash
|
||||
# Search for travel plans discussed in group chats
|
||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
||||
|
||||
# Run with custom index directory
|
||||
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
|
||||
|
||||
# Limit number of chat entries processed (useful for testing)
|
||||
python examples/wechat_history_reader_leann.py --max-entries 1000
|
||||
|
||||
# Run a single query
|
||||
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
|
||||
# Re-export and search recent chats (useful after new messages)
|
||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -314,17 +418,59 @@ Once the index is built, you can ask questions like:
|
||||
|
||||
</details>
|
||||
|
||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||
|
||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||
|
||||
**Key features:**
|
||||
- 🔍 **Semantic code search** across your entire project
|
||||
- 📚 **Context-aware assistance** for debugging and development
|
||||
- 🚀 **Zero-config setup** with automatic language detection
|
||||
|
||||
```bash
|
||||
# Install LEANN globally for MCP integration
|
||||
uv tool install leann-core
|
||||
|
||||
# Setup is automatic - just start using Claude Code!
|
||||
```
|
||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
||||
|
||||

|
||||
|
||||
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
||||
|
||||
## 🖥️ Command Line Interface
|
||||
|
||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||
|
||||
```bash
|
||||
# Build an index from documents
|
||||
leann build my-docs --docs ./documents
|
||||
### Installation
|
||||
|
||||
# Search your documents
|
||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
leann --help
|
||||
```
|
||||
|
||||
**To make it globally available:**
|
||||
```bash
|
||||
# Install the LEANN CLI globally using uv tool
|
||||
uv tool install leann-core
|
||||
|
||||
# Now you can use leann from anywhere without activating venv
|
||||
leann --help
|
||||
```
|
||||
|
||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```bash
|
||||
# build from a specific directory, and my_docs is the index name
|
||||
leann build my-docs --docs ./your_documents
|
||||
|
||||
# Search your documents
|
||||
leann search my-docs "machine learning concepts"
|
||||
|
||||
# Interactive chat with your documents
|
||||
@@ -392,17 +538,21 @@ Options:
|
||||
|
||||
**Core techniques:**
|
||||
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||
|
||||
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||
**Backends:**
|
||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
||||
|
||||
## Benchmarks
|
||||
|
||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
||||
|
||||
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||
### Storage Comparison
|
||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
||||
|
||||
### 📊 Storage Comparison
|
||||
|
||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||
|--------|-------------|------------|-------------|--------------|---------------|
|
||||
@@ -416,8 +566,7 @@ Options:
|
||||
|
||||
```bash
|
||||
uv pip install -e ".[dev]" # Install dev dependencies
|
||||
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
|
||||
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||
```
|
||||
|
||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||
@@ -429,22 +578,22 @@ If you find Leann useful, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025leannlowstoragevectorindex,
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||
year={2025},
|
||||
eprint={2506.08276},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.DB},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
}
|
||||
```
|
||||
|
||||
## ✨ [Detailed Features →](docs/features.md)
|
||||
|
||||
## 🤝 [Contributing →](docs/contributing.md)
|
||||
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||
|
||||
|
||||
## [FAQ →](docs/faq.md)
|
||||
## ❓ [FAQ →](docs/faq.md)
|
||||
|
||||
|
||||
## 📈 [Roadmap →](docs/roadmap.md)
|
||||
@@ -455,7 +604,12 @@ MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||
|
||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||
|
||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
@@ -465,4 +619,3 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.e
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
|
||||
|
||||
0
apps/__init__.py
Normal file
0
apps/__init__.py
Normal file
324
apps/base_rag_example.py
Normal file
324
apps/base_rag_example.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Base class for unified RAG examples interface.
|
||||
Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class BaseRAGExample(ABC):
|
||||
"""Base class for all RAG examples with unified interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
default_index_name: str,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.default_index_name = default_index_name
|
||||
self.parser = self._create_parser()
|
||||
|
||||
def _create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create argument parser with common parameters."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
# Core parameters (all examples share these)
|
||||
core_group = parser.add_argument_group("Core Parameters")
|
||||
core_group.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default=f"./{self.default_index_name}",
|
||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Query to run (if not provided, will run in interactive mode)",
|
||||
)
|
||||
# Allow subclasses to override default max_items
|
||||
max_items_default = getattr(self, "max_items_default", -1)
|
||||
core_group.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=max_items_default,
|
||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||
)
|
||||
|
||||
# Embedding parameters
|
||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||
# Allow subclasses to override default embedding_model
|
||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||
embedding_group.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=embedding_model_default,
|
||||
help=f"Embedding model to use (default: {embedding_model_default})",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
llm_group.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["openai", "ollama", "hf", "simulated"],
|
||||
help="LLM backend to use (default: openai)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="Host for Ollama API (default: http://localhost:11434)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
search_group.add_argument(
|
||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||
)
|
||||
search_group.add_argument(
|
||||
"--search-complexity",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Search complexity for graph traversal (default: 64)",
|
||||
)
|
||||
|
||||
# Index building parameters
|
||||
index_group = parser.add_argument_group("Index Building Parameters")
|
||||
index_group.add_argument(
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
help="Backend to use for index (default: hnsw)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--graph-degree",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Graph degree for index construction (default: 32)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--build-complexity",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Build complexity for index construction (default: 64)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-compact",
|
||||
action="store_true",
|
||||
help="Disable compact index storage",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-recompute",
|
||||
action="store_true",
|
||||
help="Disable embedding recomputation",
|
||||
)
|
||||
|
||||
# Add source-specific parameters
|
||||
self._add_specific_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add source-specific arguments. Override in subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load data from the source. Returns list of text chunks."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = args.llm_host
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
# Simulated LLM doesn't need additional configuration
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[str]) -> str:
|
||||
"""Build LEANN index from texts."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
is_recompute=not args.no_recompute,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
# Add texts in batches for better progress tracking
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for text in batch:
|
||||
builder.add_text(text)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
"""Run interactive chat with the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||
print("Type 'quit' or 'exit' to stop.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("You: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||
)
|
||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point for the example."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Check if index exists
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
index_exists = Path(args.index_dir).exists()
|
||||
|
||||
if not index_exists or args.force_rebuild:
|
||||
# Load data and build index
|
||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||
texts = await self.load_data(args)
|
||||
|
||||
if not texts:
|
||||
print("No data found to index!")
|
||||
return
|
||||
|
||||
index_path = await self.build_index(args, texts)
|
||||
else:
|
||||
print(f"\nUsing existing index in {args.index_dir}")
|
||||
|
||||
# Run query or interactive mode
|
||||
if args.query:
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
|
||||
|
||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
||||
"""Helper function to create text chunks from documents."""
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=" ",
|
||||
paragraph_separator="\n\n",
|
||||
)
|
||||
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
if nodes:
|
||||
all_texts.extend(node.get_content() for node in nodes)
|
||||
|
||||
return all_texts
|
||||
170
apps/browser_rag.py
Normal file
170
apps/browser_rag.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Browser History RAG example using the unified interface.
|
||||
Supports Chrome browser history.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
|
||||
class BrowserRAG(BaseRAGExample):
|
||||
"""RAG example for Chrome browser history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Browser History",
|
||||
description="Process and query Chrome browser history with LEANN",
|
||||
default_index_name="google_history_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add browser-specific arguments."""
|
||||
browser_group = parser.add_argument_group("Browser Parameters")
|
||||
browser_group.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _get_chrome_base_path(self) -> Path:
|
||||
"""Get the base Chrome profile path based on OS."""
|
||||
if sys.platform == "darwin":
|
||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||
elif sys.platform.startswith("linux"):
|
||||
return Path.home() / ".config" / "google-chrome"
|
||||
elif sys.platform == "win32":
|
||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
return []
|
||||
|
||||
profiles = []
|
||||
|
||||
# Check Default profile
|
||||
default_profile = base_path / "Default"
|
||||
if default_profile.exists() and (default_profile / "History").exists():
|
||||
profiles.append(default_profile)
|
||||
|
||||
# Check numbered profiles
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith("Profile "):
|
||||
if (item / "History").exists():
|
||||
profiles.append(item)
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
profile_dirs = [Path(args.chrome_profile)]
|
||||
else:
|
||||
print("Auto-detecting Chrome profiles...")
|
||||
profile_dirs = self._find_chrome_profiles()
|
||||
|
||||
# If specific profile given, filter to just that one
|
||||
if args.chrome_profile:
|
||||
profile_path = Path(args.chrome_profile)
|
||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found!")
|
||||
print("Please specify --chrome-profile manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
|
||||
# Create reader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
# Process each profile
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per profile
|
||||
max_per_profile = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_profile = remaining
|
||||
|
||||
# Load history
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_per_profile,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} history entries from this profile")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No browser history found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for browser history RAG
|
||||
print("\n🌐 Browser History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What websites did I visit about machine learning?'")
|
||||
print("- 'Find my search history about programming'")
|
||||
print("- 'What YouTube videos did I watch recently?'")
|
||||
print("- 'Show me websites about travel planning'")
|
||||
print("\nNote: Make sure Chrome is closed before running\n")
|
||||
|
||||
rag = BrowserRAG()
|
||||
asyncio.run(rag.run())
|
||||
108
apps/document_rag.py
Normal file
108
apps/document_rag.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Document RAG example using the unified interface.
|
||||
Supports PDF, TXT, MD, and other document formats.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory containing documents to index (default: data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
print(f"Filtering by file types: {args.file_types}")
|
||||
else:
|
||||
print("Processing all supported file types")
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
reader_kwargs = {
|
||||
"recursive": True,
|
||||
"encoding": "utf-8",
|
||||
}
|
||||
if args.file_types:
|
||||
reader_kwargs["required_exts"] = args.file_types
|
||||
|
||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What are the main techniques LEANN uses?'")
|
||||
print("- 'What is the technique DLPM?'")
|
||||
print("- 'Who does Elizabeth Bennet marry?'")
|
||||
print(
|
||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||
)
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -52,6 +52,11 @@ class EmlxReader(BaseReader):
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
total_files = 0
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
print(f"Starting to process directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
@@ -59,10 +64,12 @@ class EmlxReader(BaseReader):
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
# Check if we've reached the max count (skip if max_count == -1)
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
total_files += 1
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
@@ -98,17 +105,26 @@ class EmlxReader(BaseReader):
|
||||
and not self.include_html
|
||||
):
|
||||
continue
|
||||
body += part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# break
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding payload: {e}")
|
||||
continue
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding single part payload: {e}")
|
||||
body = ""
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
# Only create document if we have some content
|
||||
if body.strip() or subject != "No Subject":
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
@@ -118,18 +134,34 @@ class EmlxReader(BaseReader):
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
successful_files += 1
|
||||
|
||||
# Print first few successful files for debugging
|
||||
if successful_files <= 3:
|
||||
print(
|
||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
print("Processing summary:")
|
||||
print(f" Total .emlx files found: {total_files}")
|
||||
print(f" Successfully loaded: {successful_files}")
|
||||
print(f" Failed to load: {failed_files}")
|
||||
print(f" Final documents: {len(docs)}")
|
||||
|
||||
return docs
|
||||
156
apps/email_rag.py
Normal file
156
apps/email_rag.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
class EmailRAG(BaseRAGExample):
|
||||
"""RAG example for Apple Mail processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all emails by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Email",
|
||||
description="Process and query Apple Mail emails with LEANN",
|
||||
default_index_name="mail_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add email-specific arguments."""
|
||||
email_group = parser.add_argument_group("Email Parameters")
|
||||
email_group.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
return []
|
||||
|
||||
# Find all Messages directories
|
||||
messages_dirs = []
|
||||
for item in mail_base.rglob("Messages"):
|
||||
if item.is_dir():
|
||||
messages_dirs.append(item)
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
messages_dirs = [Path(args.mail_path)]
|
||||
else:
|
||||
print("Auto-detecting Apple Mail directories...")
|
||||
messages_dirs = self._find_mail_directories()
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Apple Mail directories found!")
|
||||
print("Please specify --mail-path manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(messages_dirs)} mail directories")
|
||||
|
||||
# Create reader
|
||||
reader = EmlxReader(include_html=args.include_html)
|
||||
|
||||
# Process each directory
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
# Count emlx files
|
||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||
print(f"Found {len(emlx_files)} email files")
|
||||
|
||||
# Apply max_items limit per directory
|
||||
max_per_dir = -1 # Default to process all
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_dir = remaining
|
||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||
|
||||
# Load emails - fix the parameter passing
|
||||
documents = reader.load_data(
|
||||
input_dir=str(messages_dir),
|
||||
max_count=max_per_dir,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} emails from this directory")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No emails found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
# Email reader uses chunk_overlap=25 as in original
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||
print(" Windows/Linux support coming soon!\n")
|
||||
|
||||
# Example queries for email RAG
|
||||
print("\n📧 Email RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did my boss say about deadlines?'")
|
||||
print("- 'Find emails about travel expenses'")
|
||||
print("- 'Show me emails from last month about the project'")
|
||||
print("- 'What food did I order from DoorDash?'")
|
||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||
|
||||
rag = EmailRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -97,6 +97,11 @@ class ChromeHistoryReader(BaseReader):
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
# add you may need to close your browser to make the database file available
|
||||
# also highlight in red
|
||||
print(
|
||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||
)
|
||||
return docs
|
||||
|
||||
return docs
|
||||
@@ -411,8 +411,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||
include_non_text = load_kwargs.get("include_non_text", False)
|
||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||
load_kwargs.get("max_length", 1000)
|
||||
load_kwargs.get("time_window_minutes", 30)
|
||||
max_length = load_kwargs.get("max_length", 1000)
|
||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||
|
||||
# Default WeChat export path
|
||||
if wechat_export_dir is None:
|
||||
@@ -460,9 +460,9 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0, # Keep 2 messages overlap between groups
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=0, # No overlap between groups
|
||||
)
|
||||
|
||||
# Create documents from concatenated groups
|
||||
@@ -474,7 +474,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
message_group, contact_name
|
||||
)
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
text=doc_content,
|
||||
metadata={"contact_name": contact_name},
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
@@ -531,7 +532,9 @@ Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
@@ -559,8 +562,8 @@ Message: {readable_text if readable_text else message_text}
|
||||
|
||||
# Look for common export directory names
|
||||
possible_dirs = [
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export"),
|
||||
]
|
||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
WeChat History RAG example using the unified interface.
|
||||
Supports WeChat chat history export and search.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
|
||||
from .history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
|
||||
class WeChatRAG(BaseRAGExample):
|
||||
"""RAG example for WeChat chat history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Match original default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="WeChat History",
|
||||
description="Process and query WeChat chat history with LEANN",
|
||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add WeChat-specific arguments."""
|
||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||
wechat_group.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default="./wechat_export",
|
||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||
)
|
||||
|
||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||
"""Export WeChat data using wechattweak-cli."""
|
||||
print("Exporting WeChat data...")
|
||||
|
||||
# Check if WeChat is running
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("WeChat is not running. Please start WeChat first.")
|
||||
return False
|
||||
except Exception:
|
||||
pass # pgrep might not be available on all systems
|
||||
|
||||
# Create export directory
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run export command
|
||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||
|
||||
try:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("WeChat data exported successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"Export failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print("\nError: wechattweak-cli not found!")
|
||||
print("Please install it first:")
|
||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[str]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||
# Try to find any existing exports in common locations
|
||||
export_dirs = reader.find_wechat_export_dirs()
|
||||
if not export_dirs:
|
||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||
return []
|
||||
|
||||
# Load documents from all found export directories
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per export
|
||||
max_per_export = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_export = remaining
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_per_export,
|
||||
concatenate_messages=True, # Enable message concatenation for better context
|
||||
)
|
||||
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return []
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks with contact information
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
text_splitter = SentenceSplitter(
|
||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
# Add contact information to each chunk
|
||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||
print(" You can still query existing exports on other platforms\n")
|
||||
|
||||
# Example queries for WeChat RAG
|
||||
print("\n💬 WeChat History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
rag = WeChatRAG()
|
||||
asyncio.run(rag.run())
|
||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
@@ -1,13 +1,28 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
# 🧪 LEANN Benchmarks & Testing
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `diskann_vs_hnsw_speed_comparison.py`
|
||||
Performance comparison between DiskANN and HNSW backends:
|
||||
- ✅ **Search latency** comparison with both backends using recompute
|
||||
- ✅ **Index size** and **build time** measurements
|
||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||
- ✅ **Configurable dataset sizes** for different scales
|
||||
|
||||
```bash
|
||||
# Quick comparison with 500 docs, 10 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||
|
||||
# Large-scale comparison with 2000 docs, 20 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||
```
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
@@ -27,7 +42,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
@@ -64,7 +79,7 @@ When all tests pass, you should see:
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
@@ -98,7 +113,7 @@ pkill -f "embedding_server"
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
@@ -117,4 +132,4 @@ These tests are designed to be run in automated environments:
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
@@ -115,7 +115,13 @@ def main():
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker="o", linestyle="-", label=f"PyTorch ({device})")
|
||||
plt.plot(
|
||||
BATCH_SIZES,
|
||||
results_torch,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
label=f"PyTorch ({device})",
|
||||
)
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||
|
||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||
@@ -62,7 +62,7 @@ def test_faiss_hnsw():
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/faiss_only.py"],
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
@@ -115,7 +115,7 @@ def test_leann_hnsw():
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
268
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
268
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN vs HNSW Search Performance Comparison
|
||||
|
||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||
- HNSW: With recompute enabled (is_recompute=True)
|
||||
- Tests performance across different dataset sizes
|
||||
- Measures search latency, recall, and index size
|
||||
"""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_test_texts(n_docs: int) -> list[str]:
|
||||
"""Create synthetic test documents for benchmarking."""
|
||||
np.random.seed(42)
|
||||
topics = [
|
||||
"machine learning and artificial intelligence",
|
||||
"natural language processing and text analysis",
|
||||
"computer vision and image recognition",
|
||||
"data science and statistical analysis",
|
||||
"deep learning and neural networks",
|
||||
"information retrieval and search engines",
|
||||
"database systems and data management",
|
||||
"software engineering and programming",
|
||||
"cybersecurity and network protection",
|
||||
"cloud computing and distributed systems",
|
||||
]
|
||||
|
||||
texts = []
|
||||
for i in range(n_docs):
|
||||
topic = topics[i % len(topics)]
|
||||
variation = np.random.randint(1, 100)
|
||||
text = (
|
||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||
f"Additional information about {topic} with details and examples. "
|
||||
f"Technical discussion of {topic} including implementation aspects."
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def benchmark_backend(
|
||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark a specific backend with the given configuration."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||
|
||||
# Build index
|
||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||
start_time = time.time()
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
build_time = time.time() - start_time
|
||||
|
||||
# Measure index size
|
||||
index_dir = Path(index_path).parent
|
||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||
|
||||
# Search benchmark
|
||||
print("🔍 Running search benchmark...")
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
search_times = []
|
||||
all_results = []
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
all_results.append(results)
|
||||
|
||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||
|
||||
# Check for valid scores (detect -inf issues)
|
||||
all_scores = [
|
||||
result.score
|
||||
for results in all_results
|
||||
for result in results
|
||||
if result.score is not None
|
||||
]
|
||||
valid_scores = [
|
||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||
]
|
||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||
|
||||
# Clean up
|
||||
try:
|
||||
if hasattr(searcher, "__del__"):
|
||||
searcher.__del__()
|
||||
del searcher
|
||||
del builder
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||
|
||||
return {
|
||||
"build_time": build_time,
|
||||
"avg_search_time_ms": avg_search_time,
|
||||
"index_size_mb": size_mb,
|
||||
"score_validity_rate": score_validity_rate,
|
||||
}
|
||||
|
||||
|
||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||
"""Run performance comparison between DiskANN and HNSW."""
|
||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||
|
||||
# Create test data
|
||||
texts = create_test_texts(n_docs)
|
||||
test_queries = [
|
||||
"machine learning algorithms",
|
||||
"natural language processing",
|
||||
"computer vision techniques",
|
||||
"data analysis methods",
|
||||
"neural network architectures",
|
||||
"database query optimization",
|
||||
"software development practices",
|
||||
"security vulnerabilities",
|
||||
"cloud infrastructure",
|
||||
"distributed computing",
|
||||
][:n_queries]
|
||||
|
||||
# HNSW benchmark
|
||||
hnsw_results = benchmark_backend(
|
||||
backend_name="hnsw",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable recompute for fair comparison
|
||||
"M": 16,
|
||||
"efConstruction": 200,
|
||||
},
|
||||
)
|
||||
|
||||
# DiskANN benchmark
|
||||
diskann_results = benchmark_backend(
|
||||
backend_name="diskann",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable graph partitioning
|
||||
"num_neighbors": 32,
|
||||
"search_list_size": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# Performance comparison
|
||||
print("\n📈 Performance Comparison Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||
print(f"{'-' * 60}")
|
||||
|
||||
# Build time comparison
|
||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||
print(
|
||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Search time comparison
|
||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||
print(
|
||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Index size comparison
|
||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||
print(
|
||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||
)
|
||||
|
||||
# Score validity
|
||||
print(
|
||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print("\n🎯 Summary:")
|
||||
if search_speedup > 1:
|
||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||
else:
|
||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||
|
||||
if size_ratio > 1:
|
||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||
else:
|
||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||
|
||||
print(
|
||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Handle help request
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||
print()
|
||||
print("Arguments:")
|
||||
print(" n_docs Number of documents to index (default: 500)")
|
||||
print(" n_queries Number of test queries to run (default: 10)")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse command line arguments
|
||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||
print()
|
||||
|
||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Benchmark interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Benchmark failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Ensure clean exit
|
||||
try:
|
||||
gc.collect()
|
||||
print("\n🧹 Cleanup completed")
|
||||
except Exception:
|
||||
pass
|
||||
sys.exit(0)
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -58,7 +58,8 @@ class GraphWrapper:
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input, attention_mask=self.static_attention_mask
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
self.use_cuda_graph = True
|
||||
else:
|
||||
@@ -82,7 +83,10 @@ class GraphWrapper:
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(input_ids=self.static_input, attention_mask=self.static_attention_mask)
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_cuda_graph:
|
||||
@@ -261,7 +265,10 @@ class Benchmark:
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features, out_features, bias=bias, has_fp16_weights=False
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False,
|
||||
)
|
||||
|
||||
# Copy weights and bias
|
||||
@@ -350,8 +357,6 @@ class Benchmark:
|
||||
# Try xformers if available (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
@@ -427,7 +432,11 @@ class Benchmark:
|
||||
else "cpu"
|
||||
)
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, self.config.seq_length), device=device, dtype=torch.long
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _run_inference(
|
||||
@@ -200,10 +200,10 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
@@ -279,7 +279,9 @@ def main():
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print("Please check the data/indices/ directory or provide --index-path manually.")
|
||||
print(
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Detect dataset type from index path to select the correct ground truth
|
||||
@@ -170,7 +170,11 @@ class Benchmark:
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, self.config.seq_length), device=self.device, dtype=torch.long
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
@@ -256,7 +260,11 @@ def run_mlx_benchmark():
|
||||
"""Run MLX-specific benchmark"""
|
||||
if not MLX_AVAILABLE:
|
||||
print("MLX not available, skipping MLX benchmark")
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "MLX not available"}
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "MLX not available",
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||
|
||||
@@ -265,7 +273,11 @@ def run_mlx_benchmark():
|
||||
results = benchmark.run()
|
||||
|
||||
if not results:
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": "No valid results"}
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "No valid results",
|
||||
}
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
@@ -1,5 +1,5 @@
|
||||
The Project Gutenberg eBook of Pride and Prejudice
|
||||
|
||||
|
||||
This ebook is for the use of anyone anywhere in the United States and
|
||||
most other parts of the world at no cost and with almost no restrictions
|
||||
whatsoever. You may copy it, give it away or re-use it under the terms
|
||||
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
|
||||
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Updated editions will replace the previous one—the old editions will
|
||||
be renamed.
|
||||
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
|
||||
at www.gutenberg.org. If you
|
||||
are not located in the United States, you will have to check the laws
|
||||
of the country where you are located before using this eBook.
|
||||
|
||||
|
||||
1.E.2. If an individual Project Gutenberg™ electronic work is
|
||||
derived from texts not protected by U.S. copyright law (does not
|
||||
contain a notice indicating that it is posted with permission of the
|
||||
@@ -14724,7 +14724,7 @@ provided that:
|
||||
Gutenberg Literary Archive Foundation at the address specified in
|
||||
Section 4, “Information about donations to the Project Gutenberg
|
||||
Literary Archive Foundation.”
|
||||
|
||||
|
||||
• You provide a full refund of any money paid by a user who notifies
|
||||
you in writing (or by e-mail) within 30 days of receipt that s/he
|
||||
does not agree to the terms of the full Project Gutenberg™
|
||||
@@ -14732,15 +14732,15 @@ provided that:
|
||||
copies of the works possessed in a physical medium and discontinue
|
||||
all use of and all access to other copies of Project Gutenberg™
|
||||
works.
|
||||
|
||||
|
||||
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
||||
any money paid for a work or a replacement copy, if a defect in the
|
||||
electronic work is discovered and reported to you within 90 days of
|
||||
receipt of the work.
|
||||
|
||||
|
||||
• You comply with all other terms of this agreement for free
|
||||
distribution of Project Gutenberg™ works.
|
||||
|
||||
|
||||
|
||||
1.E.9. If you wish to charge a fee or distribute a Project
|
||||
Gutenberg™ electronic work or group of works on different terms than
|
||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
||||
including how to make donations to the Project Gutenberg Literary
|
||||
Archive Foundation, how to help produce our new eBooks, and how to
|
||||
subscribe to our email newsletter to hear about new eBooks.
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
---
|
||||
license: mit
|
||||
---
|
||||
|
||||
# LEANN-RAG Evaluation Data
|
||||
|
||||
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||
|
||||
## Dataset Components
|
||||
|
||||
This dataset is structured into three main parts:
|
||||
|
||||
1. **Pre-built LEANN Indices**:
|
||||
* `dpr/`: A pre-built index for the DPR dataset.
|
||||
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||
|
||||
2. **Ground Truth Data**:
|
||||
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||
|
||||
3. **Queries**:
|
||||
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
```
|
||||
|
||||
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||
repo_type="dataset",
|
||||
local_dir="data"
|
||||
)
|
||||
```
|
||||
|
||||
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||
141
demo.ipynb
141
demo.ipynb
@@ -4,7 +4,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start in 30s\n",
|
||||
"# Quick Start \n",
|
||||
"\n",
|
||||
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||
"\n",
|
||||
@@ -49,68 +49,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 17077.79chunk/s]\n",
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 36.43it/s]\n",
|
||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"M: 64 for level: 0\n",
|
||||
"Starting conversion: index.index -> index.csr.tmp\n",
|
||||
"[0.00s] Reading Index HNSW header...\n",
|
||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||
"[0.00s] Reading HNSW struct vectors...\n",
|
||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||
"[0.00s] Read assign_probas (6)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||
"[0.14s] Read cum_nneighbor_per_level (7)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||
"[0.24s] Read levels (5)\n",
|
||||
"[0.33s] Probing for compact storage flag...\n",
|
||||
"[0.33s] Found compact flag: False\n",
|
||||
"[0.33s] Compact flag is False, reading original format...\n",
|
||||
"[0.33s] Probing for potential extra byte before non-compact offsets...\n",
|
||||
"[0.33s] Found and consumed an unexpected 0x00 byte.\n",
|
||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||
"[0.33s] Read offsets (6)\n",
|
||||
"[0.41s] Attempting to read neighbors vector...\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||
"[0.41s] Read neighbors (320)\n",
|
||||
"[0.54s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||
"[0.54s] Checking for storage data...\n",
|
||||
"[0.54s] Found storage fourcc: 49467849.\n",
|
||||
"[0.54s] Converting to CSR format...\n",
|
||||
"[0.54s] Conversion loop finished. \n",
|
||||
"[0.54s] Running validation checks...\n",
|
||||
" Checking total valid neighbor count...\n",
|
||||
" OK: Total valid neighbors = 20\n",
|
||||
" Checking final pointer indices...\n",
|
||||
" OK: Final pointers match data size.\n",
|
||||
"[0.54s] Deleting original neighbors and offsets arrays...\n",
|
||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||
"[0.63s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||
"[0.71s] Conversion complete.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'index.index'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
@@ -136,81 +75,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'programming languages'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5560 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5561 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5562 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Starting embedding server on port 5563...\n",
|
||||
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5563 --model-name facebook/contriever --passages-file /Users/yichuan/Desktop/code/test_leann_pip/LEANN/content/index.meta.json\n",
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"INFO:leann.embedding_server_manager:Server process started with PID: 31699\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"<frozen runpy>\", line 198, in _run_module_as_main\n",
|
||||
" File \"<frozen runpy>\", line 88, in _run_code\n",
|
||||
" File \"/Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann_backend_hnsw/hnsw_embedding_server.py\", line 323, in <module>\n",
|
||||
" create_hnsw_embedding_server(\n",
|
||||
" File \"/Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann_backend_hnsw/hnsw_embedding_server.py\", line 98, in create_hnsw_embedding_server\n",
|
||||
" passages = PassageManager(passage_sources)\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/Users/yichuan/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann/api.py\", line 127, in __init__\n",
|
||||
" raise FileNotFoundError(f\"Passage index file not found: {index_file}\")\n",
|
||||
"FileNotFoundError: Passage index file not found: /Users/yichuan/Desktop/code/test_leann_pip/LEANN/index.passages.idx\n",
|
||||
"ERROR:leann.embedding_server_manager:Server terminated during startup.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "Failed to start embedding server on port 5563",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mleann\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mapi\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m LeannSearcher\n\u001b[32m 3\u001b[39m searcher = LeannSearcher(\u001b[33m\"\u001b[39m\u001b[33mindex\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m results = \u001b[43msearcher\u001b[49m\u001b[43m.\u001b[49m\u001b[43msearch\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mprogramming languages\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 5\u001b[39m results\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann/api.py:439\u001b[39m, in \u001b[36mLeannSearcher.search\u001b[39m\u001b[34m(self, query, top_k, complexity, beam_width, prune_ratio, recompute_embeddings, pruning_strategy, expected_zmq_port, **kwargs)\u001b[39m\n\u001b[32m 437\u001b[39m start_time = time.time()\n\u001b[32m 438\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m recompute_embeddings:\n\u001b[32m--> \u001b[39m\u001b[32m439\u001b[39m zmq_port = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbackend_impl\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_ensure_server_running\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 440\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmeta_path_str\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 441\u001b[39m \u001b[43m \u001b[49m\u001b[43mport\u001b[49m\u001b[43m=\u001b[49m\u001b[43mexpected_zmq_port\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 442\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 443\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 444\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m expected_zmq_port\n\u001b[32m 445\u001b[39m zmq_time = time.time() - start_time\n",
|
||||
"\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/code/test_leann_pip/LEANN/.venv/lib/python3.11/site-packages/leann/searcher_base.py:81\u001b[39m, in \u001b[36mBaseSearcher._ensure_server_running\u001b[39m\u001b[34m(self, passages_source_file, port, **kwargs)\u001b[39m\n\u001b[32m 72\u001b[39m server_started, actual_port = \u001b[38;5;28mself\u001b[39m.embedding_server_manager.start_server(\n\u001b[32m 73\u001b[39m port=port,\n\u001b[32m 74\u001b[39m model_name=\u001b[38;5;28mself\u001b[39m.embedding_model,\n\u001b[32m (...)\u001b[39m\u001b[32m 78\u001b[39m enable_warmup=kwargs.get(\u001b[33m\"\u001b[39m\u001b[33menable_warmup\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m),\n\u001b[32m 79\u001b[39m )\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m server_started:\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 82\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFailed to start embedding server on port \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mactual_port\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 83\u001b[39m )\n\u001b[32m 85\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m actual_port\n",
|
||||
"\u001b[31mRuntimeError\u001b[39m: Failed to start embedding server on port 5563"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
|
||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
|
||||
## 🚀 Development Setup
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install uv** (fast Python package installer):
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
2. **Clone the repository**:
|
||||
```bash
|
||||
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||
cd LEANN-RAG
|
||||
```
|
||||
|
||||
3. **Install system dependencies**:
|
||||
|
||||
**macOS:**
|
||||
```bash
|
||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||
```
|
||||
|
||||
**Ubuntu/Debian:**
|
||||
```bash
|
||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||
```
|
||||
|
||||
4. **Build from source**:
|
||||
```bash
|
||||
# macOS
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||
|
||||
# Ubuntu/Debian
|
||||
uv sync
|
||||
```
|
||||
|
||||
## 🔨 Pre-commit Hooks
|
||||
|
||||
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||
|
||||
### Setup Pre-commit
|
||||
|
||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
```
|
||||
|
||||
2. **Install the git hooks**:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
3. **Run pre-commit manually** (optional):
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Pre-commit Checks
|
||||
|
||||
Our pre-commit configuration includes:
|
||||
- **Trailing whitespace removal**
|
||||
- **End-of-file fixing**
|
||||
- **YAML validation**
|
||||
- **Large file prevention**
|
||||
- **Merge conflict detection**
|
||||
- **Debug statement detection**
|
||||
- **Code formatting with ruff**
|
||||
- **Code linting with ruff**
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest test/test_filename.py
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest --cov=leann
|
||||
```
|
||||
|
||||
### Writing Tests
|
||||
|
||||
- Place tests in the `test/` directory
|
||||
- Follow the naming convention `test_*.py`
|
||||
- Use descriptive test names that explain what's being tested
|
||||
- Include both positive and negative test cases
|
||||
|
||||
## 📝 Code Style
|
||||
|
||||
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||
|
||||
### Format Your Code
|
||||
|
||||
```bash
|
||||
# Format all files
|
||||
ruff format
|
||||
|
||||
# Check formatting without changing files
|
||||
ruff format --check
|
||||
```
|
||||
|
||||
### Lint Your Code
|
||||
|
||||
```bash
|
||||
# Run linter with auto-fix
|
||||
ruff check --fix
|
||||
|
||||
# Just check without fixing
|
||||
ruff check
|
||||
```
|
||||
|
||||
### Style Guidelines
|
||||
|
||||
- Follow PEP 8 conventions
|
||||
- Use descriptive variable names
|
||||
- Add type hints where appropriate
|
||||
- Write docstrings for all public functions and classes
|
||||
- Keep functions focused and single-purpose
|
||||
|
||||
## 🚦 CI/CD
|
||||
|
||||
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||
|
||||
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||
4. **Wheel building**: Ensures packages can be built and distributed
|
||||
|
||||
### CI Commands
|
||||
|
||||
The CI uses the same commands as pre-commit to ensure consistency:
|
||||
```bash
|
||||
# Linting
|
||||
ruff check .
|
||||
|
||||
# Format checking
|
||||
ruff format --check .
|
||||
```
|
||||
|
||||
Make sure your code passes these checks locally before pushing!
|
||||
|
||||
## 🔄 Pull Request Process
|
||||
|
||||
1. **Fork the repository** and create your branch from `main`:
|
||||
```bash
|
||||
git checkout -b feature/your-feature-name
|
||||
```
|
||||
|
||||
2. **Make your changes**:
|
||||
- Write clean, documented code
|
||||
- Add tests for new functionality
|
||||
- Update documentation as needed
|
||||
|
||||
3. **Run pre-commit checks**:
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
4. **Test your changes**:
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
5. **Commit with descriptive messages**:
|
||||
```bash
|
||||
git commit -m "feat: add new search algorithm"
|
||||
```
|
||||
|
||||
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
- `feat:` for new features
|
||||
- `fix:` for bug fixes
|
||||
- `docs:` for documentation changes
|
||||
- `test:` for test additions/changes
|
||||
- `refactor:` for code refactoring
|
||||
- `perf:` for performance improvements
|
||||
|
||||
6. **Push and create a pull request**:
|
||||
- Provide a clear description of your changes
|
||||
- Reference any related issues
|
||||
- Include examples or screenshots if applicable
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
When adding new features or making significant changes:
|
||||
|
||||
1. Update relevant documentation in `/docs`
|
||||
2. Add docstrings to new functions/classes
|
||||
3. Update README.md if needed
|
||||
4. Include usage examples
|
||||
|
||||
## 🤔 Getting Help
|
||||
|
||||
- **Discord**: Join our community for discussions
|
||||
- **Issues**: Check existing issues or create a new one
|
||||
- **Discussions**: For general questions and ideas
|
||||
|
||||
## 📄 License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||
@@ -19,4 +19,4 @@ That's it! The workflow will automatically:
|
||||
- ✅ Publish to PyPI
|
||||
- ✅ Create GitHub tag and release
|
||||
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||
|
||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Thinking Budget Feature Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||
|
||||
## Feature Description
|
||||
|
||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||
- **`medium`**: Balanced speed and reasoning depth
|
||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### 1. Command Line Interface
|
||||
|
||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||
|
||||
```bash
|
||||
# LEANN CLI
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# RAG Examples
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||
```
|
||||
|
||||
### 2. LLM Backend Support
|
||||
|
||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
options = kwargs.copy()
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget:
|
||||
options.pop("thinking_budget", None)
|
||||
if thinking_budget in ["low", "medium", "high"]:
|
||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||
```
|
||||
|
||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||
|
||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||
|
||||
```python
|
||||
def ask(self, prompt: str, **kwargs) -> str:
|
||||
# Handle thinking budget for reasoning models
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||
# Check if this is an o-series model
|
||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||
if any(model in self.model for model in o_series_models):
|
||||
params["reasoning_effort"] = thinking_budget
|
||||
```
|
||||
|
||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||
|
||||
### 3. Parameter Propagation
|
||||
|
||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||
|
||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||
|
||||
## Files Modified
|
||||
|
||||
### Core Implementation
|
||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||
|
||||
### Documentation
|
||||
- `README.md`: Added thinking budget parameter to usage examples
|
||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||
|
||||
### Examples
|
||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage
|
||||
```bash
|
||||
# High reasoning effort for complex questions
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Medium reasoning for balanced performance
|
||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||
|
||||
# Low reasoning for fast responses
|
||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||
```
|
||||
|
||||
### RAG Examples
|
||||
```bash
|
||||
# Email RAG with high reasoning
|
||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
|
||||
# Document RAG with medium reasoning
|
||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Ollama Models
|
||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||
|
||||
### OpenAI Models
|
||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||
|
||||
## Testing
|
||||
|
||||
The implementation includes comprehensive testing:
|
||||
- Parameter handling verification
|
||||
- Backend-specific API format validation
|
||||
- CLI argument parsing tests
|
||||
- Integration with existing LEANN architecture
|
||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Comparison between Sentence Transformers and OpenAI embeddings
|
||||
|
||||
This example shows how different embedding models handle complex queries
|
||||
and demonstrates the differences between local and API-based embeddings.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
|
||||
# OpenAI API key should be set as environment variable
|
||||
# export OPENAI_API_KEY="your-api-key-here"
|
||||
|
||||
# Test data
|
||||
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||
|
||||
# Two queries with same intent but different wording
|
||||
query1 = "Tell me my browser history about some conference i often visit"
|
||||
query2 = "browser history about conference I often visit"
|
||||
|
||||
texts = [query1, query2, conference_text, browser_text]
|
||||
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) # Already normalized
|
||||
|
||||
|
||||
def analyze_embeddings(embeddings, model_name):
|
||||
print(f"\n=== {model_name} Results ===")
|
||||
|
||||
# Results for Query 1
|
||||
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||
|
||||
print(f"Query 1: '{query1}'")
|
||||
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||
|
||||
# Results for Query 2
|
||||
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||
|
||||
print(f"\nQuery 2: '{query2}'")
|
||||
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||
print(
|
||||
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||
)
|
||||
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||
|
||||
# Show the impact
|
||||
print("\n=== Impact Analysis ===")
|
||||
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||
|
||||
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||
print("✅ STABLE: Conference remains winner in both queries")
|
||||
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||
print("✅ STABLE: Browser remains winner in both queries")
|
||||
else:
|
||||
print("🔄 MIXED: Results vary between queries")
|
||||
|
||||
return {
|
||||
"query1_conf": sim1_conf,
|
||||
"query1_browser": sim1_browser,
|
||||
"query2_conf": sim2_conf,
|
||||
"query2_browser": sim2_browser,
|
||||
}
|
||||
|
||||
|
||||
# Test Sentence Transformers
|
||||
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||
try:
|
||||
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||
except Exception as e:
|
||||
print(f"❌ Sentence Transformers failed: {e}")
|
||||
st_results = None
|
||||
|
||||
# Test OpenAI
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing OpenAI (text-embedding-3-small)...")
|
||||
try:
|
||||
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||
except Exception as e:
|
||||
print(f"❌ OpenAI failed: {e}")
|
||||
openai_results = None
|
||||
|
||||
# Compare results
|
||||
if st_results and openai_results:
|
||||
print("\n" + "=" * 60)
|
||||
print("=== COMPARISON SUMMARY ===")
|
||||
300
docs/configuration-guide.md
Normal file
300
docs/configuration-guide.md
Normal file
@@ -0,0 +1,300 @@
|
||||
# LEANN Configuration Guide
|
||||
|
||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||
|
||||
## Getting Started: Simple is Better
|
||||
|
||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||
|
||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||
```bash
|
||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||
```
|
||||
|
||||
**For other data sources**: Limit the dataset size for quick testing
|
||||
```bash
|
||||
# WeChat: Test with recent messages only
|
||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||
|
||||
# Browser history: Last few days
|
||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||
|
||||
# Email: Recent inbox
|
||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||
```
|
||||
|
||||
Once validated, scale up gradually:
|
||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||
- This helps identify issues early before committing to long processing times
|
||||
|
||||
## Embedding Model Selection: Understanding the Trade-offs
|
||||
|
||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||
|
||||
### Small Models (< 100M parameters)
|
||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||
- **Pros**: Lightweight, fast for both indexing and inference
|
||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||
|
||||
### Medium Models (100M-500M parameters)
|
||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||
- **Cons**: Requires more compute than small models
|
||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||
|
||||
### Large Models (500M+ parameters)
|
||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||
- **Cons**: Slower inference, longer index build times
|
||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||
|
||||
### Quick Start: Cloud and Local Embedding Options
|
||||
|
||||
**OpenAI Embeddings (Fastest Setup)**
|
||||
For immediate testing without local model downloads:
|
||||
```bash
|
||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||
```
|
||||
|
||||
**Ollama Embeddings (Privacy-Focused)**
|
||||
For local embeddings with complete privacy:
|
||||
```bash
|
||||
# First, pull an embedding model
|
||||
ollama pull nomic-embed-text
|
||||
|
||||
# Use Ollama embeddings
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||
|
||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||
- **Pros**: No local compute needed, consistently fast, high quality
|
||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||
|
||||
**Local Embeddings**
|
||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||
|
||||
</details>
|
||||
|
||||
## Index Selection: Matching Your Scale
|
||||
|
||||
### HNSW (Hierarchical Navigable Small World)
|
||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||
- Full recomputation required
|
||||
- High memory usage during build phase
|
||||
- Excellent recall (95%+)
|
||||
|
||||
```bash
|
||||
# Optimal for most use cases
|
||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||
```
|
||||
|
||||
### DiskANN
|
||||
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
|
||||
|
||||
**How it works:**
|
||||
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
|
||||
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
|
||||
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
|
||||
|
||||
**Trade-offs compared to HNSW:**
|
||||
- ✅ **Faster search latency** (typically 2-8x speedup)
|
||||
- ✅ **Better scaling** for large datasets
|
||||
- ✅ **Smart storage management** with automatic partitioning
|
||||
- ✅ **Better graph locality** with `--ldg-times` parameter for SSD optimization
|
||||
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
|
||||
|
||||
```bash
|
||||
# Recommended for most use cases
|
||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||
|
||||
# For large-scale deployments
|
||||
--backend-name diskann --graph-degree 64 --build-complexity 128
|
||||
```
|
||||
|
||||
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||
|
||||
## LLM Selection: Engine and Model Comparison
|
||||
|
||||
### LLM Engines
|
||||
|
||||
**OpenAI** (`--llm openai`)
|
||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||
|
||||
**Ollama** (`--llm ollama`)
|
||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||
|
||||
**HuggingFace** (`--llm hf`)
|
||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||
- **Cons**: More complex initial setup
|
||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||
|
||||
## Parameter Tuning Guide
|
||||
|
||||
### Search Complexity Parameters
|
||||
|
||||
**`--build-complexity`** (index building)
|
||||
- Controls thoroughness during index construction
|
||||
- Higher = better recall but slower build
|
||||
- Recommendations:
|
||||
- 32: Quick prototyping
|
||||
- 64: Balanced (default)
|
||||
- 128: Production systems
|
||||
- 256: Maximum quality
|
||||
|
||||
**`--search-complexity`** (query time)
|
||||
- Controls search thoroughness
|
||||
- Higher = better results but slower
|
||||
- Recommendations:
|
||||
- 16: Fast/Interactive search
|
||||
- 32: High quality with diversity
|
||||
- 64+: Maximum accuracy
|
||||
|
||||
### Top-K Selection
|
||||
|
||||
**`--top-k`** (number of retrieved chunks)
|
||||
- More chunks = better context but slower LLM processing
|
||||
- Should be always smaller than `--search-complexity`
|
||||
- Guidelines:
|
||||
- 10-20: General questions (default: 20)
|
||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||
|
||||
**Trade-off formula**:
|
||||
- Retrieval time ∝ log(n) × search_complexity
|
||||
- LLM processing time ∝ top_k × chunk_size
|
||||
- Total context = top_k × chunk_size tokens
|
||||
|
||||
### Thinking Budget for Reasoning Models
|
||||
|
||||
**`--thinking-budget`** (reasoning effort level)
|
||||
- Controls the computational effort for reasoning models
|
||||
- Options: `low`, `medium`, `high`
|
||||
- Guidelines:
|
||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||
- `medium`: Balanced speed and reasoning depth
|
||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||
- **Supported Models**:
|
||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||
|
||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||
|
||||
**💡 Quick Examples:**
|
||||
```bash
|
||||
# OpenAI o-series reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm openai --llm-model o3 --thinking-budget medium
|
||||
|
||||
# Ollama reasoning model
|
||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||
--index-dir hnswbuild --backend hnsw \
|
||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||
```
|
||||
|
||||
### Graph Degree (HNSW/DiskANN)
|
||||
|
||||
**`--graph-degree`**
|
||||
- Number of connections per node in the graph
|
||||
- Higher = better recall but more memory
|
||||
- HNSW: 16-32 (default: 32)
|
||||
- DiskANN: 32-128 (default: 64)
|
||||
|
||||
|
||||
## Performance Optimization Checklist
|
||||
|
||||
### If Embedding is Too Slow
|
||||
|
||||
1. **Switch to smaller model**:
|
||||
```bash
|
||||
# From large model
|
||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||
# To small model
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
2. **Limit dataset size for testing**:
|
||||
```bash
|
||||
--max-items 1000 # Process first 1k items only
|
||||
```
|
||||
|
||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||
```bash
|
||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||
```
|
||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||
|
||||
4. **Use Ollama**
|
||||
```bash
|
||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||
```
|
||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||
### If Search Quality is Poor
|
||||
|
||||
1. **Increase retrieval count**:
|
||||
```bash
|
||||
--top-k 30 # Retrieve more candidates
|
||||
```
|
||||
|
||||
2. **Upgrade embedding model**:
|
||||
```bash
|
||||
# For English
|
||||
--embedding-model BAAI/bge-base-en-v1.5
|
||||
# For multilingual
|
||||
--embedding-model intfloat/multilingual-e5-large
|
||||
```
|
||||
|
||||
## Understanding the Trade-offs
|
||||
|
||||
Every configuration choice involves trade-offs:
|
||||
|
||||
| Factor | Small/Fast | Large/Quality |
|
||||
|--------|------------|---------------|
|
||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||
| Chunk Size | 512 tokens | 128 tokens |
|
||||
| Index Type | HNSW | DiskANN |
|
||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||
|
||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||
|
||||
## Deep Dive: Critical Configuration Decisions
|
||||
|
||||
### When to Disable Recomputation
|
||||
|
||||
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
|
||||
|
||||
```bash
|
||||
--no-recompute # Disable selective recomputation
|
||||
```
|
||||
|
||||
**Trade-offs**:
|
||||
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
|
||||
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
|
||||
|
||||
**Disable when**:
|
||||
- You have abundant storage and memory
|
||||
- Need extremely low latency (< 100ms)
|
||||
- Running a read-heavy workload where storage cost is acceptable
|
||||
|
||||
## Further Reading
|
||||
|
||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||
@@ -1,11 +0,0 @@
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
@@ -7,4 +7,4 @@ You can speed up the process by using a lightweight embedding model. Add this to
|
||||
```bash
|
||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||
```
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||
|
||||
## 🛠️ Technical Highlights
|
||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||
@@ -13,10 +13,10 @@
|
||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||
|
||||
## 🎨 Developer Experience
|
||||
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
|
||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
||||
# Normalized Embeddings Support in LEANN
|
||||
|
||||
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||
|
||||
## What are Normalized Embeddings?
|
||||
|
||||
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||
|
||||
## Automatic Detection
|
||||
|
||||
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||
|
||||
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||
2. **Show a warning** if you manually specify a different distance metric
|
||||
3. **Provide optimal search performance** with the correct metric
|
||||
|
||||
## Supported Normalized Embedding Models
|
||||
|
||||
### OpenAI
|
||||
All OpenAI text embedding models are normalized:
|
||||
- `text-embedding-ada-002`
|
||||
- `text-embedding-3-small`
|
||||
- `text-embedding-3-large`
|
||||
|
||||
### Voyage AI
|
||||
All Voyage AI embedding models are normalized:
|
||||
- `voyage-2`
|
||||
- `voyage-3`
|
||||
- `voyage-large-2`
|
||||
- `voyage-multilingual-2`
|
||||
- `voyage-code-2`
|
||||
|
||||
### Cohere
|
||||
All Cohere embedding models are normalized:
|
||||
- `embed-english-v3.0`
|
||||
- `embed-multilingual-v3.0`
|
||||
- `embed-english-light-v3.0`
|
||||
- `embed-multilingual-light-v3.0`
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
# Automatic detection - will use cosine distance
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai"
|
||||
)
|
||||
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||
# Automatically setting distance_metric='cosine'
|
||||
|
||||
# Manual override (not recommended)
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
distance_metric="mips" # Will show warning
|
||||
)
|
||||
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||
```
|
||||
|
||||
## Non-Normalized Embeddings
|
||||
|
||||
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||
|
||||
## Why This Matters
|
||||
|
||||
Using the wrong distance metric with normalized embeddings can lead to:
|
||||
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||
- **Incorrect ranking** of search results
|
||||
- **Suboptimal performance** compared to using the correct metric
|
||||
|
||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
## 🎯 Q2 2025
|
||||
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] HNSW backend integration
|
||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [X] Real-time embedding pipeline
|
||||
- [X] Memory-efficient graph pruning
|
||||
|
||||
@@ -18,4 +18,4 @@
|
||||
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
- [ ] Query rewrtiting, rerank and expansion
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
Run: uv run python examples/basic_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -81,7 +81,7 @@ def main():
|
||||
print()
|
||||
|
||||
print("Demo completed! Try running:")
|
||||
print(" uv run python examples/document_search.py")
|
||||
print(" uv run python apps/document_rag.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,155 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Document search demo with recompute mode
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Import backend packages to trigger plugin registration
|
||||
try:
|
||||
import leann_backend_diskann # noqa: F401
|
||||
import leann_backend_hnsw # noqa: F401
|
||||
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# Import upper-level API from leann-core
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def load_sample_documents():
|
||||
"""Create sample documents for demonstration"""
|
||||
docs = [
|
||||
{
|
||||
"title": "Intro to Python",
|
||||
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||
},
|
||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
||||
{
|
||||
"title": "Data Structures",
|
||||
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||
},
|
||||
]
|
||||
return docs
|
||||
|
||||
|
||||
def main():
|
||||
print("==========================================================")
|
||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||
print("==========================================================")
|
||||
|
||||
INDEX_DIR = Path("./test_indices")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
BACKEND_TO_TEST = "diskann"
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# --- 1. Build index ---
|
||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||
|
||||
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
|
||||
|
||||
documents = load_sample_documents()
|
||||
print(f"Loaded {len(documents)} sample documents.")
|
||||
for doc in documents:
|
||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("\nIndex built!")
|
||||
|
||||
# --- 2. Basic search demo ---
|
||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||
|
||||
query = "What is machine learning?"
|
||||
print(f"\nQuery: '{query}'")
|
||||
|
||||
print("\n--- Basic search mode (PQ computation) ---")
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=2)
|
||||
basic_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(
|
||||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||
)
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print("\n[PHASE 3] Recompute search using embedding server...")
|
||||
|
||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||
|
||||
# Configure recompute parameters
|
||||
recompute_params = {
|
||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||
"skip_search_reorder": True, # Skip search reordering
|
||||
"dedup_node_dis": True, # Enable node distance deduplication
|
||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||
"batch_recompute": False, # Don't use batch recomputation
|
||||
"global_pruning": False, # Don't use global pruning
|
||||
"zmq_port": 5555, # ZMQ port
|
||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||||
}
|
||||
|
||||
print("Recompute parameter configuration:")
|
||||
for key, value in recompute_params.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print("\n🔄 Executing Recompute search...")
|
||||
try:
|
||||
start_time = time.time()
|
||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||
recompute_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(
|
||||
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||
)
|
||||
|
||||
# Compare results
|
||||
print("\n--- Result comparison ---")
|
||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i].score
|
||||
recompute_score = recompute_results[i].score
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(
|
||||
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
|
||||
)
|
||||
|
||||
if recompute_time > basic_time:
|
||||
print("✅ Recompute mode working correctly (more accurate but slower)")
|
||||
else:
|
||||
print("i️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Recompute search failed: {e}")
|
||||
print("This usually indicates an embedding server connection issue")
|
||||
|
||||
# --- 4. Chat demo ---
|
||||
print("\n[PHASE 4] Starting chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
chat_response = chat.ask(query)
|
||||
print(f"You: {query}")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
print("\n==========================================================")
|
||||
print("✅ Demo finished successfully!")
|
||||
print("==========================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,322 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
except ModuleNotFoundError:
|
||||
# python-dotenv is not installed; skip loading environment variables
|
||||
dotenv = None
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||
|
||||
# Default Chrome profile path
|
||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs: list[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple Chrome profile data sources.
|
||||
|
||||
Args:
|
||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process per profile
|
||||
"""
|
||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Chrome profile directory
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing Chrome profile {i + 1}/{len(profile_dirs)}: {profile_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir), max_count=max_count
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||
print(
|
||||
"\033[91mYou need to close or quit all chrome browser before running this script\033[0m"
|
||||
)
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
profile_path: str | None = None,
|
||||
index_path: str = "chrome_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from Chrome history data.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of history entries to process
|
||||
"""
|
||||
print("Creating LEANN index from Chrome history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using ChromeHistoryReader from history_data
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
documents = reader.load_data(chrome_profile_path=profile_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} history documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=10,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN Chrome History Reader - Create and query browser history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=DEFAULT_CHROME_PROFILE,
|
||||
help=f"Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./google_history_index",
|
||||
help="Directory to store the LEANN index (default: ./chrome_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of history entries to process (default: 1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||
|
||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Find Chrome profile directories
|
||||
from history_data.history import ChromeHistoryReader
|
||||
|
||||
if args.auto_find_profiles:
|
||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found automatically. Exiting.")
|
||||
return
|
||||
else:
|
||||
# Use single specified profile
|
||||
profile_path = Path(args.chrome_profile)
|
||||
if not profile_path.exists():
|
||||
print(f"Chrome profile not found: {profile_path}")
|
||||
return
|
||||
profile_dirs = [profile_path]
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||
profile_dirs, INDEX_PATH, args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"What websites did I visit about machine learning?",
|
||||
"Find my search history about programming",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,338 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
# Auto-detect user's mail path
|
||||
def get_mail_path():
|
||||
"""Get the mail path for the current user"""
|
||||
home_dir = os.path.expanduser("~")
|
||||
return os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
|
||||
# Default mail path for macOS
|
||||
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_sources(
|
||||
messages_dirs: list[Path],
|
||||
index_path: str = "mail_index.leann",
|
||||
max_count: int = -1,
|
||||
include_html: bool = False,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple mail data sources.
|
||||
|
||||
Args:
|
||||
messages_dirs: List of Path objects pointing to Messages directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process per directory
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from multiple mail data sources...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each Messages directory
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing Messages directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(messages_dir)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {messages_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = node.get_content()
|
||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
mail_path: str,
|
||||
index_path: str = "mail_index.leann",
|
||||
max_count: int = 1000,
|
||||
include_html: bool = False,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
):
|
||||
"""
|
||||
Create LEANN index from mail data.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of emails to process
|
||||
include_html: Whether to include HTML content in email processing
|
||||
"""
|
||||
print("Creating LEANN index from mail data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using EmlxReader from LEANN_email_reader
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
reader = EmlxReader(include_html=include_html)
|
||||
# from email_data.email import EmlxMboxReader
|
||||
# from pathlib import Path
|
||||
# reader = EmlxMboxReader()
|
||||
documents = reader.load_data(Path(mail_path))
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} email documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=embedding_model,
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path, llm_config={"type": "openai", "model": "gpt-4o"})
|
||||
|
||||
print(f"You: {query}")
|
||||
import time
|
||||
|
||||
time.time()
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=32,
|
||||
beam_width=1,
|
||||
)
|
||||
time.time()
|
||||
# print(f"Time taken: {end_time - start_time} seconds")
|
||||
# highlight the answer
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="LEANN Mail Reader - Create and query email index")
|
||||
# Remove --mail-path argument and auto-detect all Messages directories
|
||||
# Remove DEFAULT_MAIL_PATH
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./mail_index",
|
||||
help="Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of emails to process (-1 means all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Give me some funny advertisement about apple or other companies",
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include HTML content in email processing (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="facebook/contriever",
|
||||
help="Embedding model to use (default: facebook/contriever)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"args: {args}")
|
||||
|
||||
# Automatically find all Messages directories under the current user's Mail directory
|
||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||
|
||||
mail_path = get_mail_path()
|
||||
print(f"Searching for email data in: {mail_path}")
|
||||
messages_dirs = find_all_messages_directories(mail_path)
|
||||
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||
# messages_dirs = messages_dirs[:1]
|
||||
|
||||
print("len(messages_dirs): ", len(messages_dirs))
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Messages directories found. Exiting.")
|
||||
return
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_sources(
|
||||
messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,129 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the project root to Python path so we can import from examples
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import torch
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# --- EMBEDDING MODEL ---
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
|
||||
# --- END EMBEDDING MODEL ---
|
||||
# Import EmlxReader from the new module
|
||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
def create_and_save_index(
|
||||
mail_path: str,
|
||||
save_dir: str = "mail_index_embedded",
|
||||
max_count: int = 1000,
|
||||
include_html: bool = False,
|
||||
):
|
||||
print("Creating index from mail data with embedded metadata...")
|
||||
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
# Use facebook/contriever as the embedder
|
||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||
# set on device
|
||||
|
||||
if torch.cuda.is_available():
|
||||
embed_model._model.to("cuda")
|
||||
# set mps
|
||||
elif torch.backends.mps.is_available():
|
||||
embed_model._model.to("mps")
|
||||
else:
|
||||
embed_model._model.to("cpu")
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[text_splitter], embed_model=embed_model
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index_embedded"):
|
||||
try:
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LlamaIndex Mail Reader - Create and query email index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
help="Path to mail data directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-dir",
|
||||
type=str,
|
||||
default="mail_index_embedded",
|
||||
help="Directory to store the index (default: mail_index_embedded)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-emails", type=int, default=10000, help="Maximum number of emails to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-html",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Include HTML content in email processing (default: False)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mail_path = args.mail_path
|
||||
save_dir = args.save_dir
|
||||
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(
|
||||
mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html
|
||||
)
|
||||
if index:
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"how's the icloud related advertisement saying",
|
||||
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||
]
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,118 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
async def main(args):
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
).load_data(show_progress=True)
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print("--- Index directory not found, building new index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
|
||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||
|
||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||
# query = (
|
||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
# )
|
||||
query = args.query
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||
parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="hf",
|
||||
choices=["simulated", "ollama", "hf", "openai"],
|
||||
help="The LLM backend to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="http://localhost:11434",
|
||||
help="The host for the Ollama API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./test_doc_files",
|
||||
help="Directory where the Leann index will be stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="examples/data",
|
||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||
help="The query to ask the Leann chat system.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
@@ -1,358 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Multi-Vector Aggregator for Fat Embeddings
|
||||
==========================================
|
||||
|
||||
This module implements aggregation strategies for multi-vector embeddings,
|
||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
||||
|
||||
Key features:
|
||||
- MaxSim aggregation (take maximum similarity across patches)
|
||||
- Voting-based aggregation (count patch matches)
|
||||
- Weighted aggregation (attention-score weighted)
|
||||
- Spatial clustering of matching patches
|
||||
- Document-level result consolidation
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Represents a single patch search result."""
|
||||
|
||||
patch_id: int
|
||||
image_name: str
|
||||
image_path: str
|
||||
coordinates: tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||
score: float
|
||||
attention_score: float
|
||||
scale: float
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatedResult:
|
||||
"""Represents an aggregated document-level result."""
|
||||
|
||||
image_name: str
|
||||
image_path: str
|
||||
doc_score: float
|
||||
patch_count: int
|
||||
best_patch: PatchResult
|
||||
all_patches: list[PatchResult]
|
||||
aggregation_method: str
|
||||
spatial_clusters: list[list[PatchResult]] | None = None
|
||||
|
||||
|
||||
class MultiVectorAggregator:
|
||||
"""
|
||||
Aggregates multiple patch-level results into document-level results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aggregation_method: str = "maxsim",
|
||||
spatial_clustering: bool = True,
|
||||
cluster_distance_threshold: float = 100.0,
|
||||
):
|
||||
"""
|
||||
Initialize the aggregator.
|
||||
|
||||
Args:
|
||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||
spatial_clustering: Whether to cluster spatially close patches
|
||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
||||
"""
|
||||
self.aggregation_method = aggregation_method
|
||||
self.spatial_clustering = spatial_clustering
|
||||
self.cluster_distance_threshold = cluster_distance_threshold
|
||||
|
||||
def aggregate_results(
|
||||
self, search_results: list[dict[str, Any]], top_k: int = 10
|
||||
) -> list[AggregatedResult]:
|
||||
"""
|
||||
Aggregate patch-level search results into document-level results.
|
||||
|
||||
Args:
|
||||
search_results: List of search results from LeannSearcher
|
||||
top_k: Number of top documents to return
|
||||
|
||||
Returns:
|
||||
List of aggregated document results
|
||||
"""
|
||||
# Group results by image
|
||||
image_groups = defaultdict(list)
|
||||
|
||||
for result in search_results:
|
||||
metadata = result.metadata
|
||||
if "image_name" in metadata and "patch_id" in metadata:
|
||||
patch_result = PatchResult(
|
||||
patch_id=metadata["patch_id"],
|
||||
image_name=metadata["image_name"],
|
||||
image_path=metadata["image_path"],
|
||||
coordinates=tuple(metadata["coordinates"]),
|
||||
score=result.score,
|
||||
attention_score=metadata.get("attention_score", 0.0),
|
||||
scale=metadata.get("scale", 1.0),
|
||||
metadata=metadata,
|
||||
)
|
||||
image_groups[metadata["image_name"]].append(patch_result)
|
||||
|
||||
# Aggregate each image group
|
||||
aggregated_results = []
|
||||
for image_name, patches in image_groups.items():
|
||||
if len(patches) == 0:
|
||||
continue
|
||||
|
||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||
aggregated_results.append(agg_result)
|
||||
|
||||
# Sort by aggregated score and return top-k
|
||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||
return aggregated_results[:top_k]
|
||||
|
||||
def _aggregate_image_patches(
|
||||
self, image_name: str, patches: list[PatchResult]
|
||||
) -> AggregatedResult:
|
||||
"""Aggregate patches for a single image."""
|
||||
|
||||
if self.aggregation_method == "maxsim":
|
||||
doc_score = max(patch.score for patch in patches)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "voting":
|
||||
# Count patches above threshold
|
||||
threshold = np.percentile([p.score for p in patches], 75)
|
||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
elif self.aggregation_method == "weighted":
|
||||
# Weight by attention scores
|
||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||
total_weights = sum(p.attention_score for p in patches)
|
||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||
|
||||
elif self.aggregation_method == "mean":
|
||||
doc_score = np.mean([patch.score for patch in patches])
|
||||
best_patch = max(patches, key=lambda p: p.score)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||
|
||||
# Spatial clustering if enabled
|
||||
spatial_clusters = None
|
||||
if self.spatial_clustering:
|
||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||
|
||||
return AggregatedResult(
|
||||
image_name=image_name,
|
||||
image_path=patches[0].image_path,
|
||||
doc_score=float(doc_score),
|
||||
patch_count=len(patches),
|
||||
best_patch=best_patch,
|
||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||
aggregation_method=self.aggregation_method,
|
||||
spatial_clusters=spatial_clusters,
|
||||
)
|
||||
|
||||
def _cluster_patches_spatially(self, patches: list[PatchResult]) -> list[list[PatchResult]]:
|
||||
"""Cluster patches that are spatially close to each other."""
|
||||
if len(patches) <= 1:
|
||||
return [patches]
|
||||
|
||||
clusters = []
|
||||
remaining_patches = patches.copy()
|
||||
|
||||
while remaining_patches:
|
||||
# Start new cluster with highest scoring remaining patch
|
||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||
current_cluster = [seed_patch]
|
||||
remaining_patches.remove(seed_patch)
|
||||
|
||||
# Add nearby patches to cluster
|
||||
added_to_cluster = True
|
||||
while added_to_cluster:
|
||||
added_to_cluster = False
|
||||
for patch in remaining_patches.copy():
|
||||
if self._is_patch_nearby(patch, current_cluster):
|
||||
current_cluster.append(patch)
|
||||
remaining_patches.remove(patch)
|
||||
added_to_cluster = True
|
||||
|
||||
clusters.append(current_cluster)
|
||||
|
||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||
|
||||
def _is_patch_nearby(self, patch: PatchResult, cluster: list[PatchResult]) -> bool:
|
||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||
patch_center = self._get_patch_center(patch.coordinates)
|
||||
|
||||
for cluster_patch in cluster:
|
||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||
distance = np.sqrt(
|
||||
(patch_center[0] - cluster_center[0]) ** 2
|
||||
+ (patch_center[1] - cluster_center[1]) ** 2
|
||||
)
|
||||
|
||||
if distance <= self.cluster_distance_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_patch_center(self, coordinates: tuple[int, int, int, int]) -> tuple[float, float]:
|
||||
"""Get center point of a patch."""
|
||||
x1, y1, x2, y2 = coordinates
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def print_aggregated_results(
|
||||
self, results: list[AggregatedResult], max_patches_per_doc: int = 3
|
||||
):
|
||||
"""Pretty print aggregated results."""
|
||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||
print("=" * 80)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n{i + 1}. {result.image_name}")
|
||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||
print(f" Path: {result.image_path}")
|
||||
|
||||
# Show best patch
|
||||
best = result.best_patch
|
||||
print(
|
||||
f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})"
|
||||
)
|
||||
|
||||
# Show top patches
|
||||
print(" 📍 Top Patches:")
|
||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||
print(
|
||||
f" {j + 1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}"
|
||||
)
|
||||
|
||||
# Show spatial clusters if available
|
||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||
cluster_score = max(p.score for p in cluster)
|
||||
print(
|
||||
f" Cluster {j + 1}: {len(cluster)} patches (best: {cluster_score:.4f})"
|
||||
)
|
||||
|
||||
|
||||
def demo_aggregation():
|
||||
"""Demonstrate the multi-vector aggregation functionality."""
|
||||
print("=== Multi-Vector Aggregation Demo ===")
|
||||
|
||||
# Simulate some patch-level search results
|
||||
# In real usage, these would come from LeannSearcher.search()
|
||||
|
||||
class MockResult:
|
||||
def __init__(self, score, metadata):
|
||||
self.score = score
|
||||
self.metadata = metadata
|
||||
|
||||
# Simulate results for 2 images with multiple patches each
|
||||
mock_results = [
|
||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||
MockResult(
|
||||
0.85,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 3,
|
||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||
"attention_score": 0.92,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.78,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 7,
|
||||
"coordinates": [200, 300, 324, 424], # Cat area
|
||||
"attention_score": 0.88,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.72,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 12,
|
||||
"coordinates": [150, 100, 274, 224], # Appliances
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.65,
|
||||
{
|
||||
"image_name": "cats_and_kitchen.jpg",
|
||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||
"patch_id": 15,
|
||||
"coordinates": [50, 250, 174, 374], # Furniture
|
||||
"attention_score": 0.70,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
# Image 2: city_street.jpg - 3 patches
|
||||
MockResult(
|
||||
0.68,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 2,
|
||||
"coordinates": [300, 100, 424, 224], # Buildings
|
||||
"attention_score": 0.80,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.62,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 8,
|
||||
"coordinates": [100, 350, 224, 474], # Street level
|
||||
"attention_score": 0.75,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
MockResult(
|
||||
0.55,
|
||||
{
|
||||
"image_name": "city_street.jpg",
|
||||
"image_path": "/path/to/city_street.jpg",
|
||||
"patch_id": 11,
|
||||
"coordinates": [400, 200, 524, 324], # Sky area
|
||||
"attention_score": 0.60,
|
||||
"scale": 1.0,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
# Test different aggregation methods
|
||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||
|
||||
for method in methods:
|
||||
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||
|
||||
aggregator = MultiVectorAggregator(
|
||||
aggregation_method=method, spatial_clustering=True, cluster_distance_threshold=100.0
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||
aggregator.print_aggregated_results(aggregated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_aggregation()
|
||||
@@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OpenAI Embedding Example
|
||||
|
||||
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
def main():
|
||||
# Check if OpenAI API key is available
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||
return False
|
||||
|
||||
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||
|
||||
# Sample texts
|
||||
sample_texts = [
|
||||
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||
"Natural language processing helps computers understand and generate human language.",
|
||||
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
||||
"Computer vision allows machines to interpret and understand visual information.",
|
||||
"Reinforcement learning trains agents to make decisions through trial and error.",
|
||||
"Data science combines statistics, math, and programming to extract insights from data.",
|
||||
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||
"Python is a popular programming language used extensively in data science and AI.",
|
||||
"Neural networks are inspired by the structure and function of the human brain.",
|
||||
"Big data refers to extremely large datasets that require special tools to process.",
|
||||
]
|
||||
|
||||
INDEX_DIR = Path("./simple_openai_test_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||
|
||||
print("\n=== Building Index with OpenAI Embeddings ===")
|
||||
print(f"Index path: {INDEX_PATH}")
|
||||
|
||||
try:
|
||||
# Use proper configuration for OpenAI embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_mode="openai",
|
||||
# HNSW settings for OpenAI embeddings
|
||||
M=16, # Smaller graph degree
|
||||
efConstruction=64, # Smaller construction complexity
|
||||
is_compact=True, # Enable compact storage for recompute
|
||||
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||
for i, text in enumerate(sample_texts):
|
||||
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||
builder.add_text(text, metadata)
|
||||
|
||||
print("Building index...")
|
||||
builder.build_index(INDEX_PATH)
|
||||
print("✅ Index built successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error building index: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print("\n=== Testing Search ===")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
|
||||
test_queries = [
|
||||
"What is machine learning?",
|
||||
"How do neural networks work?",
|
||||
"Programming languages for data science",
|
||||
]
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\n🔍 Query: '{query}'")
|
||||
results = searcher.search(query, top_k=3)
|
||||
|
||||
print(f" Found {len(results)} results:")
|
||||
for i, result in enumerate(results):
|
||||
print(f" {i + 1}. Score: {result.score:.4f}")
|
||||
print(f" Text: {result.text[:80]}...")
|
||||
|
||||
print("\n✅ Search test completed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during search: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
if success:
|
||||
print("\n🎉 Simple OpenAI index test completed successfully!")
|
||||
else:
|
||||
print("\n💥 Simple OpenAI index test failed!")
|
||||
@@ -1,23 +0,0 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from leann.api import LeannChat
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
|
||||
async def main():
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||
response = chat.ask(
|
||||
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1
|
||||
)
|
||||
print(f"\n[PHASE 2] Response: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,320 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Default WeChat export directory
|
||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||
|
||||
|
||||
def create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs: list[Path],
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = -1,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from multiple WeChat export data sources.
|
||||
|
||||
Args:
|
||||
export_dirs: List of Path objects pointing to WeChat export directories
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process per export
|
||||
"""
|
||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
# Process each WeChat export directory
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_count,
|
||||
concatenate_messages=True, # Disable concatenation - one message per document
|
||||
)
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
|
||||
# Check if we've reached the max count
|
||||
if max_count > 0 and total_processed >= max_count:
|
||||
print(f"Reached max count of {max_count} documents")
|
||||
break
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return None
|
||||
|
||||
print(
|
||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||
)
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
text = (
|
||||
"[Contact] means the message is from: "
|
||||
+ doc.metadata["contact_name"]
|
||||
+ "\n"
|
||||
+ node.get_content()
|
||||
)
|
||||
all_texts.append(text)
|
||||
|
||||
print(
|
||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||
)
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
def create_leann_index(
|
||||
export_dir: str | None = None,
|
||||
index_path: str = "wechat_history_index.leann",
|
||||
max_count: int = 1000,
|
||||
):
|
||||
"""
|
||||
Create LEANN index from WeChat chat history data.
|
||||
|
||||
Args:
|
||||
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||
index_path: Path to save the LEANN index
|
||||
max_count: Maximum number of chat entries to process
|
||||
"""
|
||||
print("Creating LEANN index from WeChat chat history data...")
|
||||
INDEX_DIR = Path(index_path).parent
|
||||
|
||||
if not INDEX_DIR.exists():
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Load documents using WeChatHistoryReader from history_data
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=export_dir,
|
||||
max_count=max_count,
|
||||
concatenate_messages=False, # Disable concatenation - one message per document
|
||||
)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
print(f"Loaded {len(documents)} chat documents")
|
||||
|
||||
# Create text splitter with 256 chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||
|
||||
# Convert Documents to text strings and chunk them
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# Split the document into chunks
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||
|
||||
# Create LEANN index directory
|
||||
print("--- Index directory not found, building new index ---")
|
||||
INDEX_DIR.mkdir(exist_ok=True)
|
||||
|
||||
print("--- Building new LEANN index ---")
|
||||
|
||||
print("\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# Use HNSW backend for better macOS compatibility
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
print(f"\nLEANN index built at {index_path}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
return index_path
|
||||
|
||||
|
||||
async def query_leann_index(index_path: str, query: str):
|
||||
"""
|
||||
Query the LEANN index.
|
||||
|
||||
Args:
|
||||
index_path: Path to the LEANN index
|
||||
query: The query string
|
||||
"""
|
||||
print("\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=index_path)
|
||||
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(
|
||||
query,
|
||||
top_k=20,
|
||||
recompute_beighbor_embeddings=True,
|
||||
complexity=16,
|
||||
beam_width=1,
|
||||
llm_config={
|
||||
"type": "openai",
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||
)
|
||||
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function with integrated WeChat export functionality."""
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default="./wechat_history_magic_test_11Debug_new",
|
||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-entries",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of chat entries to process (default: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Single query to run (default: runs example queries)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
INDEX_DIR = Path(args.index_dir)
|
||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||
|
||||
print(f"Using WeChat export directory: {args.export_dir}")
|
||||
print(f"Index directory: {INDEX_DIR}")
|
||||
print(f"Max entries: {args.max_entries}")
|
||||
|
||||
# Initialize WeChat reader with export capabilities
|
||||
from history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Exiting.")
|
||||
return
|
||||
|
||||
# Create or load the LEANN index from all sources
|
||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||
)
|
||||
|
||||
if index_path:
|
||||
if args.query:
|
||||
# Run single query
|
||||
await query_leann_index(index_path, args.query)
|
||||
else:
|
||||
# Example queries
|
||||
queries = [
|
||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 60)
|
||||
await query_leann_index(index_path, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
from . import diskann_backend as diskann_backend
|
||||
from . import graph_partition
|
||||
|
||||
# Export main classes and functions
|
||||
from .graph_partition import GraphPartitioner, partition_graph
|
||||
|
||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
||||
|
||||
@@ -4,9 +4,10 @@ import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
from leann.interface import (
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendFactoryInterface,
|
||||
@@ -84,6 +85,43 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
||||
f.write(data.tobytes())
|
||||
|
||||
|
||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
||||
"""
|
||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
||||
|
||||
Args:
|
||||
data: The embedding data array
|
||||
|
||||
Returns:
|
||||
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
||||
"""
|
||||
num_vectors, dim = data.shape
|
||||
|
||||
# Calculate embedding storage size
|
||||
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
||||
embedding_size_gb = embedding_size_bytes / (1024**3)
|
||||
|
||||
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
||||
# This controls Product Quantization size - smaller means more compression
|
||||
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
||||
|
||||
# build_memory_maximum: Based on available system RAM for sharding control
|
||||
# This controls how much memory DiskANN uses during index construction
|
||||
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
||||
|
||||
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
||||
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
||||
|
||||
logger.info(
|
||||
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
||||
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
||||
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
||||
)
|
||||
|
||||
return search_memory_gb, build_memory_gb
|
||||
|
||||
|
||||
@register_backend("diskann")
|
||||
class DiskannBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -99,6 +137,71 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
||||
"""
|
||||
Safely cleanup files after partition.
|
||||
In partition mode, C++ doesn't read _disk.index content,
|
||||
so we can delete it if all derived files exist.
|
||||
"""
|
||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
||||
|
||||
# Required files that C++ partition mode needs
|
||||
# Note: C++ generates these with _disk.index suffix
|
||||
disk_suffix = "_disk.index"
|
||||
required_files = [
|
||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
||||
]
|
||||
|
||||
# Check if all required files exist
|
||||
missing_files = []
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if not file_path.exists():
|
||||
missing_files.append(filename)
|
||||
|
||||
if missing_files:
|
||||
logger.warning(
|
||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
||||
)
|
||||
logger.info("Keeping all original files for safety")
|
||||
return
|
||||
|
||||
# Calculate space savings
|
||||
space_saved = 0
|
||||
files_to_delete = []
|
||||
|
||||
if disk_index_file.exists():
|
||||
space_saved += disk_index_file.stat().st_size
|
||||
files_to_delete.append(disk_index_file)
|
||||
|
||||
if beam_search_file.exists():
|
||||
space_saved += beam_search_file.stat().st_size
|
||||
files_to_delete.append(beam_search_file)
|
||||
|
||||
# Safe to delete!
|
||||
for file_to_delete in files_to_delete:
|
||||
try:
|
||||
os.remove(file_to_delete)
|
||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
||||
|
||||
if space_saved > 0:
|
||||
space_saved_mb = space_saved / (1024 * 1024)
|
||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
||||
|
||||
# Show what files are kept
|
||||
logger.info("📁 Kept essential files for partition mode:")
|
||||
for filename in required_files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
@@ -113,6 +216,17 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
|
||||
# Extract is_recompute from nested backend_kwargs if needed
|
||||
is_recompute = build_kwargs.get("is_recompute", False)
|
||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
||||
|
||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
||||
if "backend_kwargs" in build_kwargs:
|
||||
nested_params = build_kwargs.pop("backend_kwargs")
|
||||
build_kwargs.update(nested_params)
|
||||
|
||||
metric_enum = _get_diskann_metrics().get(
|
||||
build_kwargs.get("distance_metric", "mips").lower()
|
||||
)
|
||||
@@ -121,6 +235,16 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||
)
|
||||
|
||||
# Calculate smart memory configuration if not explicitly provided
|
||||
if (
|
||||
"search_memory_maximum" not in build_kwargs
|
||||
or "build_memory_maximum" not in build_kwargs
|
||||
):
|
||||
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
||||
else:
|
||||
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
||||
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy # type: ignore
|
||||
|
||||
@@ -131,12 +255,36 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
index_prefix,
|
||||
build_kwargs.get("complexity", 64),
|
||||
build_kwargs.get("graph_degree", 32),
|
||||
build_kwargs.get("search_memory_maximum", 4.0),
|
||||
build_kwargs.get("build_memory_maximum", 8.0),
|
||||
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
||||
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
||||
build_kwargs.get("num_threads", 8),
|
||||
build_kwargs.get("pq_disk_bytes", 0),
|
||||
"",
|
||||
)
|
||||
|
||||
# Auto-partition if is_recompute is enabled
|
||||
if build_kwargs.get("is_recompute", False):
|
||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
||||
from .graph_partition import partition_graph
|
||||
|
||||
# Partition the index using absolute paths
|
||||
# Convert to absolute paths to avoid issues with working directory changes
|
||||
absolute_index_dir = Path(index_dir).resolve()
|
||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
index_prefix_path=absolute_index_prefix_path,
|
||||
output_dir=str(absolute_index_dir),
|
||||
partition_prefix=index_prefix,
|
||||
)
|
||||
|
||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
||||
|
||||
logger.info("✅ Graph partitioning completed successfully!")
|
||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
||||
logger.info(f" - Partition file: {partition_bin_path}")
|
||||
|
||||
finally:
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
@@ -163,18 +311,69 @@ class DiskannSearcher(BaseSearcher):
|
||||
|
||||
self.num_threads = kwargs.get("num_threads", 8)
|
||||
|
||||
fake_zmq_port = 6666
|
||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum,
|
||||
full_index_prefix,
|
||||
self.num_threads,
|
||||
kwargs.get("num_nodes_to_cache", 0),
|
||||
1,
|
||||
fake_zmq_port, # Initial port, can be updated at runtime
|
||||
"",
|
||||
"",
|
||||
)
|
||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||
# Store the initialization parameters for later use
|
||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
||||
# C++ internally constructs: index_prefix + "_disk.index"
|
||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
||||
|
||||
# Auto-detect partition files and set partition_prefix
|
||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
||||
|
||||
partition_prefix = ""
|
||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
||||
# C++ expects full path prefix, not just filename
|
||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
else:
|
||||
logger.debug("No partition files detected, using standard index files")
|
||||
|
||||
self._init_params = {
|
||||
"metric_enum": metric_enum,
|
||||
"full_index_prefix": full_index_prefix,
|
||||
"num_threads": self.num_threads,
|
||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||
"cache_mechanism": 1,
|
||||
"pq_prefix": "",
|
||||
"partition_prefix": partition_prefix,
|
||||
}
|
||||
|
||||
# Log partition configuration for debugging
|
||||
if partition_prefix:
|
||||
logger.info(
|
||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
||||
)
|
||||
self._diskannpy = diskannpy
|
||||
self._current_zmq_port = None
|
||||
self._index = None
|
||||
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||
|
||||
def _ensure_index_loaded(self, zmq_port: int):
|
||||
"""Ensure the index is loaded with the correct zmq_port."""
|
||||
if self._index is None or self._current_zmq_port != zmq_port:
|
||||
# Need to (re)load the index with the correct zmq_port
|
||||
with suppress_cpp_output_if_needed():
|
||||
if self._index is not None:
|
||||
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||
else:
|
||||
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||
|
||||
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||
self._init_params["metric_enum"],
|
||||
self._init_params["full_index_prefix"],
|
||||
self._init_params["num_threads"],
|
||||
self._init_params["num_nodes_to_cache"],
|
||||
self._init_params["cache_mechanism"],
|
||||
zmq_port,
|
||||
self._init_params["pq_prefix"],
|
||||
self._init_params["partition_prefix"],
|
||||
)
|
||||
self._current_zmq_port = zmq_port
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -185,7 +384,7 @@ class DiskannSearcher(BaseSearcher):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
batch_recompute: bool = False,
|
||||
dedup_node_dis: bool = False,
|
||||
**kwargs,
|
||||
@@ -212,14 +411,15 @@ class DiskannSearcher(BaseSearcher):
|
||||
Returns:
|
||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||
"""
|
||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||
if recompute_embeddings:
|
||||
if zmq_port is None:
|
||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||
current_port = self._index.get_zmq_port()
|
||||
if zmq_port != current_port:
|
||||
logger.debug(f"Updating DiskANN zmq_port from {current_port} to {zmq_port}")
|
||||
self._index.set_zmq_port(zmq_port)
|
||||
self._ensure_index_loaded(zmq_port)
|
||||
else:
|
||||
# If not recomputing, we still need an index, use a default port
|
||||
if self._index is None:
|
||||
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||
|
||||
# DiskANN doesn't support "proportional" strategy
|
||||
if pruning_strategy == "proportional":
|
||||
@@ -237,6 +437,8 @@ class DiskannSearcher(BaseSearcher):
|
||||
use_global_pruning = True
|
||||
|
||||
# Perform search with suppressed C++ output based on log level
|
||||
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
|
||||
recompute_neighors = False
|
||||
with suppress_cpp_output_if_needed():
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
@@ -245,9 +447,9 @@ class DiskannSearcher(BaseSearcher):
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||
use_deferred_fetch,
|
||||
kwargs.get("skip_search_reorder", False),
|
||||
recompute_embeddings,
|
||||
recompute_neighors,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
@@ -257,3 +459,25 @@ class DiskannSearcher(BaseSearcher):
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup DiskANN-specific resources including C++ index."""
|
||||
# Call parent cleanup first
|
||||
super().cleanup()
|
||||
|
||||
# Delete the C++ index to trigger destructors
|
||||
try:
|
||||
if hasattr(self, "_index") and self._index is not None:
|
||||
del self._index
|
||||
self._index = None
|
||||
self._current_zmq_port = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection to ensure C++ objects are destroyed
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
@@ -32,10 +33,11 @@ if not logger.handlers:
|
||||
|
||||
|
||||
def create_diskann_embedding_server(
|
||||
passages_file: str | None = None,
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
distance_metric: str = "l2",
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||
@@ -79,7 +81,8 @@ def create_diskann_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
@@ -97,6 +100,7 @@ def create_diskann_embedding_server(
|
||||
socket = context.socket(
|
||||
zmq.REP
|
||||
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||
|
||||
@@ -260,9 +264,16 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "mips", "cosine"],
|
||||
help="Distance metric for similarity computation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -272,4 +283,5 @@ if __name__ == "__main__":
|
||||
zmq_port=args.zmq_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,299 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides Python bindings for the graph partition functionality
|
||||
of DiskANN, allowing users to partition disk-based indices for better
|
||||
performance.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GraphPartitioner:
|
||||
"""
|
||||
A Python interface for DiskANN's graph partition functionality.
|
||||
|
||||
This class provides methods to partition disk-based indices for improved
|
||||
search performance and memory efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, build_type: str = "release"):
|
||||
"""
|
||||
Initialize the GraphPartitioner.
|
||||
|
||||
Args:
|
||||
build_type: Build type for the executables ("debug" or "release")
|
||||
"""
|
||||
self.build_type = build_type
|
||||
self._ensure_executables()
|
||||
|
||||
def _get_executable_path(self, name: str) -> str:
|
||||
"""Get the path to a graph partition executable."""
|
||||
# Get the directory where this Python module is located
|
||||
module_dir = Path(__file__).parent
|
||||
# Navigate to the graph_partition directory
|
||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
||||
|
||||
if not executable_path.exists():
|
||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
||||
|
||||
return str(executable_path)
|
||||
|
||||
def _ensure_executables(self):
|
||||
"""Ensure that the required executables are built."""
|
||||
try:
|
||||
self._get_executable_path("partitioner")
|
||||
self._get_executable_path("index_relayout")
|
||||
except FileNotFoundError:
|
||||
# Try to build the executables automatically
|
||||
print("Executables not found, attempting to build them...")
|
||||
self._build_executables()
|
||||
|
||||
def _build_executables(self):
|
||||
"""Build the required executables."""
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Clean any existing build
|
||||
if (graph_partition_dir / "build").exists():
|
||||
shutil.rmtree(graph_partition_dir / "build")
|
||||
|
||||
# Run the build script
|
||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
||||
|
||||
# Check if executables were created
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
print(f"✅ Built partitioner: {partitioner_path}")
|
||||
print(f"✅ Built index_relayout: {relayout_path}")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to build executables: {e}")
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def partition_graph(
|
||||
self,
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Partition a disk-based index for improved performance.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning:
|
||||
- gp_times: Number of LDG partition iterations (default: 10)
|
||||
- lock_nums: Number of lock nodes (default: 10)
|
||||
- cut: Cut adjacency list degree (default: 100)
|
||||
- scale_factor: Scale factor (default: 1)
|
||||
- data_type: Data type (default: "float")
|
||||
- thread_nums: Number of threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the partitioning process fails
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine partition prefix
|
||||
if partition_prefix is None:
|
||||
partition_prefix = Path(index_prefix_path).name
|
||||
|
||||
# Get executable paths
|
||||
partitioner_path = self._get_executable_path("partitioner")
|
||||
relayout_path = self._get_executable_path("index_relayout")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Change to the graph_partition directory for temporary files
|
||||
graph_partition_dir = (
|
||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
)
|
||||
original_dir = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(graph_partition_dir)
|
||||
|
||||
# Create temporary data directory
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Run partitioner
|
||||
gp_file_path = graph_gp_path / "_part.bin"
|
||||
partitioner_cmd = [
|
||||
partitioner_path,
|
||||
"--index_file",
|
||||
old_index_file,
|
||||
"--data_type",
|
||||
params["data_type"],
|
||||
"--gp_file",
|
||||
str(gp_file_path),
|
||||
"-T",
|
||||
str(params["thread_nums"]),
|
||||
"--ldg_times",
|
||||
str(params["gp_times"]),
|
||||
"--scale",
|
||||
str(params["scale_factor"]),
|
||||
"--mode",
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
||||
result = subprocess.run(
|
||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Partitioner failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Run relayout
|
||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
||||
relayout_cmd = [
|
||||
relayout_path,
|
||||
old_index_file,
|
||||
str(gp_file_path),
|
||||
params["data_type"],
|
||||
"1",
|
||||
]
|
||||
|
||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
||||
result = subprocess.run(
|
||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Relayout failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Copy results to output directory
|
||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
||||
|
||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
||||
shutil.copy2(gp_file_path, partition_bin_path)
|
||||
|
||||
print(f"Results copied to: {output_dir}")
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
||||
"""
|
||||
Get information about a partition file.
|
||||
|
||||
Args:
|
||||
partition_bin_path: Path to the partition binary file
|
||||
|
||||
Returns:
|
||||
Dictionary containing partition information
|
||||
"""
|
||||
if not os.path.exists(partition_bin_path):
|
||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
||||
|
||||
# For now, return basic file information
|
||||
# In the future, this could parse the binary file for detailed info
|
||||
stat = os.stat(partition_bin_path)
|
||||
return {
|
||||
"file_size": stat.st_size,
|
||||
"file_path": partition_bin_path,
|
||||
"modified_time": stat.st_mtime,
|
||||
}
|
||||
|
||||
|
||||
def partition_graph(
|
||||
index_prefix_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
partition_prefix: Optional[str] = None,
|
||||
build_type: str = "release",
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Convenience function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
||||
build_type: Build type for executables ("debug" or "release")
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
partitioner = GraphPartitioner(build_type=build_type)
|
||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Example: partition an index
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph(
|
||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
||||
)
|
||||
print("Partitioning completed successfully!")
|
||||
print(f"Disk graph index: {disk_graph_path}")
|
||||
print(f"Partition binary: {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Partitioning failed: {e}")
|
||||
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simplified Graph Partition Module for LEANN DiskANN Backend
|
||||
|
||||
This module provides a simple Python interface for graph partitioning
|
||||
that directly calls the existing executables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def partition_graph_simple(
|
||||
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Simple function to partition a graph index.
|
||||
|
||||
Args:
|
||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
||||
**kwargs: Additional parameters for graph partitioning
|
||||
|
||||
Returns:
|
||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
||||
"""
|
||||
# Set default parameters
|
||||
params = {
|
||||
"gp_times": 10,
|
||||
"lock_nums": 10,
|
||||
"cut": 100,
|
||||
"scale_factor": 1,
|
||||
"data_type": "float",
|
||||
"thread_nums": 10,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Determine output directory
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(index_prefix_path).parent)
|
||||
|
||||
# Find the graph_partition directory
|
||||
current_file = Path(__file__)
|
||||
graph_partition_dir = current_file.parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
||||
|
||||
if not graph_partition_dir.exists():
|
||||
raise RuntimeError(f"Graph partition directory not found: {graph_partition_dir}")
|
||||
|
||||
# Find input index file
|
||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
||||
if not os.path.exists(old_index_file):
|
||||
old_index_file = f"{index_prefix_path}_disk.index"
|
||||
|
||||
if not os.path.exists(old_index_file):
|
||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
||||
|
||||
# Create temporary directory for processing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_data_dir = Path(temp_dir) / "data"
|
||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set up paths for temporary files
|
||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
||||
graph_gp_path = (
|
||||
graph_path
|
||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
||||
)
|
||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run the build script with our parameters
|
||||
cmd = [str(graph_partition_dir / "build.sh"), "release", "split_graph", index_prefix_path]
|
||||
|
||||
# Set environment variables for parameters
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"GP_TIMES": str(params["gp_times"]),
|
||||
"GP_LOCK_NUMS": str(params["lock_nums"]),
|
||||
"GP_CUT": str(params["cut"]),
|
||||
"GP_SCALE_F": str(params["scale_factor"]),
|
||||
"DATA_TYPE": params["data_type"],
|
||||
"GP_T": str(params["thread_nums"]),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Running graph partition with command: {' '.join(cmd)}")
|
||||
print(f"Working directory: {graph_partition_dir}")
|
||||
|
||||
# Run the command
|
||||
result = subprocess.run(
|
||||
cmd, env=env, capture_output=True, text=True, cwd=graph_partition_dir
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Command failed with return code {result.returncode}")
|
||||
print(f"stdout: {result.stdout}")
|
||||
print(f"stderr: {result.stderr}")
|
||||
raise RuntimeError(
|
||||
f"Graph partitioning failed with return code {result.returncode}.\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Check if output files were created
|
||||
disk_graph_path = Path(output_dir) / "_disk_graph.index"
|
||||
partition_bin_path = Path(output_dir) / "_partition.bin"
|
||||
|
||||
if not disk_graph_path.exists():
|
||||
raise RuntimeError(f"Expected output file not found: {disk_graph_path}")
|
||||
|
||||
if not partition_bin_path.exists():
|
||||
raise RuntimeError(f"Expected output file not found: {partition_bin_path}")
|
||||
|
||||
print("✅ Partitioning completed successfully!")
|
||||
print(f" Disk graph index: {disk_graph_path}")
|
||||
print(f" Partition binary: {partition_bin_path}")
|
||||
|
||||
return str(disk_graph_path), str(partition_bin_path)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
disk_graph_path, partition_bin_path = partition_graph_simple(
|
||||
"/Users/yichuan/Desktop/release2/leann/diskannbuild/test_doc_files",
|
||||
gp_times=5,
|
||||
lock_nums=5,
|
||||
cut=50,
|
||||
)
|
||||
print("Success! Output files:")
|
||||
print(f" - {disk_graph_path}")
|
||||
print(f" - {partition_bin_path}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.14"
|
||||
dependencies = ["leann-core==0.1.14", "numpy", "protobuf>=3.19.0"]
|
||||
version = "0.2.7"
|
||||
dependencies = ["leann-core==0.2.7", "numpy", "protobuf>=3.19.0"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# Key: simplified CMake path
|
||||
@@ -16,4 +16,4 @@ wheel.packages = ["leann_backend_diskann"]
|
||||
editable.mode = "redirect"
|
||||
cmake.build-type = "Release"
|
||||
build.verbose = true
|
||||
build.tool-args = ["-j8"]
|
||||
build.tool-args = ["-j8"]
|
||||
|
||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...b2dc4ea2c7
@@ -2,12 +2,12 @@ syntax = "proto3";
|
||||
|
||||
package protoembedding;
|
||||
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
message NodeEmbeddingRequest {
|
||||
repeated uint32 node_ids = 1;
|
||||
}
|
||||
|
||||
message NodeEmbeddingResponse {
|
||||
bytes embeddings_data = 1; // All embedded binary datas
|
||||
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
|
||||
repeated uint32 missing_ids = 3; // Missing node ids
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,14 @@ if(APPLE)
|
||||
set(OpenMP_C_LIB_NAMES "omp")
|
||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||
|
||||
# Force use of system libc++ to avoid version mismatch
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||
|
||||
# Set minimum macOS version for better compatibility
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||
endif()
|
||||
|
||||
# Use system ZeroMQ instead of building from source
|
||||
@@ -52,4 +60,4 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(third_party/faiss)
|
||||
add_subdirectory(third_party/faiss)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import gc # Import garbage collector interface
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
@@ -7,6 +8,12 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Set up logging to avoid print buffer issues
|
||||
logger = logging.getLogger(__name__)
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# --- FourCCs (add more if needed) ---
|
||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||
@@ -72,7 +79,11 @@ def read_vector_raw(f, element_fmt_char):
|
||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||
"""Reads a vector into a NumPy array."""
|
||||
count = -1 # Initialize count for robust error handling
|
||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end="", flush=True)
|
||||
print(
|
||||
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||
@@ -239,6 +250,12 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
output_filename: Output CSR index file
|
||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||
"""
|
||||
# Disable buffering for print statements to avoid deadlock in CI/pytest
|
||||
import functools
|
||||
|
||||
global print
|
||||
print = functools.partial(print, flush=True)
|
||||
|
||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||
start_time = time.time()
|
||||
original_hnsw_data = {}
|
||||
@@ -647,7 +664,10 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||
return False
|
||||
except MemoryError as e:
|
||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
||||
print(
|
||||
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Clean up potentially partially written output file?
|
||||
try:
|
||||
os.remove(output_filename)
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from leann.interface import (
|
||||
@@ -28,6 +28,12 @@ def get_metric_map():
|
||||
}
|
||||
|
||||
|
||||
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1 # Avoid division by zero
|
||||
return data / norms
|
||||
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
@@ -76,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
index.hnsw.efConstruction = self.efConstruction
|
||||
|
||||
if self.distance_metric.lower() == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
data = normalize_l2(data)
|
||||
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
@@ -118,7 +124,9 @@ class HNSWSearcher(BaseSearcher):
|
||||
)
|
||||
from . import faiss # type: ignore
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
self.distance_metric = (
|
||||
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||
)
|
||||
metric_enum = get_metric_map().get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
@@ -144,7 +152,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
self,
|
||||
query: np.ndarray,
|
||||
top_k: int,
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
complexity: int = 64,
|
||||
beam_width: int = 1,
|
||||
prune_ratio: float = 0.0,
|
||||
@@ -186,7 +194,7 @@ class HNSWSearcher(BaseSearcher):
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if self.distance_metric == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
query = normalize_l2(query)
|
||||
|
||||
params = faiss.SearchParametersHNSW()
|
||||
if zmq_port is not None:
|
||||
@@ -194,6 +202,16 @@ class HNSWSearcher(BaseSearcher):
|
||||
params.efSearch = complexity
|
||||
params.beam_size = beam_width
|
||||
|
||||
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||
# This prevents early termination when all scores are in a narrow range
|
||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||
if self.distance_metric == "cosine" and any(
|
||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||
):
|
||||
params.check_relative_distance = False
|
||||
else:
|
||||
params.check_relative_distance = True
|
||||
|
||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||
params.pq_pruning_ratio = prune_ratio
|
||||
|
||||
@@ -227,3 +245,25 @@ class HNSWSearcher(BaseSearcher):
|
||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup HNSW-specific resources including C++ ZMQ connections."""
|
||||
# Call parent cleanup first
|
||||
super().cleanup()
|
||||
|
||||
# Additional cleanup for C++ side ZMQ connections
|
||||
# The ZmqDistanceComputer in C++ uses ZMQ connections that need cleanup
|
||||
try:
|
||||
# Delete the index to trigger C++ destructors
|
||||
if hasattr(self, "index"):
|
||||
del self.index
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection to ensure C++ objects are destroyed
|
||||
try:
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
@@ -33,7 +34,7 @@ if not logger.handlers:
|
||||
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: str | None = None,
|
||||
passages_file: Optional[str] = None,
|
||||
zmq_port: int = 5555,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
distance_metric: str = "mips",
|
||||
@@ -81,19 +82,8 @@ def create_hnsw_embedding_server(
|
||||
with open(passages_file) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
# Convert relative paths to absolute paths based on metadata file location
|
||||
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||
passage_sources = []
|
||||
for source in meta["passage_sources"]:
|
||||
source_copy = source.copy()
|
||||
# Convert relative paths to absolute paths
|
||||
if not Path(source_copy["path"]).is_absolute():
|
||||
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||
if not Path(source_copy["index_path"]).is_absolute():
|
||||
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||
passage_sources.append(source_copy)
|
||||
|
||||
passages = PassageManager(passage_sources)
|
||||
# Let PassageManager handle path resolution uniformly
|
||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
@@ -102,6 +92,7 @@ def create_hnsw_embedding_server(
|
||||
"""ZMQ server thread"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
|
||||
@@ -295,7 +286,7 @@ if __name__ == "__main__":
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx"],
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode",
|
||||
)
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-hnsw"
|
||||
version = "0.1.14"
|
||||
version = "0.2.7"
|
||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||
dependencies = [
|
||||
"leann-core==0.1.14",
|
||||
"leann-core==0.2.7",
|
||||
"numpy",
|
||||
"pyzmq>=23.0.0",
|
||||
"msgpack>=1.0.0",
|
||||
@@ -24,4 +24,4 @@ build.tool-args = ["-j8"]
|
||||
|
||||
# CMake definitions to optimize compilation
|
||||
[tool.scikit-build.cmake.define]
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann-core"
|
||||
version = "0.1.14"
|
||||
version = "0.2.7"
|
||||
description = "Core API and plugin system for LEANN"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -31,6 +31,8 @@ dependencies = [
|
||||
"PyPDF2>=3.0.0",
|
||||
"pymupdf>=1.23.0",
|
||||
"pdfplumber>=0.10.0",
|
||||
"nbconvert>=7.0.0", # For .ipynb file support
|
||||
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
]
|
||||
@@ -44,6 +46,7 @@ colab = [
|
||||
|
||||
[project.scripts]
|
||||
leann = "leann.cli:main"
|
||||
leann_mcp = "leann.mcp:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
where = ["src"]
|
||||
|
||||
@@ -8,6 +8,10 @@ if platform.system() == "Darwin":
|
||||
os.environ["MKL_NUM_THREADS"] = "1"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
os.environ["KMP_BLOCKTIME"] = "0"
|
||||
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
|
||||
if os.environ.get("CI") == "true":
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||
|
||||
@@ -7,9 +7,10 @@ import json
|
||||
import logging
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -22,12 +23,17 @@ from .registry import BACKEND_REGISTRY
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_registered_backends() -> list[str]:
|
||||
"""Get list of registered backend names."""
|
||||
return list(BACKEND_REGISTRY.keys())
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
chunks: list[str],
|
||||
model_name: str,
|
||||
mode: str = "sentence-transformers",
|
||||
use_server: bool = True,
|
||||
port: int | None = None,
|
||||
port: Optional[int] = None,
|
||||
is_build=False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@@ -81,21 +87,26 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
|
||||
# Connect to embedding server
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1s timeout on receive
|
||||
socket.setsockopt(zmq.SNDTIMEO, 1000) # 1s timeout on send
|
||||
socket.setsockopt(zmq.IMMEDIATE, 1) # Don't wait for connection
|
||||
socket.connect(f"tcp://localhost:{port}")
|
||||
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
try:
|
||||
# Send chunks to server for embedding computation
|
||||
request = chunks
|
||||
socket.send(msgpack.packb(request))
|
||||
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
# Receive embeddings from server
|
||||
response = socket.recv()
|
||||
embeddings_list = msgpack.unpackb(response)
|
||||
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
# Convert back to numpy array
|
||||
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
context.term()
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -109,7 +120,9 @@ class SearchResult:
|
||||
|
||||
|
||||
class PassageManager:
|
||||
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||
def __init__(
|
||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||
):
|
||||
self.offset_maps = {}
|
||||
self.passage_files = {}
|
||||
self.global_offset_map = {} # Combined map for fast lookup
|
||||
@@ -119,10 +132,26 @@ class PassageManager:
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"] # .idx file
|
||||
|
||||
# Fix path resolution for Colab and other environments
|
||||
# Fix path resolution - relative paths should be relative to metadata file directory
|
||||
if not Path(index_file).is_absolute():
|
||||
# If relative path, try to resolve it properly
|
||||
index_file = str(Path(index_file).resolve())
|
||||
if metadata_file_path:
|
||||
# Resolve relative to metadata file directory
|
||||
metadata_dir = Path(metadata_file_path).parent
|
||||
logger.debug(
|
||||
f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
|
||||
)
|
||||
index_file = str((metadata_dir / index_file).resolve())
|
||||
passage_file = str((metadata_dir / passage_file).resolve())
|
||||
logger.debug(f"PassageManager: Resolved index_file: {index_file}")
|
||||
else:
|
||||
# Fallback to current directory resolution (legacy behavior)
|
||||
logger.warning(
|
||||
"PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
|
||||
)
|
||||
logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
|
||||
index_file = str(Path(index_file).resolve())
|
||||
passage_file = str(Path(passage_file).resolve())
|
||||
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
|
||||
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
@@ -151,22 +180,92 @@ class LeannBuilder:
|
||||
self,
|
||||
backend_name: str,
|
||||
embedding_model: str = "facebook/contriever",
|
||||
dimensions: int | None = None,
|
||||
dimensions: Optional[int] = None,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
**backend_kwargs,
|
||||
):
|
||||
self.backend_name = backend_name
|
||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||
self.backend_factory = backend_factory
|
||||
self.embedding_model = embedding_model
|
||||
self.dimensions = dimensions
|
||||
self.embedding_mode = embedding_mode
|
||||
|
||||
# Check if we need to use cosine distance for normalized embeddings
|
||||
normalized_embeddings_models = {
|
||||
# OpenAI models
|
||||
("openai", "text-embedding-ada-002"),
|
||||
("openai", "text-embedding-3-small"),
|
||||
("openai", "text-embedding-3-large"),
|
||||
# Voyage AI models
|
||||
("voyage", "voyage-2"),
|
||||
("voyage", "voyage-3"),
|
||||
("voyage", "voyage-large-2"),
|
||||
("voyage", "voyage-multilingual-2"),
|
||||
("voyage", "voyage-code-2"),
|
||||
# Cohere models
|
||||
("cohere", "embed-english-v3.0"),
|
||||
("cohere", "embed-multilingual-v3.0"),
|
||||
("cohere", "embed-english-light-v3.0"),
|
||||
("cohere", "embed-multilingual-light-v3.0"),
|
||||
}
|
||||
|
||||
# Also check for patterns in model names
|
||||
is_normalized = False
|
||||
current_model_lower = embedding_model.lower()
|
||||
current_mode_lower = embedding_mode.lower()
|
||||
|
||||
# Check exact matches
|
||||
for mode, model in normalized_embeddings_models:
|
||||
if (current_mode_lower == mode and current_model_lower == model) or (
|
||||
mode in current_mode_lower and model in current_model_lower
|
||||
):
|
||||
is_normalized = True
|
||||
break
|
||||
|
||||
# Check patterns
|
||||
if not is_normalized:
|
||||
# OpenAI patterns
|
||||
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
||||
if any(
|
||||
pattern in current_model_lower
|
||||
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
||||
):
|
||||
is_normalized = True
|
||||
# Voyage patterns
|
||||
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
||||
is_normalized = True
|
||||
# Cohere patterns
|
||||
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
||||
if "embed" in current_model_lower:
|
||||
is_normalized = True
|
||||
|
||||
# Handle distance metric
|
||||
if is_normalized and "distance_metric" not in backend_kwargs:
|
||||
backend_kwargs["distance_metric"] = "cosine"
|
||||
warnings.warn(
|
||||
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
||||
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
||||
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
||||
current_metric = backend_kwargs.get("distance_metric", "mips")
|
||||
warnings.warn(
|
||||
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
||||
f"'{embedding_model}' may lead to suboptimal search results. "
|
||||
f"Consider using 'cosine' distance metric for better performance.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.backend_kwargs = backend_kwargs
|
||||
self.chunks: list[dict[str, Any]] = []
|
||||
|
||||
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||
@@ -238,8 +337,8 @@ class LeannBuilder:
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
"path": passages_file.name, # Use relative path (just filename)
|
||||
"index_path": offset_file.name, # Use relative path (just filename)
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -354,8 +453,8 @@ class LeannBuilder:
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file),
|
||||
"path": passages_file.name, # Use relative path (just filename)
|
||||
"index_path": offset_file.name, # Use relative path (just filename)
|
||||
}
|
||||
],
|
||||
"built_from_precomputed_embeddings": True,
|
||||
@@ -383,14 +482,23 @@ class LeannSearcher:
|
||||
|
||||
self.meta_path_str = f"{index_path}.meta.json"
|
||||
if not Path(self.meta_path_str).exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
|
||||
parent_dir = Path(index_path).parent
|
||||
print(
|
||||
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
||||
)
|
||||
# highlight in red the filenotfound error
|
||||
raise FileNotFoundError(
|
||||
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
||||
)
|
||||
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||
self.meta_data = json.load(f)
|
||||
backend_name = self.meta_data["backend_name"]
|
||||
self.embedding_model = self.meta_data["embedding_model"]
|
||||
# Support both old and new format
|
||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||
self.passage_manager = PassageManager(
|
||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||
)
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||
@@ -417,6 +525,16 @@ class LeannSearcher:
|
||||
logger.info(f" Top_k: {top_k}")
|
||||
logger.info(f" Additional kwargs: {kwargs}")
|
||||
|
||||
# Smart top_k detection and adjustment
|
||||
total_docs = len(self.passage_manager.global_offset_map)
|
||||
original_top_k = top_k
|
||||
if top_k > total_docs:
|
||||
top_k = total_docs
|
||||
logger.warning(
|
||||
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
||||
)
|
||||
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
||||
|
||||
zmq_port = None
|
||||
|
||||
start_time = time.time()
|
||||
@@ -453,15 +571,15 @@ class LeannSearcher:
|
||||
zmq_port=zmq_port,
|
||||
**kwargs,
|
||||
)
|
||||
time.time() - start_time
|
||||
# logger.info(f" Search time: {search_time} seconds")
|
||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||
|
||||
enriched_results = []
|
||||
if "labels" in results and "distances" in results:
|
||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
||||
for i, (string_id, dist) in enumerate(
|
||||
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||
zip(results["labels"][0], results["distances"][0])
|
||||
):
|
||||
try:
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
@@ -487,19 +605,45 @@ class LeannSearcher:
|
||||
)
|
||||
except KeyError:
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
logger.error(
|
||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||
)
|
||||
|
||||
# Define color codes outside the loop for final message
|
||||
GREEN = "\033[92m"
|
||||
RESET = "\033[0m"
|
||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||
return enriched_results
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server and ZMQ resources.
|
||||
|
||||
This method should be called after you're done using the searcher,
|
||||
especially in test environments or batch processing scenarios.
|
||||
"""
|
||||
# Stop embedding server
|
||||
if hasattr(self.backend_impl, "embedding_server_manager"):
|
||||
self.backend_impl.embedding_server_manager.stop_server()
|
||||
|
||||
# Set ZMQ linger but don't terminate global context
|
||||
try:
|
||||
import zmq
|
||||
|
||||
# Just set linger on the global instance
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.linger = 0
|
||||
# NEVER call ctx.term() or destroy() on the global instance
|
||||
# That would block waiting for all sockets to close
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class LeannChat:
|
||||
def __init__(
|
||||
self,
|
||||
index_path: str,
|
||||
llm_config: dict[str, Any] | None = None,
|
||||
llm_config: Optional[dict[str, Any]] = None,
|
||||
enable_warmup: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -515,7 +659,7 @@ class LeannChat:
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = True,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
llm_kwargs: dict[str, Any] | None = None,
|
||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||
expected_zmq_port: int = 5557,
|
||||
**search_kwargs,
|
||||
):
|
||||
@@ -543,7 +687,10 @@ class LeannChat:
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
ask_time = time.time()
|
||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||
ask_time = time.time() - ask_time
|
||||
logger.info(f" Ask time: {ask_time} seconds")
|
||||
return ans
|
||||
|
||||
def start_interactive(self):
|
||||
@@ -560,3 +707,12 @@ class LeannChat:
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
def cleanup(self):
|
||||
"""Explicitly cleanup embedding server resources.
|
||||
|
||||
This method should be called after you're done using the chat interface,
|
||||
especially in test environments or batch processing scenarios.
|
||||
"""
|
||||
if hasattr(self.searcher, "cleanup"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
@@ -8,7 +8,7 @@ import difflib
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -17,12 +17,12 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_ollama_models() -> list[str]:
|
||||
def check_ollama_models(host: str) -> list[str]:
|
||||
"""Check available Ollama models and return a list"""
|
||||
try:
|
||||
import requests
|
||||
|
||||
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return [model["name"] for model in data.get("models", [])]
|
||||
@@ -245,7 +245,11 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||
|
||||
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||
models = list_models(
|
||||
search=query, filter="text-generation", sort="downloads", direction=-1, limit=limit
|
||||
search=query,
|
||||
filter="text-generation",
|
||||
sort="downloads",
|
||||
direction=-1,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
||||
@@ -305,10 +309,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||
return search_hf_models_fuzzy(query, limit)
|
||||
|
||||
|
||||
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||
def validate_model_and_suggest(
|
||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
||||
) -> Optional[str]:
|
||||
"""Validate model name and provide suggestions if invalid"""
|
||||
if llm_type == "ollama":
|
||||
available_models = check_ollama_models()
|
||||
available_models = check_ollama_models(host)
|
||||
if available_models and model_name not in available_models:
|
||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||
|
||||
@@ -354,7 +360,11 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||
|
||||
if suggestions:
|
||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||
error_msg += (
|
||||
"\n\nDid you mean one of these installed models?\n"
|
||||
+ "\nTry to use ollama pull to install the model you need\n"
|
||||
)
|
||||
|
||||
for i, suggestion in enumerate(suggestions, 1):
|
||||
error_msg += f" {i}. {suggestion}\n"
|
||||
else:
|
||||
@@ -461,7 +471,7 @@ class OllamaChat(LLMInterface):
|
||||
requests.get(host)
|
||||
|
||||
# Pre-check model availability with helpful suggestions
|
||||
model_error = validate_model_and_suggest(model, "ollama")
|
||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
||||
if model_error:
|
||||
raise ValueError(model_error)
|
||||
|
||||
@@ -481,11 +491,35 @@ class OllamaChat(LLMInterface):
|
||||
import requests
|
||||
|
||||
full_url = f"{self.host}/api/generate"
|
||||
|
||||
# Handle thinking budget for reasoning models
|
||||
options = kwargs.copy()
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget:
|
||||
# Remove thinking_budget from options as it's not a standard Ollama option
|
||||
options.pop("thinking_budget", None)
|
||||
# Only apply reasoning parameters to models that support it
|
||||
reasoning_supported_models = [
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
"deepseek-r1",
|
||||
"deepseek-coder",
|
||||
]
|
||||
|
||||
if thinking_budget in ["low", "medium", "high"]:
|
||||
if any(model in self.model.lower() for model in reasoning_supported_models):
|
||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False, # Keep it simple for now
|
||||
"options": kwargs,
|
||||
"options": options,
|
||||
}
|
||||
logger.debug(f"Sending request to Ollama: {payload}")
|
||||
try:
|
||||
@@ -538,14 +572,41 @@ class HFChat(LLMInterface):
|
||||
self.device = "cpu"
|
||||
logger.info("No GPU detected. Using CPU.")
|
||||
|
||||
# Load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
# Load tokenizer and model with timeout protection
|
||||
try:
|
||||
import signal
|
||||
|
||||
def timeout_handler(signum, frame):
|
||||
raise TimeoutError("Model download/loading timed out")
|
||||
|
||||
# Set timeout for model loading (60 seconds)
|
||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
||||
signal.alarm(60)
|
||||
|
||||
try:
|
||||
logger.info(f"Loading tokenizer for {model_name}...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
logger.info(f"Loading model {model_name}...")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||
device_map="auto" if self.device != "cpu" else None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
logger.info(f"Successfully loaded {model_name}")
|
||||
finally:
|
||||
signal.alarm(0) # Cancel the alarm
|
||||
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
|
||||
|
||||
except TimeoutError:
|
||||
logger.error(f"Model loading timed out for {model_name}")
|
||||
raise RuntimeError(
|
||||
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_name}: {e}")
|
||||
raise
|
||||
|
||||
# Move model to device if not using device_map
|
||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||
@@ -582,7 +643,11 @@ class HFChat(LLMInterface):
|
||||
|
||||
# Tokenize input
|
||||
inputs = self.tokenizer(
|
||||
formatted_prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
|
||||
formatted_prompt,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
@@ -620,7 +685,7 @@ class HFChat(LLMInterface):
|
||||
class OpenAIChat(LLMInterface):
|
||||
"""LLM interface for OpenAI models."""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
@@ -645,11 +710,38 @@ class OpenAIChat(LLMInterface):
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||
}
|
||||
|
||||
# Handle max_tokens vs max_completion_tokens based on model
|
||||
max_tokens = kwargs.get("max_tokens", 1000)
|
||||
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
|
||||
# o-series models use max_completion_tokens
|
||||
params["max_completion_tokens"] = max_tokens
|
||||
params["temperature"] = 1.0
|
||||
else:
|
||||
# Other models use max_tokens
|
||||
params["max_tokens"] = max_tokens
|
||||
|
||||
# Handle thinking budget for reasoning models
|
||||
thinking_budget = kwargs.get("thinking_budget")
|
||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||
# Check if this is an o-series model (partial match for model names)
|
||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||
if any(model in self.model for model in o_series_models):
|
||||
# Use the correct OpenAI reasoning parameter format
|
||||
params["reasoning_effort"] = thinking_budget
|
||||
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
||||
)
|
||||
|
||||
# Add other kwargs (excluding thinking_budget as it's handled above)
|
||||
for k, v in kwargs.items():
|
||||
if k not in ["max_tokens", "temperature", "thinking_budget"]:
|
||||
params[k] = v
|
||||
|
||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||
|
||||
try:
|
||||
@@ -669,7 +761,7 @@ class SimulatedChat(LLMInterface):
|
||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||
|
||||
|
||||
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
||||
"""
|
||||
Factory function to get an LLM interface based on configuration.
|
||||
|
||||
|
||||
@@ -41,13 +41,23 @@ def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
||||
|
||||
class LeannCLI:
|
||||
def __init__(self):
|
||||
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||
# Always use project-local .leann directory (like .git)
|
||||
self.indexes_dir = Path.cwd() / ".leann" / "indexes"
|
||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Default parser for documents
|
||||
self.node_parser = SentenceSplitter(
|
||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||
)
|
||||
|
||||
# Code-optimized parser
|
||||
self.code_parser = SentenceSplitter(
|
||||
chunk_size=512, # Larger chunks for code context
|
||||
chunk_overlap=50, # Less overlap to preserve function boundaries
|
||||
separator="\n", # Split by lines for code
|
||||
paragraph_separator="\n\n", # Preserve logical code blocks
|
||||
)
|
||||
|
||||
def get_index_path(self, index_name: str) -> str:
|
||||
index_dir = self.indexes_dir / index_name
|
||||
return str(index_dir / "documents.leann")
|
||||
@@ -64,10 +74,11 @@ class LeannCLI:
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
leann build my-docs --docs ./documents # Build index named my-docs
|
||||
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
|
||||
leann search my-docs "query" # Search in my-docs index
|
||||
leann ask my-docs "question" # Ask my-docs index
|
||||
leann list # List all stored indexes
|
||||
""",
|
||||
)
|
||||
|
||||
@@ -75,18 +86,34 @@ Examples:
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||
build_parser.add_argument("index_name", help="Index name")
|
||||
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||
build_parser.add_argument(
|
||||
"index_name", nargs="?", help="Index name (default: current directory name)"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--docs", type=str, default=".", help="Documents directory (default: current directory)"
|
||||
)
|
||||
build_parser.add_argument(
|
||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||
)
|
||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||
build_parser.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers)",
|
||||
)
|
||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||
build_parser.add_argument("--complexity", type=int, default=64)
|
||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||
build_parser.add_argument(
|
||||
"--file-types",
|
||||
type=str,
|
||||
help="Comma-separated list of file extensions to include (e.g., '.txt,.pdf,.pptx'). If not specified, uses default supported types.",
|
||||
)
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||
@@ -96,7 +123,12 @@ Examples:
|
||||
search_parser.add_argument("--complexity", type=int, default=64)
|
||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
search_parser.add_argument(
|
||||
"--recompute-embeddings",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Recompute embeddings (default: True)",
|
||||
)
|
||||
search_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
@@ -119,94 +151,370 @@ Examples:
|
||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||
ask_parser.add_argument(
|
||||
"--recompute-embeddings",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Recompute embeddings (default: True)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--pruning-strategy",
|
||||
choices=["global", "local", "proportional"],
|
||||
default="global",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
|
||||
# List command
|
||||
subparsers.add_parser("list", help="List all indexes")
|
||||
|
||||
return parser
|
||||
|
||||
def register_project_dir(self):
|
||||
"""Register current project directory in global registry"""
|
||||
global_registry = Path.home() / ".leann" / "projects.json"
|
||||
global_registry.parent.mkdir(exist_ok=True)
|
||||
|
||||
current_dir = str(Path.cwd())
|
||||
|
||||
# Load existing registry
|
||||
projects = []
|
||||
if global_registry.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(global_registry) as f:
|
||||
projects = json.load(f)
|
||||
except Exception:
|
||||
projects = []
|
||||
|
||||
# Add current directory if not already present
|
||||
if current_dir not in projects:
|
||||
projects.append(current_dir)
|
||||
|
||||
# Save registry
|
||||
import json
|
||||
|
||||
with open(global_registry, "w") as f:
|
||||
json.dump(projects, f, indent=2)
|
||||
|
||||
def _build_gitignore_parser(self, docs_dir: str):
|
||||
"""Build gitignore parser using gitignore-parser library."""
|
||||
from gitignore_parser import parse_gitignore
|
||||
|
||||
# Try to parse the root .gitignore
|
||||
gitignore_path = Path(docs_dir) / ".gitignore"
|
||||
|
||||
if gitignore_path.exists():
|
||||
try:
|
||||
# gitignore-parser automatically handles all subdirectory .gitignore files!
|
||||
matches = parse_gitignore(str(gitignore_path))
|
||||
print(f"📋 Loaded .gitignore from {docs_dir} (includes all subdirectories)")
|
||||
return matches
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not parse .gitignore: {e}")
|
||||
else:
|
||||
print("📋 No .gitignore found")
|
||||
|
||||
# Fallback: basic pattern matching for essential files
|
||||
essential_patterns = {".git", ".DS_Store", "__pycache__", "node_modules", ".venv", "venv"}
|
||||
|
||||
def basic_matches(file_path):
|
||||
path_parts = Path(file_path).parts
|
||||
return any(part in essential_patterns for part in path_parts)
|
||||
|
||||
return basic_matches
|
||||
|
||||
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
|
||||
"""Check if a file should be excluded using gitignore parser."""
|
||||
return gitignore_matches(str(relative_path))
|
||||
|
||||
def list_indexes(self):
|
||||
print("Stored LEANN indexes:")
|
||||
|
||||
if not self.indexes_dir.exists():
|
||||
# Get all project directories with .leann
|
||||
global_registry = Path.home() / ".leann" / "projects.json"
|
||||
all_projects = []
|
||||
|
||||
if global_registry.exists():
|
||||
try:
|
||||
import json
|
||||
|
||||
with open(global_registry) as f:
|
||||
all_projects = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Filter to only existing directories with .leann
|
||||
valid_projects = []
|
||||
for project_dir in all_projects:
|
||||
project_path = Path(project_dir)
|
||||
if project_path.exists() and (project_path / ".leann" / "indexes").exists():
|
||||
valid_projects.append(project_path)
|
||||
|
||||
# Add current project if it has .leann but not in registry
|
||||
current_path = Path.cwd()
|
||||
if (current_path / ".leann" / "indexes").exists() and current_path not in valid_projects:
|
||||
valid_projects.append(current_path)
|
||||
|
||||
if not valid_projects:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
|
||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||
total_indexes = 0
|
||||
current_dir = Path.cwd()
|
||||
|
||||
if not index_dirs:
|
||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||
return
|
||||
for project_path in valid_projects:
|
||||
indexes_dir = project_path / ".leann" / "indexes"
|
||||
if not indexes_dir.exists():
|
||||
continue
|
||||
|
||||
print(f"Found {len(index_dirs)} indexes:")
|
||||
for i, index_dir in enumerate(index_dirs, 1):
|
||||
index_name = index_dir.name
|
||||
status = "✓" if self.index_exists(index_name) else "✗"
|
||||
index_dirs = [d for d in indexes_dir.iterdir() if d.is_dir()]
|
||||
if not index_dirs:
|
||||
continue
|
||||
|
||||
print(f" {i}. {index_name} [{status}]")
|
||||
if self.index_exists(index_name):
|
||||
index_dir / "documents.leann.meta.json"
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||
1024 * 1024
|
||||
)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
# Show project header
|
||||
if project_path == current_dir:
|
||||
print(f"\n📁 Current project ({project_path}):")
|
||||
else:
|
||||
print(f"\n📂 {project_path}:")
|
||||
|
||||
if index_dirs:
|
||||
example_name = index_dirs[0].name
|
||||
print("\nUsage:")
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
for index_dir in index_dirs:
|
||||
total_indexes += 1
|
||||
index_name = index_dir.name
|
||||
meta_file = index_dir / "documents.leann.meta.json"
|
||||
status = "✓" if meta_file.exists() else "✗"
|
||||
|
||||
def load_documents(self, docs_dir: str):
|
||||
print(f" {total_indexes}. {index_name} [{status}]")
|
||||
if status == "✓":
|
||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||
1024 * 1024
|
||||
)
|
||||
print(f" Size: {size_mb:.1f} MB")
|
||||
|
||||
if total_indexes > 0:
|
||||
print(f"\nTotal: {total_indexes} indexes across {len(valid_projects)} projects")
|
||||
print("\nUsage (current project only):")
|
||||
|
||||
# Show example from current project
|
||||
current_indexes_dir = current_dir / ".leann" / "indexes"
|
||||
if current_indexes_dir.exists():
|
||||
current_index_dirs = [d for d in current_indexes_dir.iterdir() if d.is_dir()]
|
||||
if current_index_dirs:
|
||||
example_name = current_index_dirs[0].name
|
||||
print(f' leann search {example_name} "your query"')
|
||||
print(f" leann ask {example_name} --interactive")
|
||||
|
||||
def load_documents(self, docs_dir: str, custom_file_types: str | None = None):
|
||||
print(f"Loading documents from {docs_dir}...")
|
||||
if custom_file_types:
|
||||
print(f"Using custom file types: {custom_file_types}")
|
||||
|
||||
# Try to use better PDF parsers first
|
||||
# Build gitignore parser
|
||||
gitignore_matches = self._build_gitignore_parser(docs_dir)
|
||||
|
||||
# Try to use better PDF parsers first, but only if PDFs are requested
|
||||
documents = []
|
||||
docs_path = Path(docs_dir)
|
||||
|
||||
for file_path in docs_path.rglob("*.pdf"):
|
||||
print(f"Processing PDF: {file_path}")
|
||||
# Check if we should process PDFs
|
||||
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
||||
|
||||
# Try PyMuPDF first (best quality)
|
||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||
if text is None:
|
||||
# Try pdfplumber
|
||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||
if should_process_pdfs:
|
||||
for file_path in docs_path.rglob("*.pdf"):
|
||||
# Check if file matches any exclude pattern
|
||||
relative_path = file_path.relative_to(docs_path)
|
||||
if self._should_exclude_file(relative_path, gitignore_matches):
|
||||
continue
|
||||
|
||||
if text:
|
||||
# Create a simple document structure
|
||||
from llama_index.core import Document
|
||||
print(f"Processing PDF: {file_path}")
|
||||
|
||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||
documents.append(doc)
|
||||
else:
|
||||
# Fallback to default reader
|
||||
print(f"Using default reader for {file_path}")
|
||||
default_docs = SimpleDirectoryReader(
|
||||
str(file_path.parent),
|
||||
filename_as_id=True,
|
||||
required_exts=[file_path.suffix],
|
||||
).load_data()
|
||||
documents.extend(default_docs)
|
||||
# Try PyMuPDF first (best quality)
|
||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||
if text is None:
|
||||
# Try pdfplumber
|
||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||
|
||||
if text:
|
||||
# Create a simple document structure
|
||||
from llama_index.core import Document
|
||||
|
||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||
documents.append(doc)
|
||||
else:
|
||||
# Fallback to default reader
|
||||
print(f"Using default reader for {file_path}")
|
||||
try:
|
||||
default_docs = SimpleDirectoryReader(
|
||||
str(file_path.parent),
|
||||
filename_as_id=True,
|
||||
required_exts=[file_path.suffix],
|
||||
).load_data()
|
||||
documents.extend(default_docs)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process {file_path}: {e}")
|
||||
|
||||
# Load other file types with default reader
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".txt", ".md", ".docx"],
|
||||
).load_data(show_progress=True)
|
||||
documents.extend(other_docs)
|
||||
if custom_file_types:
|
||||
# Parse custom file types from comma-separated string
|
||||
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
|
||||
# Ensure extensions start with a dot
|
||||
code_extensions = [ext if ext.startswith(".") else f".{ext}" for ext in code_extensions]
|
||||
else:
|
||||
# Use default supported file types
|
||||
code_extensions = [
|
||||
# Original document types
|
||||
".txt",
|
||||
".md",
|
||||
".docx",
|
||||
".pptx",
|
||||
# Code files for Claude Code integration
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".sql",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".ps1",
|
||||
".bat",
|
||||
# Config and markup files
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".xml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".vue",
|
||||
".svelte",
|
||||
# Data science
|
||||
".ipynb",
|
||||
".R",
|
||||
".py",
|
||||
".jl",
|
||||
]
|
||||
# Try to load other file types, but don't fail if none are found
|
||||
try:
|
||||
# Create a custom file filter function using our PathSpec
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Return True if file should be included (not excluded)"""
|
||||
try:
|
||||
docs_path_obj = Path(docs_dir)
|
||||
file_path_obj = Path(file_path)
|
||||
relative_path = file_path_obj.relative_to(docs_path_obj)
|
||||
return not self._should_exclude_file(relative_path, gitignore_matches)
|
||||
except (ValueError, OSError):
|
||||
return True # Include files that can't be processed
|
||||
|
||||
other_docs = SimpleDirectoryReader(
|
||||
docs_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=code_extensions,
|
||||
file_extractor={}, # Use default extractors
|
||||
filename_as_id=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Filter documents after loading based on gitignore rules
|
||||
filtered_docs = []
|
||||
for doc in other_docs:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents.extend(filtered_docs)
|
||||
except ValueError as e:
|
||||
if "No files found" in str(e):
|
||||
print("No additional files found for other supported types.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
all_texts = []
|
||||
|
||||
# Define code file extensions for intelligent chunking
|
||||
code_file_exts = {
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".jsx",
|
||||
".tsx",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".hpp",
|
||||
".cs",
|
||||
".go",
|
||||
".rs",
|
||||
".rb",
|
||||
".php",
|
||||
".swift",
|
||||
".kt",
|
||||
".scala",
|
||||
".r",
|
||||
".sql",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".fish",
|
||||
".ps1",
|
||||
".bat",
|
||||
".json",
|
||||
".yaml",
|
||||
".yml",
|
||||
".xml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
".html",
|
||||
".css",
|
||||
".scss",
|
||||
".less",
|
||||
".vue",
|
||||
".svelte",
|
||||
".ipynb",
|
||||
".R",
|
||||
".jl",
|
||||
}
|
||||
|
||||
for doc in documents:
|
||||
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||
# Check if this is a code file based on source path
|
||||
source_path = doc.metadata.get("source", "")
|
||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
||||
|
||||
# Use appropriate parser based on file type
|
||||
parser = self.code_parser if is_code_file else self.node_parser
|
||||
nodes = parser.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
all_texts.append(node.get_content())
|
||||
|
||||
@@ -215,15 +523,23 @@ Examples:
|
||||
|
||||
async def build_index(self, args):
|
||||
docs_dir = args.docs
|
||||
index_name = args.index_name
|
||||
# Use current directory name if index_name not provided
|
||||
if args.index_name:
|
||||
index_name = args.index_name
|
||||
else:
|
||||
index_name = Path.cwd().name
|
||||
print(f"Using current directory name as index: '{index_name}'")
|
||||
|
||||
index_dir = self.indexes_dir / index_name
|
||||
index_path = self.get_index_path(index_name)
|
||||
|
||||
print(f"📂 Indexing: {Path(docs_dir).resolve()}")
|
||||
|
||||
if index_dir.exists() and not args.force:
|
||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||
return
|
||||
|
||||
all_texts = self.load_documents(docs_dir)
|
||||
all_texts = self.load_documents(docs_dir, args.file_types)
|
||||
if not all_texts:
|
||||
print("No documents found")
|
||||
return
|
||||
@@ -235,6 +551,7 @@ Examples:
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.complexity,
|
||||
is_compact=args.compact,
|
||||
@@ -248,6 +565,9 @@ Examples:
|
||||
builder.build_index(index_path)
|
||||
print(f"Index built at {index_path}")
|
||||
|
||||
# Register this project directory in global registry
|
||||
self.register_project_dir()
|
||||
|
||||
async def search_documents(self, args):
|
||||
index_name = args.index_name
|
||||
query = args.query
|
||||
@@ -308,6 +628,11 @@ Examples:
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
@@ -316,11 +641,17 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
@@ -329,6 +660,7 @@ Examples:
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
|
||||
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -35,7 +36,7 @@ def compute_embeddings(
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Model name
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
batch_size: Batch size for processing
|
||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||
@@ -55,6 +56,8 @@ def compute_embeddings(
|
||||
return compute_embeddings_openai(texts, model_name)
|
||||
elif mode == "mlx":
|
||||
return compute_embeddings_mlx(texts, model_name)
|
||||
elif mode == "ollama":
|
||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||
|
||||
@@ -365,3 +368,262 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
||||
|
||||
# Stack numpy arrays
|
||||
return np.stack(all_embeddings)
|
||||
|
||||
|
||||
def compute_embeddings_ollama(
|
||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute embeddings using Ollama API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to compute embeddings for
|
||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||
is_build: Whether this is a build operation (shows progress bar)
|
||||
host: Ollama host URL (default: http://localhost:11434)
|
||||
|
||||
Returns:
|
||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||
"""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
||||
)
|
||||
|
||||
if not texts:
|
||||
raise ValueError("Cannot compute embeddings for empty text list")
|
||||
|
||||
logger.info(
|
||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
||||
)
|
||||
|
||||
# Check if Ollama is running
|
||||
try:
|
||||
response = requests.get(f"{host}/api/version", timeout=5)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.ConnectionError:
|
||||
error_msg = (
|
||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
||||
"Please ensure Ollama is running:\n"
|
||||
" • macOS/Linux: ollama serve\n"
|
||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||
"Installation: https://ollama.com/download"
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
||||
|
||||
# Check if model exists and provide helpful suggestions
|
||||
try:
|
||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
model_names = [model["name"] for model in models.get("models", [])]
|
||||
|
||||
# Filter for embedding models (models that support embeddings)
|
||||
embedding_models = []
|
||||
suggested_embedding_models = [
|
||||
"nomic-embed-text",
|
||||
"mxbai-embed-large",
|
||||
"bge-m3",
|
||||
"all-minilm",
|
||||
"snowflake-arctic-embed",
|
||||
]
|
||||
|
||||
for model in model_names:
|
||||
# Check if it's an embedding model (by name patterns or known models)
|
||||
base_name = model.split(":")[0]
|
||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
||||
embedding_models.append(model)
|
||||
|
||||
# Check if model exists (handle versioned names)
|
||||
model_found = any(
|
||||
model_name == name.split(":")[0] or model_name == name for name in model_names
|
||||
)
|
||||
|
||||
if not model_found:
|
||||
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
||||
|
||||
# Suggest pulling the model
|
||||
error_msg += "📦 To install this embedding model:\n"
|
||||
error_msg += f" ollama pull {model_name}\n\n"
|
||||
|
||||
# Show available embedding models
|
||||
if embedding_models:
|
||||
error_msg += "✅ Available embedding models:\n"
|
||||
for model in embedding_models[:5]:
|
||||
error_msg += f" • {model}\n"
|
||||
if len(embedding_models) > 5:
|
||||
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
||||
else:
|
||||
error_msg += "💡 Popular embedding models to install:\n"
|
||||
for model in suggested_embedding_models[:3]:
|
||||
error_msg += f" • ollama pull {model}\n"
|
||||
|
||||
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Verify the model supports embeddings by testing it
|
||||
try:
|
||||
test_response = requests.post(
|
||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
||||
)
|
||||
if test_response.status_code != 200:
|
||||
error_msg = (
|
||||
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
||||
f"Please use an embedding model like:\n"
|
||||
)
|
||||
for model in suggested_embedding_models[:3]:
|
||||
error_msg += f" • {model}\n"
|
||||
raise ValueError(error_msg)
|
||||
except requests.exceptions.RequestException:
|
||||
# If test fails, continue anyway - model might still work
|
||||
pass
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Could not verify model existence: {e}")
|
||||
|
||||
# Process embeddings with optimized concurrent processing
|
||||
import requests
|
||||
|
||||
def get_single_embedding(text_idx_tuple):
|
||||
"""Helper function to get embedding for a single text."""
|
||||
text, idx = text_idx_tuple
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
# Truncate very long texts to avoid API issues
|
||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{host}/api/embeddings",
|
||||
json={"model": model_name, "prompt": truncated_text},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
embedding = result.get("embedding")
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError(f"No embedding returned for text {idx}")
|
||||
|
||||
return idx, embedding
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
retry_count += 1
|
||||
if retry_count >= max_retries:
|
||||
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
|
||||
return idx, None
|
||||
|
||||
except Exception as e:
|
||||
if retry_count >= max_retries - 1:
|
||||
logger.error(f"Failed to get embedding for text {idx}: {e}")
|
||||
return idx, None
|
||||
retry_count += 1
|
||||
|
||||
return idx, None
|
||||
|
||||
# Determine if we should use concurrent processing
|
||||
use_concurrent = (
|
||||
len(texts) > 5 and not is_build
|
||||
) # Don't use concurrent in build mode to avoid overwhelming
|
||||
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
|
||||
|
||||
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
|
||||
failed_indices = []
|
||||
|
||||
if use_concurrent:
|
||||
logger.info(
|
||||
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_idx = {
|
||||
executor.submit(get_single_embedding, (text, idx)): idx
|
||||
for idx, text in enumerate(texts)
|
||||
}
|
||||
|
||||
# Add progress bar for concurrent processing
|
||||
try:
|
||||
if is_build or len(texts) > 10:
|
||||
from tqdm import tqdm
|
||||
|
||||
futures_iterator = tqdm(
|
||||
as_completed(future_to_idx),
|
||||
total=len(texts),
|
||||
desc="Computing Ollama embeddings",
|
||||
)
|
||||
else:
|
||||
futures_iterator = as_completed(future_to_idx)
|
||||
except ImportError:
|
||||
futures_iterator = as_completed(future_to_idx)
|
||||
|
||||
# Collect results as they complete
|
||||
for future in futures_iterator:
|
||||
try:
|
||||
idx, embedding = future.result()
|
||||
if embedding is not None:
|
||||
all_embeddings[idx] = embedding
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
except Exception as e:
|
||||
idx = future_to_idx[future]
|
||||
logger.error(f"Exception for text {idx}: {e}")
|
||||
failed_indices.append(idx)
|
||||
|
||||
else:
|
||||
# Sequential processing with progress bar
|
||||
show_progress = is_build or len(texts) > 10
|
||||
|
||||
try:
|
||||
if show_progress:
|
||||
from tqdm import tqdm
|
||||
|
||||
iterator = tqdm(
|
||||
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
|
||||
)
|
||||
else:
|
||||
iterator = enumerate(texts)
|
||||
except ImportError:
|
||||
iterator = enumerate(texts)
|
||||
|
||||
for idx, text in iterator:
|
||||
result_idx, embedding = get_single_embedding((text, idx))
|
||||
if embedding is not None:
|
||||
all_embeddings[idx] = embedding
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
|
||||
# Handle failed embeddings
|
||||
if failed_indices:
|
||||
if len(failed_indices) == len(texts):
|
||||
raise RuntimeError("Failed to compute any embeddings")
|
||||
|
||||
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
|
||||
|
||||
# Use zero embeddings as fallback for failed ones
|
||||
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
||||
if valid_embedding:
|
||||
embedding_dim = len(valid_embedding)
|
||||
for idx in failed_indices:
|
||||
all_embeddings[idx] = [0.0] * embedding_dim
|
||||
|
||||
# Remove None values and convert to numpy array
|
||||
all_embeddings = [e for e in all_embeddings if e is not None]
|
||||
|
||||
# Convert to numpy array and normalize
|
||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
# Normalize embeddings (L2 normalization)
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
||||
|
||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -182,8 +184,8 @@ class EmbeddingServerManager:
|
||||
e.g., "leann_backend_diskann.embedding_server"
|
||||
"""
|
||||
self.backend_module_name = backend_module_name
|
||||
self.server_process: subprocess.Popen | None = None
|
||||
self.server_port: int | None = None
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.server_port: Optional[int] = None
|
||||
self._atexit_registered = False
|
||||
|
||||
def start_server(
|
||||
@@ -293,6 +295,8 @@ class EmbeddingServerManager:
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
if embedding_mode != "sentence-transformers":
|
||||
command.extend(["--embedding-mode", embedding_mode])
|
||||
if kwargs.get("distance_metric"):
|
||||
command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||
|
||||
return command
|
||||
|
||||
@@ -301,13 +305,24 @@ class EmbeddingServerManager:
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
logger.info(f"Command: {' '.join(command)}")
|
||||
|
||||
# Let server output go directly to console
|
||||
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||
# In CI environment, redirect output to avoid buffer deadlock
|
||||
# Embedding servers use many print statements that can fill buffers
|
||||
is_ci = os.environ.get("CI") == "true"
|
||||
if is_ci:
|
||||
stdout_target = subprocess.DEVNULL
|
||||
stderr_target = subprocess.DEVNULL
|
||||
logger.info("CI environment detected, redirecting embedding server output to DEVNULL")
|
||||
else:
|
||||
stdout_target = None # Direct to console for visible logs
|
||||
stderr_target = None # Direct to console for visible logs
|
||||
|
||||
# IMPORTANT: Use a new session so we can manage the whole process group reliably
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=None, # Direct to console
|
||||
stderr=None, # Direct to console
|
||||
stdout=stdout_target,
|
||||
stderr=stderr_target,
|
||||
start_new_session=True,
|
||||
)
|
||||
self.server_port = port
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
@@ -349,34 +364,50 @@ class EmbeddingServerManager:
|
||||
logger.info(
|
||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||
)
|
||||
self.server_process.terminate()
|
||||
# Try terminating the whole process group first (POSIX)
|
||||
try:
|
||||
pgid = os.getpgid(self.server_process.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception:
|
||||
# Fallback to terminating just the process
|
||||
self.server_process.terminate()
|
||||
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
self.server_process.wait(timeout=3)
|
||||
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it."
|
||||
)
|
||||
self.server_process.kill()
|
||||
|
||||
# Clean up process resources to prevent resource tracker warnings
|
||||
try:
|
||||
self.server_process.wait() # Ensure process is fully cleaned up
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
pgid = os.getpgid(self.server_process.pid)
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except Exception:
|
||||
self.server_process.kill()
|
||||
try:
|
||||
self.server_process.wait(timeout=2)
|
||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(
|
||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
||||
)
|
||||
# Don't hang indefinitely
|
||||
|
||||
# Clean up process resources without waiting
|
||||
# The process should already be terminated/killed above
|
||||
# Don't wait here as it can hang CI indefinitely
|
||||
self.server_process = None
|
||||
|
||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||
"""Launch the server process with Colab-specific settings."""
|
||||
logger.info(f"Colab Command: {' '.join(command)}")
|
||||
|
||||
# In Colab, we need to be more careful about process management
|
||||
# In Colab, redirect to DEVNULL to avoid pipe blocking
|
||||
# PIPE without reading can cause hangs
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
self.server_port = port
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -34,7 +34,9 @@ class LeannBackendSearcherInterface(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
||||
def _ensure_server_running(
|
||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
||||
) -> int:
|
||||
"""Ensure server is running"""
|
||||
pass
|
||||
|
||||
@@ -48,7 +50,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Search for nearest neighbors
|
||||
@@ -74,7 +76,7 @@ class LeannBackendSearcherInterface(ABC):
|
||||
self,
|
||||
query: str,
|
||||
use_server_if_available: bool = True,
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Compute embedding for a query string
|
||||
|
||||
|
||||
176
packages/leann-core/src/leann/mcp.py
Executable file
176
packages/leann-core/src/leann/mcp.py
Executable file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def handle_request(request):
|
||||
if request.get("method") == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"capabilities": {"tools": {}},
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
elif request.get("method") == "tools/list":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "leann_search",
|
||||
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
|
||||
|
||||
🎯 **Perfect for**:
|
||||
- "How does authentication work?" → finds auth-related code
|
||||
- "Error handling patterns" → locates try-catch blocks and error logic
|
||||
- "Database connection setup" → finds DB initialization code
|
||||
- "API endpoint definitions" → locates route handlers
|
||||
- "Configuration management" → finds config files and usage
|
||||
|
||||
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_name": {
|
||||
"type": "string",
|
||||
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"default": 5,
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
|
||||
},
|
||||
"complexity": {
|
||||
"type": "integer",
|
||||
"default": 32,
|
||||
"minimum": 16,
|
||||
"maximum": 128,
|
||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
||||
},
|
||||
},
|
||||
"required": ["index_name", "query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leann_status",
|
||||
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_name": {
|
||||
"type": "string",
|
||||
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "leann_list",
|
||||
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
elif request.get("method") == "tools/call":
|
||||
tool_name = request["params"]["name"]
|
||||
args = request["params"].get("arguments", {})
|
||||
|
||||
try:
|
||||
if tool_name == "leann_search":
|
||||
# Validate required parameters
|
||||
if not args.get("index_name") or not args.get("query"):
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Error: Both index_name and query are required",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Build simplified command
|
||||
cmd = [
|
||||
"leann",
|
||||
"search",
|
||||
args["index_name"],
|
||||
args["query"],
|
||||
f"--top-k={args.get('top_k', 5)}",
|
||||
f"--complexity={args.get('complexity', 32)}",
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_status":
|
||||
if args.get("index_name"):
|
||||
# Check specific index status - for now, we'll use leann list and filter
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
# We could enhance this to show more detailed status per index
|
||||
else:
|
||||
# Show all indexes status
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
|
||||
elif tool_name == "leann_list":
|
||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": result.stdout
|
||||
if result.returncode == 0
|
||||
else f"Error: {result.stderr}",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"error": {"code": -1, "message": str(e)},
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
for line in sys.stdin:
|
||||
try:
|
||||
request = json.loads(line.strip())
|
||||
response = handle_request(request)
|
||||
if response:
|
||||
print(json.dumps(response))
|
||||
sys.stdout.flush()
|
||||
except Exception as e:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": None,
|
||||
"error": {"code": -1, "message": str(e)},
|
||||
}
|
||||
print(json.dumps(error_response))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -63,12 +63,19 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
if not self.embedding_model:
|
||||
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||
|
||||
# Get distance_metric from meta if not provided in kwargs
|
||||
distance_metric = (
|
||||
kwargs.get("distance_metric")
|
||||
or self.meta.get("backend_kwargs", {}).get("distance_metric")
|
||||
or "mips"
|
||||
)
|
||||
|
||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||
port=port,
|
||||
model_name=self.embedding_model,
|
||||
embedding_mode=self.embedding_mode,
|
||||
passages_file=passages_source_file,
|
||||
distance_metric=kwargs.get("distance_metric"),
|
||||
distance_metric=distance_metric,
|
||||
enable_warmup=kwargs.get("enable_warmup", False),
|
||||
)
|
||||
if not server_started:
|
||||
@@ -125,10 +132,15 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
context = None
|
||||
socket = None
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
|
||||
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
|
||||
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout
|
||||
socket.setsockopt(zmq.IMMEDIATE, 1) # Don't wait for connection
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
|
||||
# Send embedding request
|
||||
@@ -140,9 +152,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
response_bytes = socket.recv()
|
||||
response = msgpack.unpackb(response_bytes)
|
||||
|
||||
socket.close()
|
||||
context.term()
|
||||
|
||||
# Convert response to numpy array
|
||||
if isinstance(response, list) and len(response) > 0:
|
||||
return np.array(response, dtype=np.float32)
|
||||
@@ -151,6 +160,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to compute embeddings via server: {e}")
|
||||
finally:
|
||||
if socket:
|
||||
socket.close(linger=0)
|
||||
if context:
|
||||
context.term()
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
@@ -162,7 +176,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
prune_ratio: float = 0.0,
|
||||
recompute_embeddings: bool = False,
|
||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||
zmq_port: int | None = None,
|
||||
zmq_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -184,7 +198,27 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Ensures the embedding server is stopped when the searcher is destroyed."""
|
||||
def cleanup(self):
|
||||
"""Cleanup resources including embedding server and ZMQ connections."""
|
||||
# Stop embedding server
|
||||
if hasattr(self, "embedding_server_manager"):
|
||||
self.embedding_server_manager.stop_server()
|
||||
|
||||
# Set ZMQ linger but don't terminate global context
|
||||
try:
|
||||
import zmq
|
||||
|
||||
# Just set linger on the global instance
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.linger = 0
|
||||
# NEVER call ctx.term() on the global instance
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Ensures resources are cleaned up when the searcher is destroyed."""
|
||||
try:
|
||||
self.cleanup()
|
||||
except Exception:
|
||||
# Ignore errors during destruction
|
||||
pass
|
||||
|
||||
91
packages/leann-mcp/README.md
Normal file
91
packages/leann-mcp/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# 🔥 LEANN Claude Code Integration
|
||||
|
||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
**Step 1:** First, complete the basic LEANN installation following the [📦 Installation guide](../../README.md#installation) in the root README:
|
||||
|
||||
```bash
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
uv pip install leann
|
||||
```
|
||||
|
||||
**Step 2:** Install LEANN globally for MCP integration:
|
||||
```bash
|
||||
uv tool install leann-core
|
||||
```
|
||||
|
||||
This makes the `leann` command available system-wide, which `leann_mcp` requires.
|
||||
|
||||
## 🚀 Quick Setup
|
||||
|
||||
Add the LEANN MCP server to Claude Code:
|
||||
|
||||
```bash
|
||||
claude mcp add leann-server -- leann_mcp
|
||||
```
|
||||
|
||||
## 🛠️ Available Tools
|
||||
|
||||
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
|
||||
|
||||
- **`leann_list`** - List all available indexes across your projects
|
||||
- **`leann_search`** - Perform semantic searches across code and documents
|
||||
- **`leann_ask`** - Ask natural language questions and get AI-powered answers from your codebase
|
||||
|
||||
## 🎯 Quick Start Example
|
||||
|
||||
```bash
|
||||
# Build an index for your project (change to your actual path)
|
||||
leann build my-project --docs ./
|
||||
|
||||
# Start Claude Code
|
||||
claude
|
||||
```
|
||||
|
||||
**Try this in Claude Code:**
|
||||
```
|
||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
||||
</p>
|
||||
|
||||
|
||||
## 🧠 How It Works
|
||||
|
||||
The integration consists of three key components working seamlessly together:
|
||||
|
||||
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
|
||||
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
|
||||
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
|
||||
|
||||
## 📁 File Support
|
||||
|
||||
LEANN understands **30+ file types** including:
|
||||
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
|
||||
- **Data**: SQL, YAML, JSON, CSV, XML
|
||||
- **Documentation**: Markdown, TXT, PDF
|
||||
- **And many more!**
|
||||
|
||||
## 💾 Storage & Organization
|
||||
|
||||
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
|
||||
- **Global registry**: Project tracking at `~/.leann/projects.json`
|
||||
- **Multi-project support**: Switch between different codebases seamlessly
|
||||
- **Portable**: Transfer indexes between machines with minimal overhead
|
||||
|
||||
## 🗑️ Uninstalling
|
||||
|
||||
To remove the LEANN MCP server from Claude Code:
|
||||
|
||||
```bash
|
||||
claude mcp remove leann-server
|
||||
```
|
||||
To remove LEANN
|
||||
```
|
||||
uv pip uninstall leann leann-backend-hnsw leann-core
|
||||
```
|
||||
@@ -5,36 +5,32 @@ LEANN is a revolutionary vector database that democratizes personal AI. Transfor
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Default installation (HNSW backend, recommended)
|
||||
# Default installation (includes both HNSW and DiskANN backends)
|
||||
uv pip install leann
|
||||
|
||||
# With DiskANN backend (for large-scale deployments)
|
||||
uv pip install leann[diskann]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
from pathlib import Path
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
# Build an index
|
||||
builder = LeannBuilder(backend_name="hnsw")
|
||||
# Build an index (choose backend: "hnsw" or "diskann")
|
||||
builder = LeannBuilder(backend_name="hnsw") # or "diskann" for large-scale deployments
|
||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||
builder.build_index("my_index.leann")
|
||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||
builder.build_index(INDEX_PATH)
|
||||
|
||||
# Search
|
||||
searcher = LeannSearcher("my_index.leann")
|
||||
results = searcher.search("storage savings", top_k=3)
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||
|
||||
# Chat with your data
|
||||
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
|
||||
response = chat.ask("How much storage does LEANN save?")
|
||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
MIT License
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "leann"
|
||||
version = "0.1.14"
|
||||
version = "0.2.7"
|
||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
@@ -24,19 +24,16 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
# Default installation: core + hnsw
|
||||
# Default installation: core + hnsw + diskann
|
||||
dependencies = [
|
||||
"leann-core>=0.1.0",
|
||||
"leann-backend-hnsw>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
diskann = [
|
||||
"leann-backend-diskann>=0.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# All backends now included by default
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/yourusername/leann"
|
||||
Documentation = "https://leann.readthedocs.io"
|
||||
Repository = "https://github.com/yourusername/leann"
|
||||
Issues = "https://github.com/yourusername/leann/issues"
|
||||
Repository = "https://github.com/yichuan-w/LEANN"
|
||||
Issues = "https://github.com/yichuan-w/LEANN/issues"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import xml.etree.ElementTree as ET
|
||||
import xml.etree.ElementTree as ElementTree
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
@@ -26,7 +26,7 @@ def get_safe_path(s: str) -> str:
|
||||
def process_history(history: str):
|
||||
if history.startswith("<?xml") or history.startswith("<msg>"):
|
||||
try:
|
||||
root = ET.fromstring(history)
|
||||
root = ElementTree.fromstring(history)
|
||||
title = root.find(".//title").text if root.find(".//title") is not None else None
|
||||
quoted = (
|
||||
root.find(".//refermsg/content").text
|
||||
@@ -52,7 +52,8 @@ def get_message(history: dict | str):
|
||||
|
||||
def export_chathistory(user_id: str):
|
||||
res = requests.get(
|
||||
"http://localhost:48065/wechat/chatlog", params={"userId": user_id, "count": 100000}
|
||||
"http://localhost:48065/wechat/chatlog",
|
||||
params={"userId": user_id, "count": 100000},
|
||||
).json()
|
||||
for i in range(len(res["chatLogs"])):
|
||||
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
||||
@@ -116,7 +117,8 @@ def export_sqlite(
|
||||
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||
for user in tqdm(all_users):
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user["arg"], user["title"])
|
||||
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
|
||||
(user["arg"], user["title"]),
|
||||
)
|
||||
usr_chatlog = export_chathistory(user["arg"])
|
||||
for msg in usr_chatlog:
|
||||
|
||||
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "leann-workspace"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
|
||||
dependencies = [
|
||||
"leann-core",
|
||||
@@ -32,9 +32,9 @@ dependencies = [
|
||||
"pypdfium2>=4.30.0",
|
||||
# LlamaIndex core and readers - updated versions
|
||||
"llama-index>=0.12.44",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||
"llama-index-readers-docling",
|
||||
"llama-index-node-parser-docling",
|
||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||
"llama-index-vector-stores-faiss>=0.4.0",
|
||||
"llama-index-embeddings-huggingface>=0.5.5",
|
||||
# Other dependencies
|
||||
@@ -43,16 +43,33 @@ dependencies = [
|
||||
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||
"psutil>=5.8.0",
|
||||
"pybind11>=3.0.0",
|
||||
"pathspec>=0.12.1",
|
||||
"nbconvert>=7.16.6",
|
||||
"gitignore-parser>=0.1.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
"pytest-cov>=4.0",
|
||||
"pytest>=8.3.0", # Minimum version for Python 3.13 support
|
||||
"pytest-cov>=5.0",
|
||||
"pytest-xdist>=3.5", # For parallel test execution
|
||||
"black>=23.0",
|
||||
"ruff>=0.1.0",
|
||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
||||
"matplotlib",
|
||||
"huggingface-hub>=0.20.0",
|
||||
"pre-commit>=3.5.0",
|
||||
]
|
||||
|
||||
test = [
|
||||
"pytest>=8.3.0", # Minimum version for Python 3.13 support
|
||||
"pytest-timeout>=2.3",
|
||||
"anyio>=4.0", # For async test support (includes pytest plugin)
|
||||
"psutil>=5.9.0", # For process cleanup in tests
|
||||
"llama-index-core>=0.12.0",
|
||||
"llama-index-readers-file>=0.4.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"sentence-transformers>=2.2.0",
|
||||
]
|
||||
|
||||
diskann = [
|
||||
@@ -77,7 +94,7 @@ leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = tr
|
||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
target-version = "py39"
|
||||
line-length = 100
|
||||
extend-exclude = [
|
||||
"third_party",
|
||||
@@ -122,3 +139,33 @@ line-ending = "auto"
|
||||
dev = [
|
||||
"ruff>=0.12.4",
|
||||
]
|
||||
|
||||
[tool.lychee]
|
||||
accept = ["200", "403", "429", "503"]
|
||||
timeout = 20
|
||||
max_retries = 2
|
||||
exclude = ["localhost", "127.0.0.1", "example.com"]
|
||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
||||
scheme = ["https", "http"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"openai: marks tests that require OpenAI API key",
|
||||
]
|
||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
||||
timeout_method = "thread" # Use thread method to avoid non-daemon thread issues
|
||||
addopts = [
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"--strict-markers",
|
||||
"--disable-warnings",
|
||||
]
|
||||
env = [
|
||||
"HF_HUB_DISABLE_SYMLINKS=1",
|
||||
"TOKENIZERS_PARALLELISM=false",
|
||||
]
|
||||
|
||||
@@ -19,16 +19,16 @@ uv pip install build twine delocate auditwheel scikit-build-core cmake pybind11
|
||||
build_package() {
|
||||
local package_dir=$1
|
||||
local package_name=$(basename $package_dir)
|
||||
|
||||
|
||||
echo "Building $package_name..."
|
||||
cd $package_dir
|
||||
|
||||
|
||||
# Clean previous builds
|
||||
rm -rf dist/ build/ _skbuild/
|
||||
|
||||
|
||||
# Build directly with pip wheel (avoids sdist issues)
|
||||
pip wheel . --no-deps -w dist
|
||||
|
||||
|
||||
# Repair wheel for binary packages
|
||||
if [[ "$package_name" != "leann-core" ]] && [[ "$package_name" != "leann" ]]; then
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
@@ -57,7 +57,7 @@ build_package() {
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
echo "Built wheels in $package_dir/dist/"
|
||||
ls -la dist/
|
||||
cd - > /dev/null
|
||||
@@ -84,4 +84,4 @@ else
|
||||
fi
|
||||
|
||||
echo -e "\nBuild complete! Test with:"
|
||||
echo "uv pip install packages/*/dist/*.whl"
|
||||
echo "uv pip install packages/*/dist/*.whl"
|
||||
|
||||
@@ -28,4 +28,4 @@ else
|
||||
fi
|
||||
|
||||
echo "✅ Version updated to $NEW_VERSION"
|
||||
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"
|
||||
echo "✅ Dependencies updated to use leann-core==$NEW_VERSION"
|
||||
|
||||
103
scripts/diagnose_hang.sh
Executable file
103
scripts/diagnose_hang.sh
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/bin/bash
|
||||
# Diagnostic script for debugging CI hangs
|
||||
|
||||
echo "========================================="
|
||||
echo " CI HANG DIAGNOSTIC SCRIPT"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
echo "📅 Current time: $(date)"
|
||||
echo "🖥️ Hostname: $(hostname)"
|
||||
echo "👤 User: $(whoami)"
|
||||
echo "📂 Working directory: $(pwd)"
|
||||
echo ""
|
||||
|
||||
echo "=== PYTHON ENVIRONMENT ==="
|
||||
python --version 2>&1 || echo "Python not found"
|
||||
pip list 2>&1 | head -20 || echo "pip not available"
|
||||
echo ""
|
||||
|
||||
echo "=== PROCESS INFORMATION ==="
|
||||
echo "Current shell PID: $$"
|
||||
echo "Parent PID: $PPID"
|
||||
echo ""
|
||||
|
||||
echo "All Python processes:"
|
||||
ps aux | grep -E "[p]ython" || echo "No Python processes"
|
||||
echo ""
|
||||
|
||||
echo "All pytest processes:"
|
||||
ps aux | grep -E "[p]ytest" || echo "No pytest processes"
|
||||
echo ""
|
||||
|
||||
echo "Embedding server processes:"
|
||||
ps aux | grep -E "[e]mbedding_server" || echo "No embedding server processes"
|
||||
echo ""
|
||||
|
||||
echo "Zombie processes:"
|
||||
ps aux | grep "<defunct>" || echo "No zombie processes"
|
||||
echo ""
|
||||
|
||||
echo "=== NETWORK INFORMATION ==="
|
||||
echo "Network listeners on typical embedding server ports:"
|
||||
ss -ltn 2>/dev/null | grep -E ":555[0-9]|:556[0-9]" || netstat -ltn 2>/dev/null | grep -E ":555[0-9]|:556[0-9]" || echo "No listeners on embedding ports"
|
||||
echo ""
|
||||
|
||||
echo "All network listeners:"
|
||||
ss -ltn 2>/dev/null | head -20 || netstat -ltn 2>/dev/null | head -20 || echo "Cannot get network info"
|
||||
echo ""
|
||||
|
||||
echo "=== FILE DESCRIPTORS ==="
|
||||
echo "Open files for current shell:"
|
||||
lsof -p $$ 2>/dev/null | head -20 || echo "lsof not available"
|
||||
echo ""
|
||||
|
||||
if [ -d "/proc/$$" ]; then
|
||||
echo "File descriptors for current shell (/proc/$$/fd):"
|
||||
ls -la /proc/$$/fd 2>/dev/null | head -20 || echo "Cannot access /proc/$$/fd"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
echo "=== SYSTEM RESOURCES ==="
|
||||
echo "Memory usage:"
|
||||
free -h 2>/dev/null || vm_stat 2>/dev/null || echo "Cannot get memory info"
|
||||
echo ""
|
||||
|
||||
echo "Disk usage:"
|
||||
df -h . 2>/dev/null || echo "Cannot get disk info"
|
||||
echo ""
|
||||
|
||||
echo "CPU info:"
|
||||
nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo "Cannot get CPU info"
|
||||
echo ""
|
||||
|
||||
echo "=== PYTHON SPECIFIC CHECKS ==="
|
||||
python -c "
|
||||
import sys
|
||||
import os
|
||||
print(f'Python executable: {sys.executable}')
|
||||
print(f'Python path: {sys.path[:3]}...')
|
||||
print(f'Environment PYTHONPATH: {os.environ.get(\"PYTHONPATH\", \"Not set\")}')
|
||||
print(f'Site packages: {[p for p in sys.path if \"site-packages\" in p][:2]}')
|
||||
" 2>&1 || echo "Cannot run Python diagnostics"
|
||||
echo ""
|
||||
|
||||
echo "=== ZMQ SPECIFIC CHECKS ==="
|
||||
python -c "
|
||||
try:
|
||||
import zmq
|
||||
print(f'ZMQ version: {zmq.zmq_version()}')
|
||||
print(f'PyZMQ version: {zmq.pyzmq_version()}')
|
||||
ctx = zmq.Context.instance()
|
||||
print(f'ZMQ context instance: {ctx}')
|
||||
except Exception as e:
|
||||
print(f'ZMQ check failed: {e}')
|
||||
" 2>&1 || echo "Cannot check ZMQ"
|
||||
echo ""
|
||||
|
||||
echo "=== PYTEST CHECK ==="
|
||||
pytest --version 2>&1 || echo "pytest not found"
|
||||
echo ""
|
||||
|
||||
echo "=== END OF DIAGNOSTICS ==="
|
||||
echo "Generated at: $(date)"
|
||||
@@ -15,4 +15,4 @@ VERSION=$1
|
||||
git add . && git commit -m "chore: bump version to $VERSION" && git push
|
||||
|
||||
# Create release (triggers CI)
|
||||
gh release create v$VERSION --generate-notes
|
||||
gh release create v$VERSION --generate-notes
|
||||
|
||||
@@ -27,4 +27,4 @@ else
|
||||
else
|
||||
echo "Cancelled"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if (
|
||||
part.get_content_type() == "text/plain"
|
||||
or part.get_content_type() == "text/html"
|
||||
):
|
||||
body += part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
# break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
if count == 0:
|
||||
print("--------------------------------")
|
||||
print("dir path", dirpath)
|
||||
print(metadata)
|
||||
print(doc_content)
|
||||
print("--------------------------------")
|
||||
body = []
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
print(
|
||||
"-------------------------------- get content type -------------------------------"
|
||||
)
|
||||
print(part.get_content_type())
|
||||
print(part)
|
||||
# body.append(part.get_payload(decode=True).decode('utf-8', errors='ignore'))
|
||||
print(
|
||||
"-------------------------------- get content type -------------------------------"
|
||||
)
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
print(body)
|
||||
|
||||
print(body)
|
||||
print("--------------------------------")
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"!!!!!!! Error parsing email from {filepath}: {e} !!!!!!!!")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"!!!!!!! Error reading file !!!!!!!! {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
# Use the custom EmlxReader instead of MboxReader
|
||||
documents = EmlxReader().load_data(
|
||||
"/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||
max_count=1000,
|
||||
) # Returns list of documents
|
||||
|
||||
# Configure the index with larger chunk size to handle long metadata
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Create a custom text splitter with larger chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=200)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents, transformations=[text_splitter]
|
||||
) # Initialize index with documents
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||
print(res)
|
||||
@@ -1,219 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index", max_count: int = 1000):
|
||||
"""
|
||||
Create the index from mail data and save it to disk.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
save_dir: Directory to save the index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating index from mail data...")
|
||||
|
||||
# Load documents
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
# Create text splitter
|
||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=0)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||
|
||||
# Save the index
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||
|
||||
if index:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"What emails mention GSR appointments?",
|
||||
"Find emails about deadlines",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,219 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, StorageContext, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with reduced metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create minimal metadata (only essential info)
|
||||
metadata = {
|
||||
"subject": subject[:50], # Truncate subject
|
||||
"from": from_addr[:30], # Truncate from
|
||||
"date": date[:20], # Truncate date
|
||||
"filename": filename, # Keep filename
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def create_and_save_index(
|
||||
mail_path: str, save_dir: str = "mail_index_small", max_count: int = 1000
|
||||
):
|
||||
"""
|
||||
Create the index from mail data and save it to disk.
|
||||
|
||||
Args:
|
||||
mail_path: Path to the mail directory
|
||||
save_dir: Directory to save the index
|
||||
max_count: Maximum number of emails to process
|
||||
"""
|
||||
print("Creating index from mail data with small chunks...")
|
||||
|
||||
# Load documents
|
||||
documents = EmlxReader().load_data(mail_path, max_count=max_count)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Exiting.")
|
||||
return None
|
||||
|
||||
# Create text splitter with small chunk size
|
||||
text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(documents, transformations=[text_splitter])
|
||||
|
||||
# Save the index
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index.storage_context.persist(persist_dir=save_dir)
|
||||
print(f"Index saved to {save_dir}")
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index_small"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"Query: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
mail_path = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages"
|
||||
save_dir = "mail_index_small"
|
||||
|
||||
# Check if index already exists
|
||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||
print("Loading existing index...")
|
||||
index = load_index(save_dir)
|
||||
else:
|
||||
print("Creating new index...")
|
||||
index = create_and_save_index(mail_path, save_dir, max_count=1000)
|
||||
|
||||
if index:
|
||||
# Example queries
|
||||
queries = [
|
||||
"Hows Berkeley Graduate Student Instructor",
|
||||
"What emails mention GSR appointments?",
|
||||
"Find emails about deadlines",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print("\n" + "=" * 50)
|
||||
query_index(index, query)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,154 +0,0 @@
|
||||
import email
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
|
||||
# Check if directory exists and is accessible
|
||||
if not os.path.exists(input_dir):
|
||||
print(f"Error: Directory '{input_dir}' does not exist")
|
||||
return docs
|
||||
|
||||
if not os.access(input_dir, os.R_OK):
|
||||
print(f"Error: Directory '{input_dir}' is not accessible (permission denied)")
|
||||
print("This is likely due to macOS security restrictions on Mail app data")
|
||||
return docs
|
||||
|
||||
print(f"Scanning directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
if count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
print(f"Found .emlx file: {filepath}")
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = part.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
break
|
||||
else:
|
||||
body = msg.get_payload(decode=True).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
|
||||
# Create document content
|
||||
doc_content = f"""
|
||||
From: {from_addr}
|
||||
To: {to_addr}
|
||||
Subject: {subject}
|
||||
Date: {date}
|
||||
|
||||
{body}
|
||||
"""
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"file_path": filepath,
|
||||
"subject": subject,
|
||||
"from": from_addr,
|
||||
"to": to_addr,
|
||||
"date": date,
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Loaded {len(docs)} email documents")
|
||||
return docs
|
||||
|
||||
|
||||
def main():
|
||||
# Use the current directory where the sample.emlx file is located
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
print("Testing EmlxReader with sample .emlx file...")
|
||||
print(f"Scanning directory: {current_dir}")
|
||||
|
||||
# Use the custom EmlxReader
|
||||
documents = EmlxReader().load_data(current_dir, max_count=1000)
|
||||
|
||||
if not documents:
|
||||
print("No documents loaded. Make sure sample.emlx exists in the examples directory.")
|
||||
return
|
||||
|
||||
print(f"\nSuccessfully loaded {len(documents)} document(s)")
|
||||
|
||||
# Initialize index with documents
|
||||
index = VectorStoreIndex.from_documents(documents)
|
||||
query_engine = index.as_query_engine()
|
||||
|
||||
print("\nTesting query: 'Hows Berkeley Graduate Student Instructor'")
|
||||
res = query_engine.query("Hows Berkeley Graduate Student Instructor")
|
||||
print(f"Response: {res}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,105 +0,0 @@
|
||||
import os
|
||||
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
|
||||
|
||||
def load_index(save_dir: str = "mail_index"):
|
||||
"""
|
||||
Load the saved index from disk.
|
||||
|
||||
Args:
|
||||
save_dir: Directory where the index is saved
|
||||
|
||||
Returns:
|
||||
Loaded index or None if loading fails
|
||||
"""
|
||||
try:
|
||||
# Load storage context
|
||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||
|
||||
# Load index
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
storage_context.vector_store, storage_context=storage_context
|
||||
)
|
||||
|
||||
print(f"Index loaded from {save_dir}")
|
||||
return index
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading index: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def query_index(index, query: str):
|
||||
"""
|
||||
Query the loaded index.
|
||||
|
||||
Args:
|
||||
index: The loaded index
|
||||
query: The query string
|
||||
"""
|
||||
if index is None:
|
||||
print("No index available for querying.")
|
||||
return
|
||||
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query(query)
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Response: {response}")
|
||||
|
||||
|
||||
def main():
|
||||
save_dir = "mail_index"
|
||||
|
||||
# Check if index exists
|
||||
if not os.path.exists(save_dir) or not os.path.exists(
|
||||
os.path.join(save_dir, "vector_store.json")
|
||||
):
|
||||
print(f"Index not found in {save_dir}")
|
||||
print("Please run mail_reader_save_load.py first to create the index.")
|
||||
return
|
||||
|
||||
# Load the index
|
||||
index = load_index(save_dir)
|
||||
|
||||
if not index:
|
||||
print("Failed to load index.")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Email Query Interface")
|
||||
print("=" * 60)
|
||||
print("Type 'quit' to exit")
|
||||
print("Type 'help' for example queries")
|
||||
print("=" * 60)
|
||||
|
||||
# Interactive query loop
|
||||
while True:
|
||||
try:
|
||||
query = input("\nEnter your query: ").strip()
|
||||
|
||||
if query.lower() == "quit":
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("\nExample queries:")
|
||||
print("- Hows Berkeley Graduate Student Instructor")
|
||||
print("- What emails mention GSR appointments?")
|
||||
print("- Find emails about deadlines")
|
||||
print("- Search for emails from specific sender")
|
||||
print("- Find emails about meetings")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
query_index(index, query)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error processing query: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug script to test ZMQ communication with the exact same setup as main_cli_example.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import zmq
|
||||
|
||||
sys.path.append("packages/leann-backend-diskann")
|
||||
from leann_backend_diskann import embedding_pb2
|
||||
|
||||
|
||||
def test_zmq_with_same_model():
|
||||
print("=== Testing ZMQ with same model as main_cli_example.py ===")
|
||||
|
||||
# Test the exact same model that main_cli_example.py uses
|
||||
model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
# Start server with the same model
|
||||
import subprocess
|
||||
|
||||
server_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||
"--zmq-port",
|
||||
"5556", # Use different port to avoid conflicts
|
||||
"--model-name",
|
||||
model_name,
|
||||
]
|
||||
|
||||
print(f"Starting server with command: {' '.join(server_cmd)}")
|
||||
server_process = subprocess.Popen(
|
||||
server_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
|
||||
)
|
||||
|
||||
# Wait for server to start
|
||||
print("Waiting for server to start...")
|
||||
time.sleep(10)
|
||||
|
||||
# Check if server is running
|
||||
if server_process.poll() is not None:
|
||||
stdout, stderr = server_process.communicate()
|
||||
print(f"Server failed to start. stdout: {stdout}")
|
||||
print(f"Server failed to start. stderr: {stderr}")
|
||||
return False
|
||||
|
||||
print(f"Server started with PID: {server_process.pid}")
|
||||
|
||||
try:
|
||||
# Test client
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect("tcp://127.0.0.1:5556")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout like C++
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
# Create request with same format as C++
|
||||
request = embedding_pb2.NodeEmbeddingRequest()
|
||||
request.node_ids.extend([0, 1, 2, 3, 4]) # Test with some node IDs
|
||||
|
||||
print(f"Sending request with {len(request.node_ids)} node IDs...")
|
||||
start_time = time.time()
|
||||
|
||||
# Send request
|
||||
socket.send(request.SerializeToString())
|
||||
|
||||
# Receive response
|
||||
response_data = socket.recv()
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Received response in {end_time - start_time:.3f} seconds")
|
||||
print(f"Response size: {len(response_data)} bytes")
|
||||
|
||||
# Parse response
|
||||
response = embedding_pb2.NodeEmbeddingResponse()
|
||||
response.ParseFromString(response_data)
|
||||
|
||||
print(f"Response dimensions: {list(response.dimensions)}")
|
||||
print(f"Embeddings data size: {len(response.embeddings_data)} bytes")
|
||||
print(f"Missing IDs: {list(response.missing_ids)}")
|
||||
|
||||
# Calculate expected size
|
||||
if len(response.dimensions) == 2:
|
||||
batch_size = response.dimensions[0]
|
||||
embedding_dim = response.dimensions[1]
|
||||
expected_bytes = batch_size * embedding_dim * 4 # 4 bytes per float
|
||||
print(f"Expected bytes: {expected_bytes}, Actual: {len(response.embeddings_data)}")
|
||||
|
||||
if len(response.embeddings_data) == expected_bytes:
|
||||
print("✅ Response format is correct!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Response format mismatch!")
|
||||
return False
|
||||
else:
|
||||
print("❌ Invalid response dimensions!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during ZMQ test: {e}")
|
||||
return False
|
||||
finally:
|
||||
# Clean up
|
||||
server_process.terminate()
|
||||
server_process.wait()
|
||||
print("Server terminated")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_zmq_with_same_model()
|
||||
if success:
|
||||
print("\n✅ ZMQ communication test passed!")
|
||||
else:
|
||||
print("\n❌ ZMQ communication test failed!")
|
||||
106
tests/README.md
Normal file
106
tests/README.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# LEANN Tests
|
||||
|
||||
This directory contains automated tests for the LEANN project using pytest.
|
||||
|
||||
## Test Files
|
||||
|
||||
### `test_readme_examples.py`
|
||||
Tests the examples shown in README.md:
|
||||
- The basic example code that users see first (parametrized for both HNSW and DiskANN backends)
|
||||
- Import statements work correctly
|
||||
- Different backend options (HNSW, DiskANN)
|
||||
- Different LLM configuration options (parametrized for both backends)
|
||||
- **All main README examples are tested with both HNSW and DiskANN backends using pytest parametrization**
|
||||
|
||||
### `test_basic.py`
|
||||
Basic functionality tests that verify:
|
||||
- All packages can be imported correctly
|
||||
- C++ extensions (FAISS, DiskANN) load properly
|
||||
- Basic index building and searching works for both HNSW and DiskANN backends
|
||||
- Uses parametrized tests to test both backends
|
||||
|
||||
### `test_document_rag.py`
|
||||
Tests the document RAG example functionality:
|
||||
- Tests with facebook/contriever embeddings
|
||||
- Tests with OpenAI embeddings (if API key is available)
|
||||
- Tests error handling with invalid parameters
|
||||
- Verifies that normalized embeddings are detected and cosine distance is used
|
||||
|
||||
### `test_diskann_partition.py`
|
||||
Tests DiskANN graph partitioning functionality:
|
||||
- Tests DiskANN index building without partitioning (baseline)
|
||||
- Tests automatic graph partitioning with `is_recompute=True`
|
||||
- Verifies that partition files are created and large files are cleaned up for storage saving
|
||||
- Tests search functionality with partitioned indices
|
||||
- Validates medoid and max_base_norm file generation and usage
|
||||
- Includes performance comparison between DiskANN (with partition) and HNSW
|
||||
- **Note**: These tests are skipped in CI due to hardware requirements and computation time
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Install test dependencies:
|
||||
```bash
|
||||
# Using extras
|
||||
uv pip install -e ".[test]"
|
||||
```
|
||||
|
||||
### Run all tests:
|
||||
```bash
|
||||
pytest tests/
|
||||
|
||||
# Or with coverage
|
||||
pytest tests/ --cov=leann --cov-report=html
|
||||
|
||||
# Run in parallel (faster)
|
||||
pytest tests/ -n auto
|
||||
```
|
||||
|
||||
### Run specific tests:
|
||||
```bash
|
||||
# Only basic tests
|
||||
pytest tests/test_basic.py
|
||||
|
||||
# Only tests that don't require OpenAI
|
||||
pytest tests/ -m "not openai"
|
||||
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Run DiskANN partition tests (requires local machine, not CI)
|
||||
pytest tests/test_diskann_partition.py
|
||||
```
|
||||
|
||||
### Run with specific backend:
|
||||
```bash
|
||||
# Test only HNSW backend
|
||||
pytest tests/test_basic.py::test_backend_basic[hnsw]
|
||||
pytest tests/test_readme_examples.py::test_readme_basic_example[hnsw]
|
||||
|
||||
# Test only DiskANN backend
|
||||
pytest tests/test_basic.py::test_backend_basic[diskann]
|
||||
pytest tests/test_readme_examples.py::test_readme_basic_example[diskann]
|
||||
|
||||
# All DiskANN tests (parametrized + specialized partition tests)
|
||||
pytest tests/ -k diskann
|
||||
```
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Tests are automatically run in GitHub Actions:
|
||||
1. After building wheel packages
|
||||
2. On multiple Python versions (3.9 - 3.13)
|
||||
3. On both Ubuntu and macOS
|
||||
4. Using pytest with appropriate markers and flags
|
||||
|
||||
### pytest.ini Configuration
|
||||
|
||||
The `pytest.ini` file configures:
|
||||
- Test discovery paths
|
||||
- Default timeout (600 seconds)
|
||||
- Environment variables (HF_HUB_DISABLE_SYMLINKS, TOKENIZERS_PARALLELISM)
|
||||
- Custom markers for slow and OpenAI tests
|
||||
- Verbose output with short tracebacks
|
||||
|
||||
### Known Issues
|
||||
|
||||
- OpenAI tests are automatically skipped if no API key is provided
|
||||
301
tests/conftest.py
Normal file
301
tests/conftest.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Global test configuration and cleanup fixtures."""
|
||||
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
# Enable faulthandler to dump stack traces
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _ci_backtraces():
|
||||
"""Dump stack traces before CI timeout to diagnose hanging."""
|
||||
if os.getenv("CI") == "true":
|
||||
# Dump stack traces 10s before the 180s timeout
|
||||
faulthandler.dump_traceback_later(170, repeat=True)
|
||||
yield
|
||||
faulthandler.cancel_dump_traceback_later()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def global_test_cleanup() -> Generator:
|
||||
"""Global cleanup fixture that runs after all tests.
|
||||
|
||||
This ensures all ZMQ connections and child processes are properly cleaned up,
|
||||
preventing the test runner from hanging on exit.
|
||||
"""
|
||||
yield
|
||||
|
||||
# Cleanup after all tests
|
||||
print("\n🧹 Running global test cleanup...")
|
||||
|
||||
# 1. Force cleanup of any LeannSearcher instances
|
||||
try:
|
||||
import gc
|
||||
|
||||
# Force garbage collection to trigger __del__ methods
|
||||
gc.collect()
|
||||
time.sleep(0.2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Set ZMQ linger but DON'T term Context.instance()
|
||||
# Terminating the global instance can block if other code still has sockets
|
||||
try:
|
||||
import zmq
|
||||
|
||||
# Just set linger on the global instance, don't terminate it
|
||||
ctx = zmq.Context.instance()
|
||||
ctx.linger = 0
|
||||
# Do NOT call ctx.term() or ctx.destroy() on the global instance!
|
||||
# That would block waiting for all sockets to close
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Kill any leftover child processes (including grandchildren)
|
||||
try:
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process()
|
||||
# Get ALL descendants recursively
|
||||
children = current_process.children(recursive=True)
|
||||
|
||||
if children:
|
||||
print(f"\n⚠️ Cleaning up {len(children)} leftover child processes...")
|
||||
|
||||
# First try to terminate gracefully
|
||||
for child in children:
|
||||
try:
|
||||
print(f" Terminating {child.pid} ({child.name()})")
|
||||
child.terminate()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Wait a bit for processes to terminate
|
||||
gone, alive = psutil.wait_procs(children, timeout=2)
|
||||
|
||||
# Force kill any remaining processes
|
||||
for child in alive:
|
||||
try:
|
||||
print(f" Force killing process {child.pid} ({child.name()})")
|
||||
child.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Final wait to ensure cleanup
|
||||
psutil.wait_procs(alive, timeout=1)
|
||||
except ImportError:
|
||||
# psutil not installed, try basic process cleanup
|
||||
try:
|
||||
# Send SIGTERM to all child processes
|
||||
os.killpg(os.getpgid(os.getpid()), signal.SIGTERM)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Warning: Error during process cleanup: {e}")
|
||||
|
||||
# List and clean up remaining threads
|
||||
try:
|
||||
import threading
|
||||
|
||||
threads = [t for t in threading.enumerate() if t is not threading.main_thread()]
|
||||
if threads:
|
||||
print(f"\n⚠️ {len(threads)} non-main threads still running:")
|
||||
for t in threads:
|
||||
print(f" - {t.name} (daemon={t.daemon})")
|
||||
|
||||
# Force cleanup of pytest-timeout threads that block exit
|
||||
if "pytest_timeout" in t.name and not t.daemon:
|
||||
print(f" 🔧 Converting pytest-timeout thread to daemon: {t.name}")
|
||||
try:
|
||||
t.daemon = True
|
||||
print(" ✓ Converted to daemon thread")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed: {e}")
|
||||
|
||||
# Check if only daemon threads remain
|
||||
non_daemon = [
|
||||
t for t in threading.enumerate() if t is not threading.main_thread() and not t.daemon
|
||||
]
|
||||
if non_daemon:
|
||||
print(f"\n⚠️ {len(non_daemon)} non-daemon threads still blocking exit")
|
||||
# Force exit in CI to prevent hanging
|
||||
if os.environ.get("CI") == "true":
|
||||
print("🔨 Forcing exit in CI environment...")
|
||||
os._exit(0)
|
||||
except Exception as e:
|
||||
print(f"Thread cleanup error: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auto_cleanup_searcher():
|
||||
"""Fixture that automatically cleans up LeannSearcher instances."""
|
||||
searchers = []
|
||||
|
||||
def register(searcher):
|
||||
"""Register a searcher for cleanup."""
|
||||
searchers.append(searcher)
|
||||
return searcher
|
||||
|
||||
yield register
|
||||
|
||||
# Cleanup all registered searchers
|
||||
for searcher in searchers:
|
||||
try:
|
||||
searcher.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force garbage collection
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def _reap_children():
|
||||
"""Reap all child processes at session end as a safety net."""
|
||||
yield
|
||||
|
||||
# Final aggressive cleanup
|
||||
try:
|
||||
import psutil
|
||||
|
||||
me = psutil.Process()
|
||||
kids = me.children(recursive=True)
|
||||
for p in kids:
|
||||
try:
|
||||
p.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_, alive = psutil.wait_procs(kids, timeout=2)
|
||||
for p in alive:
|
||||
try:
|
||||
p.kill()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_after_each_test():
|
||||
"""Cleanup after each test to prevent resource leaks."""
|
||||
yield
|
||||
|
||||
# Force garbage collection to trigger any __del__ methods
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
# Give a moment for async cleanup
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with better timeout handling."""
|
||||
# Set default timeout method to thread if not specified
|
||||
if not config.getoption("--timeout-method", None):
|
||||
config.option.timeout_method = "thread"
|
||||
|
||||
# Add more logging
|
||||
print(f"🔧 Pytest configured at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f" Python version: {os.sys.version}")
|
||||
print(f" Platform: {os.sys.platform}")
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
"""Called after the Session object has been created."""
|
||||
print(f"🏁 Pytest session starting at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f" Session ID: {id(session)}")
|
||||
|
||||
# Show initial process state
|
||||
try:
|
||||
import psutil
|
||||
|
||||
current = psutil.Process()
|
||||
print(f" Current PID: {current.pid}")
|
||||
print(f" Parent PID: {current.ppid()}")
|
||||
children = current.children(recursive=True)
|
||||
if children:
|
||||
print(f" ⚠️ Already have {len(children)} child processes at start!")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
"""Called after whole test run finished."""
|
||||
print(f"🏁 Pytest session finishing at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f" Exit status: {exitstatus}")
|
||||
|
||||
# Aggressive cleanup before pytest exits
|
||||
print("🧹 Starting aggressive cleanup...")
|
||||
|
||||
# First, clean up child processes
|
||||
try:
|
||||
import psutil
|
||||
|
||||
current = psutil.Process()
|
||||
children = current.children(recursive=True)
|
||||
|
||||
if children:
|
||||
print(f" Found {len(children)} child processes to clean up:")
|
||||
for child in children:
|
||||
try:
|
||||
print(f" - PID {child.pid}: {child.name()} (status: {child.status()})")
|
||||
child.terminate()
|
||||
except Exception as e:
|
||||
print(f" - Failed to terminate {child.pid}: {e}")
|
||||
|
||||
# Wait briefly then kill
|
||||
time.sleep(0.5)
|
||||
_, alive = psutil.wait_procs(children, timeout=1)
|
||||
|
||||
for child in alive:
|
||||
try:
|
||||
print(f" - Force killing {child.pid}")
|
||||
child.kill()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(" No child processes found")
|
||||
|
||||
except Exception as e:
|
||||
print(f" Process cleanup error: {e}")
|
||||
|
||||
# Second, clean up problematic threads
|
||||
try:
|
||||
import threading
|
||||
|
||||
threads = [t for t in threading.enumerate() if t is not threading.main_thread()]
|
||||
if threads:
|
||||
print(f" Found {len(threads)} non-main threads:")
|
||||
for t in threads:
|
||||
print(f" - {t.name} (daemon={t.daemon})")
|
||||
# Convert pytest-timeout threads to daemon so they don't block exit
|
||||
if "pytest_timeout" in t.name and not t.daemon:
|
||||
try:
|
||||
t.daemon = True
|
||||
print(" ✓ Converted to daemon")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Force exit if non-daemon threads remain in CI
|
||||
non_daemon = [
|
||||
t for t in threading.enumerate() if t is not threading.main_thread() and not t.daemon
|
||||
]
|
||||
if non_daemon and os.environ.get("CI") == "true":
|
||||
print(f" ⚠️ {len(non_daemon)} non-daemon threads remain, forcing exit...")
|
||||
os._exit(exitstatus or 0)
|
||||
|
||||
except Exception as e:
|
||||
print(f" Thread cleanup error: {e}")
|
||||
|
||||
print(f"✅ Pytest exiting at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
95
tests/test_basic.py
Normal file
95
tests/test_basic.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Basic functionality tests for CI pipeline using pytest.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from test_timeout import ci_timeout
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test that all packages can be imported."""
|
||||
|
||||
# Test C++ extensions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
@pytest.mark.parametrize("backend_name", ["hnsw", "diskann"])
|
||||
@ci_timeout(120) # 2 minute timeout for backend tests
|
||||
def test_backend_basic(backend_name):
|
||||
"""Test basic functionality for each backend."""
|
||||
from leann.api import LeannBuilder, LeannSearcher, SearchResult
|
||||
|
||||
# Create temporary directory for index
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"test.{backend_name}")
|
||||
|
||||
# Test with small data
|
||||
texts = [f"This is document {i} about topic {i % 5}" for i in range(100)]
|
||||
|
||||
# Configure builder based on backend
|
||||
if backend_name == "hnsw":
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
M=16,
|
||||
efConstruction=200,
|
||||
)
|
||||
else: # diskann
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
num_neighbors=32,
|
||||
search_list_size=50,
|
||||
)
|
||||
|
||||
# Add texts
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
# Build index
|
||||
builder.build_index(index_path)
|
||||
|
||||
# Test search
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search("document about topic 2", top_k=5)
|
||||
|
||||
# Verify results
|
||||
assert len(results) > 0
|
||||
assert isinstance(results[0], SearchResult)
|
||||
assert "topic 2" in results[0].text or "document" in results[0].text
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true", reason="Skip model tests in CI to avoid MPS memory issues"
|
||||
)
|
||||
@ci_timeout(180) # 3 minute timeout for large index test
|
||||
def test_large_index():
|
||||
"""Test with larger dataset."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / "test_large.hnsw")
|
||||
texts = [f"Document {i}: {' '.join([f'word{j}' for j in range(50)])}" for i in range(1000)]
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
|
||||
searcher = LeannSearcher(index_path)
|
||||
results = searcher.search(["word10 word20"], top_k=10)
|
||||
assert len(results[0]) == 10
|
||||
49
tests/test_ci_minimal.py
Normal file
49
tests/test_ci_minimal.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Minimal tests for CI that don't require model loading or significant memory.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_package_imports():
|
||||
"""Test that all core packages can be imported."""
|
||||
# Core package
|
||||
|
||||
# Backend packages
|
||||
|
||||
# Core modules
|
||||
|
||||
assert True # If we get here, imports worked
|
||||
|
||||
|
||||
def test_cli_help():
|
||||
"""Test that CLI example shows help."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, "apps/document_rag.py", "--help"], capture_output=True, text=True
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "usage:" in result.stdout.lower() or "usage:" in result.stderr.lower()
|
||||
assert "--llm" in result.stdout or "--llm" in result.stderr
|
||||
|
||||
|
||||
def test_backend_registration():
|
||||
"""Test that backends are properly registered."""
|
||||
from leann.api import get_registered_backends
|
||||
|
||||
backends = get_registered_backends()
|
||||
assert "hnsw" in backends
|
||||
assert "diskann" in backends
|
||||
|
||||
|
||||
def test_version_info():
|
||||
"""Test that packages have version information."""
|
||||
import leann
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
|
||||
# Check that packages have __version__ or can be imported
|
||||
assert hasattr(leann, "__version__") or True
|
||||
assert hasattr(leann_backend_hnsw, "__version__") or True
|
||||
assert hasattr(leann_backend_diskann, "__version__") or True
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user