Compare commits

..

1 Commits

Author SHA1 Message Date
Andy Lee
b92ec04178 refactor: move to apps 2025-07-22 22:18:17 -07:00
212 changed files with 18531 additions and 34728 deletions

View File

@@ -1,20 +0,0 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
inputs:
debug_enabled:
type: boolean
description: 'Run with tmate debugging enabled (SSH access to runner)'
required: false
default: false
jobs:
build:
uses: ./.github/workflows/build-reusable.yml
with:
debug_enabled: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled || false }}

View File

@@ -1,460 +0,0 @@
name: Reusable Build
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
debug_enabled:
description: 'Enable tmate debugging session for troubleshooting'
required: false
type: boolean
default: false
jobs:
lint:
name: Lint and Format Check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install ruff
run: |
uv tool install ruff==0.12.7
- name: Run ruff check
run: |
ruff check .
- name: Run ruff format check
run: |
ruff format --check .
build:
needs: lint
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
matrix:
include:
- os: ubuntu-22.04
python: '3.9'
- os: ubuntu-22.04
python: '3.10'
- os: ubuntu-22.04
python: '3.11'
- os: ubuntu-22.04
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
- os: macos-latest
python: '3.9'
- os: macos-latest
python: '3.10'
- os: macos-latest
python: '3.11'
- os: macos-latest
python: '3.12'
- os: macos-latest
python: '3.13'
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
# Don't install LLVM, use system clang for better compatibility
brew install libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Build packages
run: |
# Build core (platform independent) on all platforms for consistency
cd packages/leann-core
uv build
cd ../..
# Build HNSW backend
cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then
# Use system clang instead of homebrew LLVM for better compatibility
export CC=clang
export CXX=clang++
export MACOSX_DEPLOYMENT_TARGET=11.0
uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then
# Use system clang instead of homebrew LLVM for better compatibility
export CC=clang
export CXX=clang++
# sgesdd_ is only available on macOS 13.3+
export MACOSX_DEPLOYMENT_TARGET=13.3
uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build meta package (platform independent) on all platforms
cd packages/leann
uv build
cd ../..
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel - use show first to debug
cd packages/leann-backend-diskann
if [ -d dist ]; then
echo "Checking DiskANN wheel contents before repair:"
unzip -l dist/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found"
auditwheel show dist/*.whl || echo "auditwheel show failed"
auditwheel repair dist/*.whl -w dist_repaired
echo "Checking DiskANN wheel contents after repair:"
unzip -l dist_repaired/*.whl | grep -E "\.so|\.pyd|_diskannpy" || echo "No .so files found after repair"
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Install built packages for testing
run: |
# Create a virtual environment with the correct Python version
uv venv --python python${{ matrix.python }}
source .venv/bin/activate || source .venv/Scripts/activate
# Install the built wheels directly to ensure we use locally built packages
# Use only locally built wheels on all platforms for full consistency
FIND_LINKS="--find-links packages/leann-core/dist --find-links packages/leann/dist"
FIND_LINKS="$FIND_LINKS --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist"
uv pip install leann-core leann leann-backend-hnsw leann-backend-diskann \
$FIND_LINKS --force-reinstall
# Install test dependencies using extras
uv pip install -e ".[test]"
# Debug: Check if _diskannpy module is installed correctly
echo "Checking installed DiskANN module structure:"
python -c "import leann_backend_diskann; print('leann_backend_diskann location:', leann_backend_diskann.__file__)" || echo "Failed to import leann_backend_diskann"
python -c "from leann_backend_diskann import _diskannpy; print('_diskannpy imported successfully')" || echo "Failed to import _diskannpy"
ls -la $(python -c "import leann_backend_diskann; import os; print(os.path.dirname(leann_backend_diskann.__file__))" 2>/dev/null) 2>/dev/null || echo "Failed to list module directory"
# Extra debugging for Python 3.13
if [[ "${{ matrix.python }}" == "3.13" ]]; then
echo "=== Python 3.13 Debug Info ==="
echo "Python version details:"
python --version
python -c "import sys; print(f'sys.version_info: {sys.version_info}')"
echo "Pytest version:"
python -m pytest --version
echo "Testing basic pytest collection:"
if [[ "$RUNNER_OS" == "Linux" ]]; then
timeout --signal=INT 10 python -m pytest --collect-only tests/test_ci_minimal.py -v || echo "Collection timed out or failed"
else
# No timeout on macOS/Windows
python -m pytest --collect-only tests/test_ci_minimal.py -v || echo "Collection failed"
fi
echo "Testing single simple test:"
if [[ "$RUNNER_OS" == "Linux" ]]; then
timeout --signal=INT 10 python -m pytest tests/test_ci_minimal.py::test_package_imports --full-trace -v || echo "Simple test timed out or failed"
else
# No timeout on macOS/Windows
python -m pytest tests/test_ci_minimal.py::test_package_imports --full-trace -v || echo "Simple test failed"
fi
fi
# Enable tmate debugging session if requested
- name: Setup tmate session for debugging
if: ${{ inputs.debug_enabled }}
uses: mxschmitt/action-tmate@v3
with:
detached: true
timeout-minutes: 30
limit-access-to-actor: true
- name: Run tests with pytest
# Timeout hierarchy:
# 1. Individual test timeout: 20s (see pyproject.toml markers)
# 2. Pytest session timeout: 300s (see pyproject.toml [tool.pytest.ini_options])
# 3. Outer shell timeout: 360s (300s + 60s buffer for cleanup)
# 4. GitHub Actions job timeout: 6 hours (default)
env:
CI: true # Mark as CI environment to skip memory-intensive tests
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HF_HUB_DISABLE_SYMLINKS: 1
TOKENIZERS_PARALLELISM: false
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
MKL_NUM_THREADS: 1 # Single thread for MKL operations
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
# Define comprehensive diagnostic function
diag() {
echo "===== COMPREHENSIVE DIAGNOSTICS BEGIN ====="
date
echo ""
echo "### Current Shell Info ###"
echo "Shell PID: $$"
echo "Shell PPID: $PPID"
echo "Current directory: $(pwd)"
echo ""
echo "### Process Tree (full) ###"
pstree -ap 2>/dev/null || ps auxf || true
echo ""
echo "### All Python/Pytest Processes ###"
ps -ef | grep -E 'python|pytest' | grep -v grep || true
echo ""
echo "### Embedding Server Processes ###"
ps -ef | grep -E 'embedding|zmq|diskann' | grep -v grep || true
echo ""
echo "### Network Listeners ###"
ss -ltnp 2>/dev/null || netstat -ltn 2>/dev/null || true
echo ""
echo "### Open File Descriptors (lsof) ###"
lsof -p $$ 2>/dev/null | head -20 || true
echo ""
echo "### Zombie Processes ###"
ps aux | grep '<defunct>' || echo "No zombie processes"
echo ""
echo "### Current Jobs ###"
jobs -l || true
echo ""
echo "### /proc/PID/fd for current shell ###"
ls -la /proc/$$/fd 2>/dev/null || true
echo ""
echo "===== COMPREHENSIVE DIAGNOSTICS END ====="
}
# Enable verbose logging for debugging
export PYTHONUNBUFFERED=1
export PYTEST_CURRENT_TEST=1
# Run all tests with extensive logging
if [[ "$RUNNER_OS" == "Linux" ]]; then
echo "🚀 Starting Linux test execution with timeout..."
echo "Current time: $(date)"
echo "Shell PID: $$"
echo "Python: $(python --version)"
echo "Pytest: $(pytest --version)"
# Show environment variables for debugging
echo "📦 Environment variables:"
env | grep -E "PYTHON|PYTEST|CI|RUNNER" | sort
# Set trap for diagnostics
trap diag INT TERM EXIT
echo "📋 Pre-test diagnostics:"
ps -ef | grep -E 'python|pytest' | grep -v grep || echo "No python/pytest processes before test"
# Check for any listening ports before test
echo "🔌 Pre-test network state:"
ss -ltn 2>/dev/null | grep -E "555[0-9]|556[0-9]" || echo "No embedding server ports open"
# Set timeouts - outer must be larger than pytest's internal timeout
# IMPORTANT: Keep PYTEST_TIMEOUT_SEC in sync with pyproject.toml [tool.pytest.ini_options] timeout
PYTEST_TIMEOUT_SEC=${PYTEST_TIMEOUT_SEC:-300} # Default 300s, matches pyproject.toml
BUFFER_SEC=${TIMEOUT_BUFFER_SEC:-60} # Buffer for cleanup after pytest timeout
OUTER_TIMEOUT_SEC=${OUTER_TIMEOUT_SEC:-$((PYTEST_TIMEOUT_SEC + BUFFER_SEC))}
echo "⏰ Timeout configuration:"
echo " - Pytest internal timeout: ${PYTEST_TIMEOUT_SEC}s (from pyproject.toml)"
echo " - Cleanup buffer: ${BUFFER_SEC}s"
echo " - Outer shell timeout: ${OUTER_TIMEOUT_SEC}s (${PYTEST_TIMEOUT_SEC}s + ${BUFFER_SEC}s buffer)"
echo " - This ensures pytest can complete its own timeout handling and cleanup"
echo "🏃 Running pytest with ${OUTER_TIMEOUT_SEC}s outer timeout..."
# Export for inner shell
export PYTEST_TIMEOUT_SEC OUTER_TIMEOUT_SEC BUFFER_SEC
timeout --preserve-status --signal=INT --kill-after=10 ${OUTER_TIMEOUT_SEC} bash -c '
echo "⏱️ Pytest starting at: $(date)"
echo "Running command: pytest tests/ -vv --maxfail=3 --tb=short --capture=no"
# Run pytest with maximum verbosity and no output capture
pytest tests/ -vv --maxfail=3 --tb=short --capture=no --log-cli-level=DEBUG 2>&1 | tee pytest.log
PYTEST_EXIT=${PIPESTATUS[0]}
echo "✅ Pytest finished at: $(date) with exit code: $PYTEST_EXIT"
echo "Last 20 lines of pytest output:"
tail -20 pytest.log || true
# Immediately check for leftover processes
echo "🔍 Post-pytest process check:"
ps -ef | grep -E "python|pytest|embedding" | grep -v grep || echo "No leftover processes"
# Clean up any children before exit
echo "🧹 Cleaning up child processes..."
pkill -TERM -P $$ 2>/dev/null || true
sleep 0.5
pkill -KILL -P $$ 2>/dev/null || true
echo "📊 Final check before exit:"
ps -ef | grep -E "python|pytest|embedding" | grep -v grep || echo "All clean"
exit $PYTEST_EXIT
'
EXIT_CODE=$?
echo "🔚 Timeout command exited with code: $EXIT_CODE"
if [ $EXIT_CODE -eq 124 ]; then
echo "⚠️ TIMEOUT TRIGGERED - Tests took more than ${OUTER_TIMEOUT_SEC} seconds!"
echo "📸 Capturing full diagnostics..."
diag
# Run diagnostic script if available
if [ -f scripts/diagnose_hang.sh ]; then
echo "🔍 Running diagnostic script..."
bash scripts/diagnose_hang.sh || true
fi
# More aggressive cleanup
echo "💀 Killing all Python processes owned by runner..."
pkill -9 -u runner python || true
pkill -9 -u runner pytest || true
elif [ $EXIT_CODE -ne 0 ]; then
echo "❌ Tests failed with exit code: $EXIT_CODE"
else
echo "✅ All tests passed!"
fi
# Always show final state
echo "📍 Final state check:"
ps -ef | grep -E 'python|pytest|embedding' | grep -v grep || echo "No Python processes remaining"
exit $EXIT_CODE
else
# For macOS/Windows, run without GNU timeout
echo "🚀 Running tests on $RUNNER_OS..."
pytest tests/ -vv --maxfail=3 --tb=short --capture=no --log-cli-level=INFO
fi
# Provide tmate session on test failure for debugging
- name: Setup tmate session on failure
if: ${{ failure() && (inputs.debug_enabled || contains(github.event.head_commit.message, '[debug]')) }}
uses: mxschmitt/action-tmate@v3
with:
timeout-minutes: 30
limit-access-to-actor: true
- name: Run sanity checks (optional)
run: |
# Activate virtual environment
source .venv/bin/activate || source .venv/Scripts/activate
# Run distance function tests if available
if [ -f test/sanity_checks/test_distance_functions.py ]; then
echo "Running distance function sanity checks..."
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
fi
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/

View File

@@ -1,19 +0,0 @@
name: Link Check
on:
push:
branches: [ main, master ]
pull_request:
schedule:
- cron: "0 3 * * 1"
jobs:
link-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: lycheeverse/lychee-action@v2
with:
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,129 +0,0 @@
name: Release
on:
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
jobs:
update-version:
name: Update Version
runs-on: ubuntu-latest
permissions:
contents: write
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
# Remove 'v' prefix if present for validation
VERSION_CLEAN="${{ inputs.version }}"
VERSION_CLEAN="${VERSION_CLEAN#v}"
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
exit 1
fi
echo "✅ Version format valid: ${{ inputs.version }}"
- name: Update versions and push
id: push
run: |
# Check current version
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD)
else
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA"
fi
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
build-packages:
name: Build packages
needs: update-version
uses: ./.github/workflows/build-reusable.yml
with:
ref: 'main'
publish:
name: Publish and Release
needs: [update-version, build-packages]
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
ref: 'main'
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist-artifacts
- name: Collect packages
run: |
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
if [ -z "$TWINE_PASSWORD" ]; then
echo "❌ PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create release
run: |
# Check if tag already exists
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
else
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}"
fi
# Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
else
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
echo "✅ Created GitHub release v${{ inputs.version }}"
fi
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

21
.gitignore vendored
View File

@@ -9,9 +9,10 @@ demo/indices/
outputs/ outputs/
*.pkl *.pkl
*.pdf *.pdf
*.idx *.idx
*.map *.map
.history/ .history/
scripts/
lm_eval.egg-info/ lm_eval.egg-info/
demo/experiment_results/**/*.json demo/experiment_results/**/*.json
*.jsonl *.jsonl
@@ -34,15 +35,11 @@ build/
nprobe_logs/ nprobe_logs/
micro/results micro/results
micro/contriever-INT8 micro/contriever-INT8
data/* examples/data/*
!data/2501.14312v1 (1).pdf !examples/data/2501.14312v1 (1).pdf
!data/2506.08276v1.pdf !examples/data/2506.08276v1.pdf
!data/PrideandPrejudice.txt !examples/data/PrideandPrejudice.txt
!data/huawei_pangu.md !examples/data/README.md
!data/ground_truth/
!data/indices/
!data/queries/
!data/.gitattributes
*.qdstrm *.qdstrm
benchmark_results/ benchmark_results/
results/ results/
@@ -89,6 +86,4 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json *.meta.json
*.passages.json *.passages.json
batchtest.py batchtest.py
tests/__pytest_cache__/
tests/__pycache__/

View File

@@ -1,16 +0,0 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- id: debug-statements
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7 # Fixed version to match pyproject.toml
hooks:
- id: ruff
- id: ruff-format

587
README.md
View File

@@ -6,21 +6,17 @@
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+"> <img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License"> <img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform"> <img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue?style=flat-square" alt="MCP Integration">
</p> </p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto"> <h2 align="center" tabindex="-1" class="heading-element" dir="auto">
The smallest vector index in the world. RAG Everything with LEANN! The smallest vector index in the world. RAG Everything with LEANN!
</h2> </h2>
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**. LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig ](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276) LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy. **Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
@@ -30,123 +26,57 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%"> <img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p> </p>
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison) > **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service". 🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage! 🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory! 📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage. **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Installation ## Quick Start in 1 minute
### 📦 Prerequisites: Install uv
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
```bash ```bash
curl -LsSf https://astral.sh/uv/install.sh | sh git clone git@github.com:yichuan-w/LEANN.git leann
```
### 🚀 Quick Install
Clone the repository to access all examples and try amazing applications,
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
```
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
```bash
uv venv
source .venv/bin/activate
uv pip install leann
```
<details>
<summary>
<strong>🔧 Build from Source (Recommended for development)</strong>
</summary>
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann cd leann
git submodule update --init --recursive git submodule update --init --recursive
``` ```
**macOS:** **macOS:**
```bash ```bash
brew install llvm libomp boost protobuf zeromq pkgconf brew install llvm libomp boost protobuf zeromq
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync export CC=$(brew --prefix llvm)/bin/clang
export CXX=$(brew --prefix llvm)/bin/clang++
# Install with HNSW backend (default, recommended for most users)
uv sync
# Or add DiskANN backend if you want to test more options
uv sync --extra diskann
``` ```
**Linux:** **Linux (Ubuntu/Debian):**
```bash ```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
# Install with HNSW backend (default, recommended for most users)
uv sync uv sync
# Or add DiskANN backend if you want to test more options
uv sync --extra diskann
``` ```
</details>
## Quick Start **Ollama Setup (Recommended for full privacy):**
Our declarative API makes RAG as easy as writing a config file. > *You can skip this installation if you only want to use OpenAI API for generation.*
Check out [demo.ipynb](demo.ipynb) or [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python *macOS:*
from leann import LeannBuilder, LeannSearcher, LeannChat
from pathlib import Path
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their bananacrocodile hybrid back")
builder.build_index(INDEX_PATH)
# Search
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Chat with your data
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
```
## RAG on Everything!
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
### Generation Model Setup
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
<details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
Set your OpenAI API key as an environment variable:
```bash
export OPENAI_API_KEY="your-api-key-here"
```
</details>
<details>
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
**macOS:**
First, [download Ollama for macOS](https://ollama.com/download/mac). First, [download Ollama for macOS](https://ollama.com/download/mac).
@@ -155,8 +85,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
ollama pull llama3.2:1b ollama pull llama3.2:1b
``` ```
**Linux:** *Linux:*
```bash ```bash
# Install Ollama # Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh curl -fsSL https://ollama.ai/install.sh | sh
@@ -168,120 +97,90 @@ ollama serve &
ollama pull llama3.2:1b ollama pull llama3.2:1b
``` ```
</details> ## Dead Simple API
### ⭐ Flexible Configuration Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs. ```python
from leann.api import LeannBuilder, LeannSearcher, LeannChat
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality. # 1. Build the index (no embeddings stored!)
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("C# is a powerful programming language")
builder.add_text("Python is a powerful programming language and it is very popular")
builder.add_text("Machine learning transforms industries")
builder.add_text("Neural networks process complex data")
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
builder.build_index("knowledge.leann")
# 2. Search with real-time embeddings
searcher = LeannSearcher("knowledge.leann")
results = searcher.search("programming languages", top_k=2)
# 3. Chat with LEANN using retrieved results
llm_config = {
"type": "ollama",
"model": "llama3.2:1b"
}
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
response = chat.ask(
"Compare the two retrieved programming languages and say which one is more popular today.",
top_k=2,
)
```
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
[Try the interactive demo →](demo.ipynb)
## Wild Things You Can Do
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
### Process Any Documents (.pdf, .txt, .md)
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents, and even any directory that stores your personal files!
The following scripts use Ollama `qwen3:8b` by default, so you need `ollama pull qwen3:8b` first. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
```bash
# Drop your PDFs, .txt, .md files into apps/documents/data/
python -m apps.documents
# Or with uv
uv run python -m apps.documents
```
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
### Search Your Entire Life
```bash
python -m apps.email
# "What's the number of class recommend to take per semester for incoming EECS students?"
```
**90K emails → 14MB.** Finally, search your email like you search Google.
<details> <details>
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary> <summary><strong>📋 Click to expand: Command Examples</strong></summary>
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
```bash ```bash
# Core Parameters (General preprocessing for all examples) # Use default mail path (works for most macOS setups)
--index-dir DIR # Directory to store the index (default: current directory) python -m apps.email
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
--max-items N # Limit data preprocessing (default: -1, process all data)
--force-rebuild # Force rebuild index even if it exists
# Embedding Parameters # Run with custom index directory
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, nomic-embed-text, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text python -m apps.email --index-dir "./my_mail_index"
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
# LLM Parameters (Text generation models) # Process all emails (may take time but indexes everything)
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai) python -m apps.email --max-emails -1
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
# Search Parameters # Limit number of emails processed (useful for testing)
--top-k N # Number of results to retrieve (default: 20) python -m apps.email --max-emails 1000
--search-complexity N # Search complexity for graph traversal (default: 32)
# Chunking Parameters # Run a single query
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat) python -m apps.email --query "What did my boss say about deadlines?"
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
# Index Building Parameters
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
--graph-degree N # Graph degree for index construction (default: 32)
--build-complexity N # Build complexity for index construction (default: 64)
--no-compact # Disable compact index storage (compact storage IS enabled to save storage by default)
--no-recompute # Disable embedding recomputation (recomputation IS enabled to save storage by default)
```
</details>
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
<p align="center">
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
</p>
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
```bash
source .venv/bin/activate # Don't forget to activate the virtual environment
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
```
<details>
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
#### Parameters
```bash
--data-dir DIR # Directory containing documents to process (default: data)
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
```
#### Example Commands
```bash
# Process all documents with larger chunks for academic papers
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
# Filter only markdown and Python files with smaller chunks
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
```
</details>
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
> **Note:** The examples below currently support macOS only. Windows support coming soon.
<p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p>
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
```
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
#### Parameters
```bash
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
--include-html # Include HTML content in processing (useful for newsletters)
```
#### Example Commands
```bash
# Search work emails from a specific account
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
# Find all receipts and order confirmations (includes HTML)
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
``` ```
</details> </details>
@@ -295,32 +194,28 @@ Once the index is built, you can ask questions like:
- "Show me emails about travel expenses" - "Show me emails about travel expenses"
</details> </details>
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History! ### Time Machine for the Web
<p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
</p>
```bash ```bash
python -m apps.browser_rag --query "Tell me my browser history about machine learning?" python -m apps.browser
# "Tell me my browser history about machine learning system stuff?"
``` ```
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine. **38K browser entries → 6MB.** Your browser history becomes your personal search engine.
<details> <details>
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary> <summary><strong>📋 Click to expand: Command Examples</strong></summary>
#### Parameters
```bash ```bash
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted) # Use default Chrome profile (auto-finds all profiles)
``` python -m apps.browser
#### Example Commands # Run with custom index directory
```bash python -m apps.browser --index-dir "./my_chrome_index"
# Search academic research from your browsing history
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
# Track competitor analysis across work profile # Limit number of history entries processed (useful for testing)
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000 python -m apps.browser --max-entries 500
# Run a single query
python -m apps.browser --query "What websites did I visit about machine learning?"
``` ```
</details> </details>
@@ -353,58 +248,44 @@ Once the index is built, you can ask questions like:
</details> </details>
### 💬 WeChat Detective: Unlock Your Golden Memories! ### WeChat Detective
<p align="center">
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
</p>
```bash ```bash
python -m apps.wechat_rag --query "Show me all group chats about weekend plans" python -m apps.wechat
# "Show me all group chats about weekend plans"
``` ```
**400K messages → 64MB storage** Search years of chat history in any language. **400K messages → 64MB.** Search years of chat history in any language.
<details> <details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary> <summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI), First, you need to install the WeChat exporter:
```bash
brew install sunnyyoung/repo/wechattweak-cli
```
or install it manually (if you have issues with Homebrew):
```bash ```bash
sudo packages/wechat-exporter/wechattweak-cli install sudo packages/wechat-exporter/wechattweak-cli install
``` ```
**Troubleshooting:** **Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
- **Export errors**: If you encounter the error below, try restarting WeChat
```bash
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
</details> </details>
<details> <details>
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary> <summary><strong>📋 Click to expand: Command Examples</strong></summary>
#### Parameters
```bash ```bash
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct) # Use default settings (recommended for first run)
--force-export # Force re-export even if data exists python -m apps.wechat
```
#### Example Commands # Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
```bash python -m apps.wechat --export-dir "./my_wechat_exports"
# Search for travel plans discussed in group chats
python -m apps.wechat_rag --query "travel plans" --max-items 10000
# Re-export and search recent chats (useful after new messages) # Run with custom index directory
python -m apps.wechat_rag --force-export --query "work schedule" python -m apps.wechat --index-dir "./my_wechat_index"
# Limit number of chat entries processed (useful for testing)
python -m apps.wechat --max-entries 1000
# Run a single query
python -m apps.wechat --query "Show me conversations about travel plans"
``` ```
</details> </details>
@@ -418,59 +299,17 @@ Once the index is built, you can ask questions like:
</details> </details>
### 🚀 Claude Code Integration: Transform Your Development Workflow!
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
**Key features:**
- 🔍 **Semantic code search** across your entire project
- 📚 **Context-aware assistance** for debugging and development
- 🚀 **Zero-config setup** with automatic language detection
```bash
# Install LEANN globally for MCP integration
uv tool install leann-core
# Setup is automatic - just start using Claude Code!
```
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
![LEANN MCP Integration](assets/mcp_leann.png)
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
## 🖥️ Command Line Interface ## 🖥️ Command Line Interface
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat. LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
### Installation
If you followed the Quick Start, `leann` is already installed in your virtual environment:
```bash ```bash
source .venv/bin/activate # Build an index from documents
leann --help leann build my-docs --docs ./documents
```
**To make it globally available:** # Search your documents
```bash
# Install the LEANN CLI globally using uv tool
uv tool install leann-core
# Now you can use leann from anywhere without activating venv
leann --help
```
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
### Usage Examples
```bash
# build from a specific directory, and my_docs is the index name
leann build my-docs --docs ./your_documents
# Search your documents
leann search my-docs "machine learning concepts" leann search my-docs "machine learning concepts"
# Interactive chat with your documents # Interactive chat with your documents
@@ -538,35 +377,60 @@ Options:
**Core techniques:** **Core techniques:**
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path - **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections - **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization - **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
- **Two-level search:** Smart graph traversal that prioritizes promising nodes - **Two-level search:** Smart graph traversal that prioritizes promising nodes
**Backends:** **Backends:** DiskANN or HNSW - pick what works for your data size.
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
## Benchmarks ## Benchmarks
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends Run the comparison yourself:
```bash
python -m apps.benchmarks
```
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action | System | Storage |
|--------|---------|
| FAISS HNSW | 5.5 MB |
| LEANN | 0.5 MB |
| **Savings** | **91%** |
### 📊 Storage Comparison Same dataset, same hardware, same embedding model. LEANN just works better.
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|--------|-------------|------------|-------------|--------------|---------------|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
| Savings| 91% | 97% | 97% | 97% | 95% |
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## Reproduce Our Results ## Reproduce Our Results
```bash ```bash
uv pip install -e ".[dev]" # Install dev dependencies uv pip install -e ".[dev]" # Install dev dependencies
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks python -m apps.evaluation data/indices/dpr/dpr_diskann # DPR dataset
python -m apps.evaluation data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
``` ```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data! The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
@@ -578,25 +442,108 @@ If you find Leann useful, please cite:
```bibtex ```bibtex
@misc{wang2025leannlowstoragevectorindex, @misc{wang2025leannlowstoragevectorindex,
title={LEANN: A Low-Storage Vector Index}, title={LEANN: A Low-Storage Vector Index},
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez}, author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
year={2025}, year={2025},
eprint={2506.08276}, eprint={2506.08276},
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.DB}, primaryClass={cs.DB},
url={https://arxiv.org/abs/2506.08276}, url={https://arxiv.org/abs/2506.08276},
} }
``` ```
## ✨ [Detailed Features →](docs/features.md) ## ✨ Features
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md) ### 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
### 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## ❓ [FAQ →](docs/faq.md) <!-- ## ❓ FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
``` -->
## FAQ
### 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## 📈 [Roadmap →](docs/roadmap.md) ## 📈 Roadmap
### 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
### 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
### 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion
## 📄 License ## 📄 License
@@ -604,11 +551,10 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments ## 🙏 Acknowledgments
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf). - **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
We welcome more contributors! Feel free to open issues or submit PRs. - **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
--- ---
@@ -619,3 +565,4 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
<p align="center"> <p align="center">
Made with ❤️ by the Leann team Made with ❤️ by the Leann team
</p> </p>

View File

@@ -1,324 +0,0 @@
"""
Base class for unified RAG examples interface.
Provides common parameters and functionality for all RAG examples.
"""
import argparse
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
class BaseRAGExample(ABC):
"""Base class for all RAG examples with unified interface."""
def __init__(
self,
name: str,
description: str,
default_index_name: str,
):
self.name = name
self.description = description
self.default_index_name = default_index_name
self.parser = self._create_parser()
def _create_parser(self) -> argparse.ArgumentParser:
"""Create argument parser with common parameters."""
parser = argparse.ArgumentParser(
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
)
# Core parameters (all examples share these)
core_group = parser.add_argument_group("Core Parameters")
core_group.add_argument(
"--index-dir",
type=str,
default=f"./{self.default_index_name}",
help=f"Directory to store the index (default: ./{self.default_index_name})",
)
core_group.add_argument(
"--query",
type=str,
default=None,
help="Query to run (if not provided, will run in interactive mode)",
)
# Allow subclasses to override default max_items
max_items_default = getattr(self, "max_items_default", -1)
core_group.add_argument(
"--max-items",
type=int,
default=max_items_default,
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
)
core_group.add_argument(
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
)
# Embedding parameters
embedding_group = parser.add_argument_group("Embedding Parameters")
# Allow subclasses to override default embedding_model
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
embedding_group.add_argument(
"--embedding-model",
type=str,
default=embedding_model_default,
help=f"Embedding model to use (default: {embedding_model_default})",
)
embedding_group.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
llm_group.add_argument(
"--llm",
type=str,
default="openai",
choices=["openai", "ollama", "hf", "simulated"],
help="LLM backend to use (default: openai)",
)
llm_group.add_argument(
"--llm-model",
type=str,
default=None,
help="LLM model name (default: gpt-4o for openai, llama3.2:1b for ollama)",
)
llm_group.add_argument(
"--llm-host",
type=str,
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
)
llm_group.add_argument(
"--thinking-budget",
type=str,
choices=["low", "medium", "high"],
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
# Search parameters
search_group = parser.add_argument_group("Search Parameters")
search_group.add_argument(
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
)
search_group.add_argument(
"--search-complexity",
type=int,
default=32,
help="Search complexity for graph traversal (default: 64)",
)
# Index building parameters
index_group = parser.add_argument_group("Index Building Parameters")
index_group.add_argument(
"--backend-name",
type=str,
default="hnsw",
choices=["hnsw", "diskann"],
help="Backend to use for index (default: hnsw)",
)
index_group.add_argument(
"--graph-degree",
type=int,
default=32,
help="Graph degree for index construction (default: 32)",
)
index_group.add_argument(
"--build-complexity",
type=int,
default=64,
help="Build complexity for index construction (default: 64)",
)
index_group.add_argument(
"--no-compact",
action="store_true",
help="Disable compact index storage",
)
index_group.add_argument(
"--no-recompute",
action="store_true",
help="Disable embedding recomputation",
)
# Add source-specific parameters
self._add_specific_arguments(parser)
return parser
@abstractmethod
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
"""Add source-specific arguments. Override in subclasses."""
pass
@abstractmethod
async def load_data(self, args) -> list[str]:
"""Load data from the source. Returns list of text chunks."""
pass
def get_llm_config(self, args) -> dict[str, Any]:
"""Get LLM configuration based on arguments."""
config = {"type": args.llm}
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
# Simulated LLM doesn't need additional configuration
pass
return config
async def build_index(self, args, texts: list[str]) -> str:
"""Build LEANN index from texts."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
is_recompute=not args.no_recompute,
num_threads=1, # Force single-threaded mode
)
# Add texts in batches for better progress tracking
batch_size = 1000
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
for text in batch:
builder.add_text(text)
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
print("Building index structure...")
builder.build_index(index_path)
print(f"Index saved to: {index_path}")
return index_path
async def run_interactive_chat(self, args, index_path: str):
"""Run interactive chat with the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
print("Type 'quit' or 'exit' to stop.\n")
while True:
try:
query = input("You: ").strip()
if query.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
if not query:
continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.search_complexity,
llm_kwargs=llm_kwargs,
)
print(f"\nAssistant: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
async def run_single_query(self, args, index_path: str, query: str):
"""Run a single query against the index."""
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)
print(f"\n[Query]: \033[36m{query}\033[0m")
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if hasattr(args, "thinking_budget") and args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
)
print(f"\n[Response]: \033[36m{response}\033[0m")
async def run(self):
"""Main entry point for the example."""
args = self.parser.parse_args()
# Check if index exists
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
index_exists = Path(args.index_dir).exists()
if not index_exists or args.force_rebuild:
# Load data and build index
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
texts = await self.load_data(args)
if not texts:
print("No data found to index!")
return
index_path = await self.build_index(args, texts)
else:
print(f"\nUsing existing index in {args.index_dir}")
# Run query or interactive mode
if args.query:
await self.run_single_query(args, index_path, args.query)
else:
await self.run_interactive_chat(args, index_path)
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
"""Helper function to create text chunks from documents."""
node_parser = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=" ",
paragraph_separator="\n\n",
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
all_texts.extend(node.get_content() for node in nodes)
return all_texts

View File

338
apps/benchmarks/__main__.py Normal file
View File

@@ -0,0 +1,338 @@
#!/usr/bin/env python3
"""
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import logging
import os
import sys
import time
import psutil
import gc
import subprocess
from pathlib import Path
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def print_memory_stats(stage: str, start_mem: float):
"""Print memory statistics"""
current_mem = get_memory_usage()
diff = current_mem - start_mem
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
return current_mem
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
print(f"\n=== {self.name} Memory Summary ===")
for stage, mem in self.stages:
print(f"{stage}: {mem:.1f} MB")
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
return peak_mem
def test_faiss_hnsw():
"""Test Faiss HNSW Vector Store in subprocess"""
print("\n" + "=" * 50)
print("TESTING FAISS HNSW VECTOR STORE")
print("=" * 50)
try:
# Get the directory of this script
script_dir = Path(__file__).parent
faiss_script = script_dir / "faiss_only.py"
result = subprocess.run(
[sys.executable, str(faiss_script)],
capture_output=True,
text=True,
timeout=300,
)
print(result.stdout)
if result.stderr:
print("Stderr:", result.stderr)
if result.returncode != 0:
return {
"peak_memory": float("inf"),
"error": f"Process failed with code {result.returncode}",
}
# Parse peak memory from output
lines = result.stdout.split("\n")
peak_memory = 0.0
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(
line.split("Peak Memory:")[1].split("MB")[0].strip()
)
return {"peak_memory": peak_memory}
except Exception as e:
return {
"peak_memory": float("inf"),
"error": str(e),
}
def test_leann_hnsw():
"""Test LEANN HNSW Search Memory (load existing index)"""
print("\n" + "=" * 50)
print("TESTING LEANN HNSW SEARCH MEMORY")
print("=" * 50)
tracker = MemoryTracker("LEANN HNSW Search")
# Import and setup
tracker.checkpoint("Initial")
from leann.api import LeannSearcher
tracker.checkpoint("After imports")
from llama_index.core import SimpleDirectoryReader
from leann.api import LeannBuilder, LeannSearcher
# Load and parse documents
documents = SimpleDirectoryReader(
"../documents/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
tracker.checkpoint("After text chunking")
# Build LEANN index
INDEX_DIR = Path("./test_leann_comparison")
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
# Check if index already exists
if os.path.exists(INDEX_PATH + ".meta.json"):
print("Loading existing LEANN HNSW index...")
tracker.checkpoint("After loading existing index")
else:
print("Building new LEANN HNSW index...")
# Clean up previous index
import shutil
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1,
)
tracker.checkpoint("After builder setup")
print("Building LEANN HNSW index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
del builder
gc.collect()
tracker.checkpoint("After index building")
# Find existing LEANN index
index_paths = [
"./test_leann_comparison/comparison.leann",
]
index_path = None
for path in index_paths:
if os.path.exists(path + ".meta.json"):
index_path = path
break
if not index_path:
print("❌ LEANN index not found. Please build it first")
return {"peak_memory": float("inf"), "error": "Index not found"}
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
_ = searcher.search(query, top_k=20, ef=120)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
# Get storage size before cleanup
storage_size = 0
INDEX_DIR = Path(index_path).parent
if INDEX_DIR.exists():
total_size = 0
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
for filename in filenames:
# Only count actual index files, skip text data and backups
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
continue
# Count .index, .idx, .map files (actual index structures)
if filename.endswith((".index", ".idx", ".map")):
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
storage_size = total_size / (1024 * 1024) # Convert to MB
# Clean up
del searcher
gc.collect()
return {
"peak_memory": peak_memory,
"storage_size": storage_size,
}
def main():
"""Run comparison tests"""
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
print("=" * 60)
# Test Faiss HNSW
faiss_results = test_faiss_hnsw()
# Force garbage collection
gc.collect()
time.sleep(2)
# Test LEANN HNSW
leann_results = test_leann_hnsw()
# Final comparison
print("\n" + "=" * 60)
print("STORAGE + SEARCH MEMORY COMPARISON")
print("=" * 60)
# Get storage sizes
faiss_storage_size = 0
leann_storage_size = leann_results.get("storage_size", 0)
# Get Faiss storage size using Python
if os.path.exists("./storage_faiss"):
total_size = 0
for dirpath, _, filenames in os.walk("./storage_faiss"):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
print("Faiss HNSW:")
if "error" in faiss_results:
print(f" ❌ Failed: {faiss_results['error']}")
else:
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f" Storage Size: {faiss_storage_size:.1f} MB")
print("\nLEANN HNSW:")
if "error" in leann_results:
print(f" ❌ Failed: {leann_results['error']}")
else:
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f" Storage Size: {leann_storage_size:.1f} MB")
# Calculate improvements only if both tests succeeded
if "error" not in faiss_results and "error" not in leann_results:
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
)
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
)
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
)
else:
print(" Storage Size: similar")
else:
if "error" not in leann_results:
print("\n✅ LEANN HNSW completed successfully!")
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
if "error" not in faiss_results:
print("\n✅ Faiss HNSW completed successfully!")
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,151 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import sys
import time
import psutil
import gc
import os
def get_memory_usage():
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = get_memory_usage()
diff = current_mem - self.start_mem
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
return peak_mem
def main():
try:
import faiss
except ImportError:
print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu`")
sys.exit(1)
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings,
node_parser,
Document,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
Settings.embed_model = embed_model
tracker.checkpoint("After embedding model setup")
d = 768
faiss_index = faiss.IndexHNSWFlat(d, 32)
faiss_index.hnsw.efConstruction = 64
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"../documents/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks using the same splitter as LEANN
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
tracker.checkpoint("After text splitter setup")
# Check if index already exists and try to load it
index_loaded = False
if os.path.exists("./storage_faiss"):
print("Loading existing Faiss HNSW index...")
try:
# Use the correct Faiss loading pattern from the example
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print(f"Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
print(f"Failed to load existing index: {e}")
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
transformations=[node_parser]
)
tracker.checkpoint("After index building")
# Save index to disk using the correct pattern
index.storage_context.persist(persist_dir="./storage_faiss")
tracker.checkpoint("After index saving")
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
_ = query_engine.query(query)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
if __name__ == "__main__":
main()

0
apps/browser/__init__.py Normal file
View File

201
apps/browser/__main__.py Normal file
View File

@@ -0,0 +1,201 @@
import os
import asyncio
import argparse
try:
import dotenv
dotenv.load_dotenv()
except ModuleNotFoundError:
# python-dotenv is not installed; skip loading environment variables
dotenv = None
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
# Default Chrome profile path
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
"""
Create LEANN index from multiple Chrome profile data sources.
Args:
profile_dirs: List of Path objects pointing to Chrome profile directories
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process per profile
"""
print("Creating LEANN index from multiple Chrome profile data sources...")
# Load documents using ChromeHistoryReader from local readers module
from .readers import ChromeHistoryReader
reader = ChromeHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Chrome profile directory
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
try:
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_count
)
if documents:
print(f"Loaded {len(documents)} history documents from {profile_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {profile_dir}")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={
"temperature": 0.0,
"max_tokens": 1000
}
)
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')
parser.add_argument('--query', type=str, default=None,
help='Single query to run (default: runs example queries)')
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
help='Automatically find all Chrome profiles (default: True)')
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
print(f"Using Chrome profile: {args.chrome_profile}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Find Chrome profile directories
from .readers import ChromeHistoryReader
if args.auto_find_profiles:
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return
else:
# Use single specified profile
profile_path = Path(args.chrome_profile)
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"What websites did I visit about machine learning?",
"Find my search history about programming"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

176
apps/browser/readers.py Normal file
View File

@@ -0,0 +1,176 @@
import sqlite3
import os
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader):
"""
Chrome browser history reader that extracts browsing data from SQLite database.
Reads Chrome history from the default Chrome profile location and creates documents
with embedded metadata similar to the email reader structure.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load Chrome history data from the default Chrome profile location.
Args:
input_dir: Not used for Chrome history (kept for compatibility)
**load_kwargs:
max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
# Default Chrome profile path on macOS
if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return docs
try:
# Connect to the Chrome history database
print(f"Connecting to database: {history_db_path}")
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
# Query to get browsing history with metadata (removed created_time column)
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
"""
print(f"Executing query on database: {history_db_path}")
cursor.execute(query)
rows = cursor.fetchall()
print(f"Query returned {len(rows)} rows")
count = 0
for row in rows:
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text
doc_content = f"""
[BROWSING HISTORY METADATA]
URL: {url}
Title: {title}
Last Visit: {last_visit}
Visit Count: {visit_count}
Typed Count: {typed_count}
Hidden: {hidden}
[END METADATA]
Title: {title}
URL: {url}
Last visited: {last_visit}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
conn.close()
print(f"Loaded {len(docs)} Chrome history documents")
except Exception as e:
print(f"Error reading Chrome history: {e}")
return docs
return docs
@staticmethod
def find_chrome_profiles() -> List[Path]:
"""
Find all Chrome profile directories.
Returns:
List of Path objects pointing to Chrome profile directories
"""
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
profile_dirs = []
if not chrome_base_path.exists():
print(f"Chrome directory not found at: {chrome_base_path}")
return profile_dirs
# Find all profile directories
for profile_dir in chrome_base_path.iterdir():
if profile_dir.is_dir() and profile_dir.name != "System Profile":
history_path = profile_dir / "History"
if history_path.exists():
profile_dirs.append(profile_dir)
print(f"Found Chrome profile: {profile_dir}")
print(f"Found {len(profile_dirs)} Chrome profiles")
return profile_dirs
@staticmethod
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
"""
Export Chrome history to a text file using the same SQL query format.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
"""
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return
try:
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
LIMIT ?
"""
cursor.execute(query, (max_count,))
rows = cursor.fetchall()
with open(output_file, 'w', encoding='utf-8') as f:
for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
conn.close()
print(f"Exported {len(rows)} history entries to {output_file}")
except Exception as e:
print(f"Error exporting Chrome history: {e}")

View File

@@ -1,170 +0,0 @@
"""
Browser History RAG example using the unified interface.
Supports Chrome browser history.
"""
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample, create_text_chunks
from .history_data.history import ChromeHistoryReader
class BrowserRAG(BaseRAGExample):
"""RAG example for Chrome browser history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Browser History",
description="Process and query Chrome browser history with LEANN",
default_index_name="google_history_index",
)
def _add_specific_arguments(self, parser):
"""Add browser-specific arguments."""
browser_group = parser.add_argument_group("Browser Parameters")
browser_group.add_argument(
"--chrome-profile",
type=str,
default=None,
help="Path to Chrome profile directory (auto-detected if not specified)",
)
browser_group.add_argument(
"--auto-find-profiles",
action="store_true",
default=True,
help="Automatically find all Chrome profiles (default: True)",
)
browser_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
browser_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
def _get_chrome_base_path(self) -> Path:
"""Get the base Chrome profile path based on OS."""
if sys.platform == "darwin":
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
elif sys.platform.startswith("linux"):
return Path.home() / ".config" / "google-chrome"
elif sys.platform == "win32":
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
else:
raise ValueError(f"Unsupported platform: {sys.platform}")
def _find_chrome_profiles(self) -> list[Path]:
"""Auto-detect all Chrome profiles."""
base_path = self._get_chrome_base_path()
if not base_path.exists():
return []
profiles = []
# Check Default profile
default_profile = base_path / "Default"
if default_profile.exists() and (default_profile / "History").exists():
profiles.append(default_profile)
# Check numbered profiles
for item in base_path.iterdir():
if item.is_dir() and item.name.startswith("Profile "):
if (item / "History").exists():
profiles.append(item)
return profiles
async def load_data(self, args) -> list[str]:
"""Load browser history and convert to text chunks."""
# Determine Chrome profiles
if args.chrome_profile and not args.auto_find_profiles:
profile_dirs = [Path(args.chrome_profile)]
else:
print("Auto-detecting Chrome profiles...")
profile_dirs = self._find_chrome_profiles()
# If specific profile given, filter to just that one
if args.chrome_profile:
profile_path = Path(args.chrome_profile)
profile_dirs = [p for p in profile_dirs if p == profile_path]
if not profile_dirs:
print("No Chrome profiles found!")
print("Please specify --chrome-profile manually")
return []
print(f"Found {len(profile_dirs)} Chrome profiles")
# Create reader
reader = ChromeHistoryReader()
# Process each profile
all_documents = []
total_processed = 0
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
try:
# Apply max_items limit per profile
max_per_profile = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_profile = remaining
# Load history
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_per_profile,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} history entries from this profile")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No browser history found to process!")
return []
print(f"\nTotal history entries processed: {len(all_documents)}")
# Convert to text chunks
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for browser history RAG
print("\n🌐 Browser History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What websites did I visit about machine learning?'")
print("- 'Find my search history about programming'")
print("- 'What YouTube videos did I watch recently?'")
print("- 'Show me websites about travel planning'")
print("\nNote: Make sure Chrome is closed before running\n")
rag = BrowserRAG()
asyncio.run(rag.run())

View File

@@ -1,108 +0,0 @@
"""
Document RAG example using the unified interface.
Supports PDF, TXT, MD, and other document formats.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample, create_text_chunks
from llama_index.core import SimpleDirectoryReader
class DocumentRAG(BaseRAGExample):
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
def __init__(self):
super().__init__(
name="Document",
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
default_index_name="test_doc_files",
)
def _add_specific_arguments(self, parser):
"""Add document-specific arguments."""
doc_group = parser.add_argument_group("Document Parameters")
doc_group.add_argument(
"--data-dir",
type=str,
default="data",
help="Directory containing documents to index (default: data)",
)
doc_group.add_argument(
"--file-types",
nargs="+",
default=None,
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
)
doc_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
doc_group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)
async def load_data(self, args) -> list[str]:
"""Load documents and convert to text chunks."""
print(f"Loading documents from: {args.data_dir}")
if args.file_types:
print(f"Filtering by file types: {args.file_types}")
else:
print("Processing all supported file types")
# Check if data directory exists
data_path = Path(args.data_dir)
if not data_path.exists():
raise ValueError(f"Data directory not found: {args.data_dir}")
# Load documents
reader_kwargs = {
"recursive": True,
"encoding": "utf-8",
}
if args.file_types:
reader_kwargs["required_exts"] = args.file_types
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
show_progress=True
)
if not documents:
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
return []
print(f"Loaded {len(documents)} documents")
# Convert to text chunks
all_texts = create_text_chunks(
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
# Apply max_items limit if specified
if args.max_items > 0 and len(all_texts) > args.max_items:
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
all_texts = all_texts[: args.max_items]
return all_texts
if __name__ == "__main__":
import asyncio
# Example queries for document RAG
print("\n📄 Document RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What are the main techniques LEANN uses?'")
print("- 'What is the technique DLPM?'")
print("- 'Who does Elizabeth Bennet marry?'")
print(
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
)
print("\nOr run without --query for interactive mode\n")
rag = DocumentRAG()
asyncio.run(rag.run())

View File

113
apps/documents/__main__.py Normal file
View File

@@ -0,0 +1,113 @@
import argparse
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import dotenv
from leann.api import LeannBuilder, LeannChat
from pathlib import Path
import os
dotenv.load_dotenv()
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
# Get the data directory relative to this module
current_dir = Path(__file__).parent
data_dir = current_dir / "data"
documents = SimpleDirectoryReader(
str(data_dir),
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
print(f"\n[PHASE 2] Starting Leann chat session...")
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Leann Chat with various LLM backends."
)
parser.add_argument(
"--llm",
type=str,
default="hf",
choices=["simulated", "ollama", "hf", "openai"],
help="The LLM backend to use.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-0.6B",
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
)
parser.add_argument(
"--host",
type=str,
default="http://localhost:11434",
help="The host for the Ollama API.",
)
parser.add_argument(
"--index-dir",
type=str,
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,82 @@
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
各位好,
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
首先为自证身份,列举一些细节:
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
4. 诺亚曾经传说是研究型的但是来了之后因为在四野做大模型项目项目成员完全变成了交付型的且充满了例会评审汇报。很多时候做实验都要申请。团队需要对接终端小艺华为云ICT等诸多业务线交付压力不小。
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”一开始只有内部需要申请试用的网页版到后续迫于压力在welink上接入和公测开放。
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
华为确实主要在昇腾卡上训练大模型小模型实验室有不少英伟达的卡他们之前也会用来训练后面转移到昇腾。曾经我被华为“打造世界第二选择”的决心而折服我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打从充满bug到现在能训出模型付出了巨大的心血和代价。
最初我们的算力非常有限在910A上训练模型。那会只支持fp16训练的稳定性远不如bf16。盘古的moe开始很早23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型后面主力模型也逐渐在910B上训练。
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低每个单个的符号数字空格乃至汉字都会占用一个token。可想而知这会非常浪费算力且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好虽然事后来看他的怀疑是无疑正确的于是就决定让71B和135B换tokenizer因为小模型实验室曾经尝试过。团队缝合了两个tokenizer开始了tokenizer的更换。71B模型的更换失败了而135B因为采用了更精细的embedding初始化策略续训了至少1T的数据后词表总算更换成功但可想而知效果并不会变好。
于此同期阿里和智谱等国内其他公司在GPU上训练且已经摸索出了正确的方法盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时团队的士气低迷到了极点。团队在算力极其有限的时候做出了很多努力和挣扎。比如团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B架构相对落后团队进行了一系列的操作比如切换绝对位置编码到rope去掉bias切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训变成了第二代38B dense模型在几个月内这个模型都是主要的盘古中档位模型曾经具有一定的竞争力。但是由于更大的135B模型架构落后且更换词表模型损伤巨大后续分析发现当时更换的缝合词表有更严重的bug续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
在这种情况下王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来通过训练短短的几百B数据各项指标平均提升了十个点左右。实际上这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行使得领导完全对于这种扯淡的事情没有概念他们只会觉得肯定是有什么算法创新。经过内部的分析他们实际上是使用Qwen 1.5 110B续训而来通过加层扩增ffn维度添加盘古pi论文的一些机制得来凑够了大概135B的参数。实际上旧的135B有107层而这个模型只有82层各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游甚至包括外部客户。
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击内部很多人其实都知道这件事甚至包括终端和华为云。我们都戏称以后别叫盘古模型了叫千古吧。当时团队成员就想向bcg举报了毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来因为更高级别的领导比如姚老师以及可能熊总和查老其实后面也知道了但是并不管因为通过套壳拿出好的结果对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷离职跑路也逐渐成为挂在嘴边的事。
此时盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来当时诺亚完全没有掌握从头训练的技术何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下盘古开始了第三代模型的训练付出了巨大的努力后在数据架构和训练算法方面都与业界逐渐接轨而这其中的艰辛和小模型实验室的人一点关系都没有。
一开始团队成员毫无信心只从一个13B的模型开始训练但是后面发现效果还不错于是这个模型后续再次进行了一次参数扩增变成了第三代的38B代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的也是业界常见的做法。而当时王云鹤的实验室做出来了另一个词表也就是后续pangu系列的词表。当时两个词表还被迫进行了一次赛马最终没有明显的好坏结论。于是领导当即决定应该统一词表使用王云鹤他们的。于是在后续从头训练的135B V3也就是对外的Pangu Ultra便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑为什么当时同为V3代的两个不同档位的模型会使用不同的tokenizer。
我们打心眼里觉得135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的华为全栈自研正经从头训练的千亿级别的模型且效果与24年同期竞品可比的。写到这里我已经热泪盈眶太不容易了。当时为了稳定训练团队做了大量实验对比并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难我们做到了我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨我们为了它的训练而不眠。在被内部心声骂的一文不值的时候我们有多么不甘有多少的委屈我们挺住了。
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
然而我们的所有辛苦的成果经常被小模型实验室轻飘飘的拿走了。数据直接要走。代码直接要走还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦他们取得荣耀。果然应了那句话你在负重前行是因为有人替你岁月静好。在这种情况下越来越多的战友再也坚持不下去了选择了离开。看到身边那些优秀的同事一个个离职我的内心又感叹又难过。在这种作战一样的环境下我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方堪称良师。看到他们去了诸如字节SeedDeepseek月之暗面腾讯和快手等等很多出色的团队我打心眼里为他们高兴和祝福脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新ta说“来这里是我技术生涯中的耻辱在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足以及没法适应互联网公司高淘汰的环境让我多次想离职的心始终没有迈出这一步。
盘古除了dense模型后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的小模型实验室也开启了第二次主要的套壳行动次要的插曲可能还包括一些别的模型比如math模型即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的就算如此这也与技术报告不符何况是套壳qwen 2.5的14b续训。还记得他们训了没几天内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型都知道他们的套壳行动只是迫于各种原因无法伸张正义。实际上对于后续训了很久很久的这个模型Honestagi能够分析出这个量级的相似性我已经很诧异了因为这个模型为了续训洗参数所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印采取了不少办法甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
24年底和25年初在Deepseek v3和r1发布之后由于其惊艳的技术水平团队受到了巨大的冲击也受到了更大的质疑。于是为了紧跟潮流盘古模仿Deepseek的模型尺寸开启了718B moe的训练。这个时候小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数进行训练。连任务加载ckpt的目录都是deepseekv3改都不改何其嚣张与之相反一些有真正技术信仰的同事在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然这个模型怎么可能比直接套壳的好呢如果不是团队leader坚持早就被叫停了。
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
HonestAGI的事情出来后内部让大家不停的研讨分析如何公关和“回应”。诚然这个原文的分析也许不够有力给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此这两天我内心感到作呕时时怀疑自己的人生意义以及苍天无眼。我不奉陪了我要离职了同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到他们竟然猖狂到敢开源。我没想到他们敢如此愚弄世人大肆宣发。当时我也许是存了侥幸心理没有拒绝署名。我相信很多扎实做事的战友也只是被迫上了贼船或者不知情。但这件事已经无法挽回我希望我的余生能够坚持扎实做真正有意义的事为我当时的软弱和不坚定赎罪。
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
现在,我累了,我想投降。
其实时至今日我还是真心希望华为能认真吸取教训能做好盘古把盘古做到世界一流把昇腾变成英伟达的水平。内部的劣币驱逐良币使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着施展着他们的抱负才华为中美在AI的激烈竞赛中奉献力量。我时常感叹华为不是没有人才而是根本不知道怎么留住人才。如果给这些人合适的环境合适的资源更少的枷锁更少的政治斗争盘古何愁不成
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
如果我消失了就当是我为了真理和理想为了华为乃至中国能够更好地发展算力和AI而牺牲了吧我愿埋葬于那片曾经奋斗过的地方。
诺亚,再见
2025年7月6日凌晨 写于深圳
---
各位好,
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
我补充一些细节,以免某些人继续颠倒黑白。
关于135B V2小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后比如任务令表彰和及时激励因为不想继续支撑下游应用和模型迭代又把这个烫手山芋甩给了四纵。确实技高一筹直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型最终拿回了一个当时一个魔改的先进的千问。做大模型的人自己做的模型就像自己孩子一样熟悉不要把别人都当傻子。就像自家儿子出门一趟回来个别人家孩子。
盘古report的署名是不符合学术规范的。例如135B V3有不少有技术贡献的人因为作者名额数量限制劳动成果没有得到应有的回报团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶甚至是团队当时的精神支柱支撑着不少兄弟们继续留在诺亚。所谓的名额限制以及挂名了一些毫无技术贡献的人如一些小模型实验室的人让兄弟们何其心寒。
---
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317

0
apps/email/__init__.py Normal file
View File

193
apps/email/__main__.py Normal file
View File

@@ -0,0 +1,193 @@
import os
import sys
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
# Auto-detect user's mail path
def get_mail_path():
"""Get the mail path for the current user"""
home_dir = os.path.expanduser("~")
return os.path.join(home_dir, "Library", "Mail")
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from multiple mail data sources.
Args:
messages_dirs: List of Path objects pointing to Messages directories
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process per directory
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from multiple mail data sources...")
# Load documents using EmlxReader from local readers module
from .readers import EmlxReader, find_all_messages_directories
reader = EmlxReader(include_html=include_html)
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Messages directory
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
try:
documents = reader.load_data(messages_dir)
if documents:
print(f"Loaded {len(documents)} email documents from {messages_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {messages_dir}")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path,
llm_config={"type": "openai", "model": "gpt-4o"})
print(f"You: {query}")
import time
start_time = time.time()
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=12,
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
help='Single query to run (default: runs example queries)')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
help='Embedding model to use (default: facebook/contriever)')
args = parser.parse_args()
print(f"args: {args}")
# Automatically find all Messages directories under the current user's Mail directory
from .readers import find_all_messages_directories
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
print('len(messages_dirs): ', len(messages_dirs))
if not messages_dirs:
print("No Messages directories found. Exiting.")
return
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
print(f"Index directory: {INDEX_DIR}")
print(f"Found {len(messages_dirs)} Messages directories.")
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Dict, List, Optional
from fsspec import AbstractFileSystem from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document from llama_index.core.schema import Document
@@ -27,7 +27,11 @@ class MboxReader(BaseReader):
""" """
DEFAULT_MESSAGE_FORMAT: str = ( DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}" "Date: {_date}\n"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
) )
def __init__( def __init__(
@@ -41,7 +45,9 @@ class MboxReader(BaseReader):
try: try:
from bs4 import BeautifulSoup # noqa from bs4 import BeautifulSoup # noqa
except ImportError: except ImportError:
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`") raise ImportError(
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.max_count = max_count self.max_count = max_count
@@ -50,9 +56,9 @@ class MboxReader(BaseReader):
def load_data( def load_data(
self, self,
file: Path, file: Path,
extra_info: dict | None = None, extra_info: Optional[Dict] = None,
fs: AbstractFileSystem | None = None, fs: Optional[AbstractFileSystem] = None,
) -> list[Document]: ) -> List[Document]:
"""Parse file into string.""" """Parse file into string."""
# Import required libraries # Import required libraries
import mailbox import mailbox
@@ -68,7 +74,7 @@ class MboxReader(BaseReader):
) )
i = 0 i = 0
results: list[str] = [] results: List[str] = []
# Load file using mailbox # Load file using mailbox
bytes_parser = BytesParser(policy=default).parse bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
@@ -118,7 +124,7 @@ class MboxReader(BaseReader):
class EmlxMboxReader(MboxReader): class EmlxMboxReader(MboxReader):
""" """
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files. EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by: Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory 1. Reading .emlx files from a directory
2. Converting them to mbox format in memory 2. Converting them to mbox format in memory
@@ -128,13 +134,13 @@ class EmlxMboxReader(MboxReader):
def load_data( def load_data(
self, self,
directory: Path, directory: Path,
extra_info: dict | None = None, extra_info: Optional[Dict] = None,
fs: AbstractFileSystem | None = None, fs: Optional[AbstractFileSystem] = None,
) -> list[Document]: ) -> List[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic.""" """Parse .emlx files from directory into strings using MboxReader logic."""
import os
import tempfile import tempfile
import os
if fs: if fs:
logger.warning( logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading " "fs was specified but EmlxMboxReader doesn't support loading "
@@ -144,37 +150,37 @@ class EmlxMboxReader(MboxReader):
# Find all .emlx files in the directory # Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx")) emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}") logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files: if not emlx_files:
logger.warning(f"No .emlx files found in {directory}") logger.warning(f"No .emlx files found in {directory}")
return [] return []
# Create a temporary mbox file # Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox: with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format # Convert .emlx files to mbox format
for emlx_file in emlx_files: for emlx_file in emlx_files:
try: try:
# Read the .emlx file # Read the .emlx file
with open(emlx_file, encoding="utf-8", errors="ignore") as f: with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read() content = f.read()
# .emlx format: first line is length, rest is email content # .emlx format: first line is length, rest is email content
lines = content.split("\n", 1) lines = content.split('\n', 1)
if len(lines) >= 2: if len(lines) >= 2:
email_content = lines[1] # Skip the length line email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line) # Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n") temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e: except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}") logger.warning(f"Failed to process {emlx_file}: {e}")
continue continue
# Close the temporary file so MboxReader can read it # Close the temporary file so MboxReader can read it
temp_mbox.close() temp_mbox.close()
try: try:
# Use the parent MboxReader's logic to parse the mbox file # Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs) return super().load_data(Path(temp_mbox_path), extra_info, fs)
@@ -182,5 +188,5 @@ class EmlxMboxReader(MboxReader):
# Clean up temporary file # Clean up temporary file
try: try:
os.unlink(temp_mbox_path) os.unlink(temp_mbox_path)
except OSError: except:
pass pass

124
apps/email/readers.py Normal file
View File

@@ -0,0 +1,124 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str = None) -> List[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, dirnames, filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
if part.get_content_type() == "text/html" and not self.include_html:
continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
[EMAIL METADATA]
File: {filename}
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
[END METADATA]
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs

View File

@@ -1,167 +0,0 @@
import email
import os
from pathlib import Path
from typing import Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str | None = None) -> list[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, _dirnames, _filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: list[Document] = []
max_count = load_kwargs.get("max_count", 1000)
count = 0
total_files = 0
successful_files = 0
failed_files = 0
print(f"Starting to process directory: {input_dir}")
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
# Check if we've reached the max count (skip if max_count == -1)
if max_count > 0 and count >= max_count:
break
if filename.endswith(".emlx"):
total_files += 1
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, encoding="utf-8", errors="ignore") as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split("\n", 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get("Subject", "No Subject")
from_addr = msg.get("From", "Unknown")
to_addr = msg.get("To", "Unknown")
date = msg.get("Date", "Unknown")
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if (
part.get_content_type() == "text/plain"
or part.get_content_type() == "text/html"
):
if (
part.get_content_type() == "text/html"
and not self.include_html
):
continue
try:
payload = part.get_payload(decode=True)
if payload:
body += payload.decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error decoding payload: {e}")
continue
else:
try:
payload = msg.get_payload(decode=True)
if payload:
body = payload.decode("utf-8", errors="ignore")
except Exception as e:
print(f"Error decoding single part payload: {e}")
body = ""
# Only create document if we have some content
if body.strip() or subject != "No Subject":
# Create document content with metadata embedded in text
doc_content = f"""
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
successful_files += 1
# Print first few successful files for debugging
if successful_files <= 3:
print(
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
)
except Exception as e:
failed_files += 1
if failed_files <= 5: # Only print first few errors
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
failed_files += 1
if failed_files <= 5: # Only print first few errors
print(f"Error reading file {filepath}: {e}")
continue
print("Processing summary:")
print(f" Total .emlx files found: {total_files}")
print(f" Successfully loaded: {successful_files}")
print(f" Failed to load: {failed_files}")
print(f" Final documents: {len(docs)}")
return docs

View File

@@ -1,156 +0,0 @@
"""
Email RAG example using the unified interface.
Supports Apple Mail on macOS.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample, create_text_chunks
from .email_data.LEANN_email_reader import EmlxReader
class EmailRAG(BaseRAGExample):
"""RAG example for Apple Mail processing."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Process all emails by default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="Email",
description="Process and query Apple Mail emails with LEANN",
default_index_name="mail_index",
)
def _add_specific_arguments(self, parser):
"""Add email-specific arguments."""
email_group = parser.add_argument_group("Email Parameters")
email_group.add_argument(
"--mail-path",
type=str,
default=None,
help="Path to Apple Mail directory (auto-detected if not specified)",
)
email_group.add_argument(
"--include-html", action="store_true", help="Include HTML content in email processing"
)
email_group.add_argument(
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
)
email_group.add_argument(
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
)
def _find_mail_directories(self) -> list[Path]:
"""Auto-detect all Apple Mail directories."""
mail_base = Path.home() / "Library" / "Mail"
if not mail_base.exists():
return []
# Find all Messages directories
messages_dirs = []
for item in mail_base.rglob("Messages"):
if item.is_dir():
messages_dirs.append(item)
return messages_dirs
async def load_data(self, args) -> list[str]:
"""Load emails and convert to text chunks."""
# Determine mail directories
if args.mail_path:
messages_dirs = [Path(args.mail_path)]
else:
print("Auto-detecting Apple Mail directories...")
messages_dirs = self._find_mail_directories()
if not messages_dirs:
print("No Apple Mail directories found!")
print("Please specify --mail-path manually")
return []
print(f"Found {len(messages_dirs)} mail directories")
# Create reader
reader = EmlxReader(include_html=args.include_html)
# Process each directory
all_documents = []
total_processed = 0
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
try:
# Count emlx files
emlx_files = list(messages_dir.glob("*.emlx"))
print(f"Found {len(emlx_files)} email files")
# Apply max_items limit per directory
max_per_dir = -1 # Default to process all
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_dir = remaining
# If args.max_items == -1, max_per_dir stays -1 (process all)
# Load emails - fix the parameter passing
documents = reader.load_data(
input_dir=str(messages_dir),
max_count=max_per_dir,
)
if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} emails from this directory")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No emails found to process!")
return []
print(f"\nTotal emails processed: {len(all_documents)}")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks
# Email reader uses chunk_overlap=25 as in original
all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
print(" Windows/Linux support coming soon!\n")
# Example queries for email RAG
print("\n📧 Email RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What did my boss say about deadlines?'")
print("- 'Find emails about travel expenses'")
print("- 'Show me emails from last month about the project'")
print("- 'What food did I order from DoorDash?'")
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
rag = EmailRAG()
asyncio.run(rag.run())

View File

382
apps/evaluation/__main__.py Normal file
View File

@@ -0,0 +1,382 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List
from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try:
from huggingface_hub import snapshot_download
if download_embeddings:
# Download everything including embeddings (large files)
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument(
"index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print(
"No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file, "r") as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

0
apps/wechat/__init__.py Normal file
View File

230
apps/wechat/__main__.py Normal file
View File

@@ -0,0 +1,230 @@
import os
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any, Optional
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
import requests
import time
dotenv.load_dotenv()
# Default WeChat export directory
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path],
index_path: str = "wechat_history_index.leann",
max_count: int = -1,
):
"""
Create LEANN index from multiple WeChat export data sources.
Args:
export_dirs: List of Path objects pointing to WeChat export directories
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process per export
"""
print("Creating LEANN index from multiple WeChat export data sources...")
# Load documents using WeChatHistoryReader from local readers module
from .readers import WeChatHistoryReader
reader = WeChatHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each WeChat export directory
for i, export_dir in enumerate(export_dirs):
print(
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try:
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_count,
concatenate_messages=True, # Disable concatenation - one message per document
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
all_texts.append(text)
print(
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
)
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="Qwen/Qwen3-Embedding-0.6B",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=16,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann: {chat_response}")
async def main():
"""Main function with integrated WeChat export functionality."""
# Parse command line arguments
parser = argparse.ArgumentParser(
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
)
parser.add_argument(
"--export-dir",
type=str,
default=DEFAULT_WECHAT_EXPORT_DIR,
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
)
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_magic_test_11Debug_new",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=50,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--force-export",
action="store_true",
default=False,
help="Force re-export of WeChat data even if exports exist",
)
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
print(f"Using WeChat export directory: {args.export_dir}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Initialize WeChat reader with export capabilities
from .readers import WeChatHistoryReader
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Exiting.")
return
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_wechat_exports(
export_dirs, INDEX_PATH, max_count=args.max_entries
)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
]
for query in queries:
print("\n" + "=" * 60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

719
apps/wechat/readers.py Normal file
View File

@@ -0,0 +1,719 @@
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
from typing import List, Any, Dict, Optional
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
from datetime import datetime
class WeChatHistoryReader(BaseReader):
"""
WeChat chat history reader that extracts chat data from exported JSON files.
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
and creates documents with embedded metadata similar to the Chrome history reader structure.
Also includes utilities for automatic WeChat chat history export.
"""
def __init__(self) -> None:
"""Initialize."""
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
def check_wechat_running(self) -> bool:
"""Check if WeChat is currently running."""
try:
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
return result.returncode == 0
except Exception:
return False
def install_wechattweak(self) -> bool:
"""Install WeChatTweak CLI tool."""
try:
# Create wechat-exporter directory if it doesn't exist
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...")
subprocess.run([
"curl", "-L", "-o", str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
], check=True)
# Make executable
wechattweak_path.chmod(0o755)
# Install WeChatTweak
print("Installing WeChatTweak...")
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
return True
except Exception as e:
print(f"Error installing WeChatTweak: {e}")
return False
def restart_wechat(self):
"""Restart WeChat to apply WeChatTweak."""
try:
print("Restarting WeChat...")
subprocess.run(["pkill", "-f", "WeChat"], check=False)
time.sleep(2)
subprocess.run(["open", "-a", "WeChat"], check=True)
time.sleep(5) # Wait for WeChat to start
except Exception as e:
print(f"Error restarting WeChat: {e}")
def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available."""
try:
result = subprocess.run([
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
], capture_output=True, text=True, timeout=5)
return result.returncode == 0 and result.stdout.strip()
except Exception:
return False
def _extract_readable_text(self, content: str) -> str:
"""
Extract readable text from message content, removing XML and system messages.
Args:
content: The raw message content (can be string or dict)
Returns:
Cleaned, readable text
"""
if not content:
return ""
# Handle dictionary content (like quoted messages)
if isinstance(content, dict):
# Extract text from dictionary structure
text_parts = []
if 'title' in content:
text_parts.append(str(content['title']))
if 'quoted' in content:
text_parts.append(str(content['quoted']))
if 'content' in content:
text_parts.append(str(content['content']))
if 'text' in content:
text_parts.append(str(content['text']))
if text_parts:
return " | ".join(text_parts)
else:
# If we can't extract meaningful text from dict, return empty
return ""
# Handle string content
if not isinstance(content, str):
return ""
# Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If it's just XML or system message, return empty
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
return ""
return clean_content.strip()
def _is_text_message(self, content: str) -> bool:
"""
Check if a message contains readable text content.
Args:
content: The message content (can be string or dict)
Returns:
True if the message contains readable text, False otherwise
"""
if not content:
return False
# Handle dictionary content
if isinstance(content, dict):
# Check if dict has any readable text fields
text_fields = ['title', 'quoted', 'content', 'text']
for field in text_fields:
if field in content and content[field]:
return True
return False
# Handle string content
if not isinstance(content, str):
return False
# Skip image messages (contain XML with img tags)
if '<img' in content and 'cdnurl' in content:
return False
# Skip emoji messages (contain emoji XML tags)
if '<emoji' in content and 'productid' in content:
return False
# Skip voice messages
if '<voice' in content:
return False
# Skip video messages
if '<video' in content:
return False
# Skip file messages
if '<appmsg' in content and 'appid' in content:
return False
# Skip system messages (like "recalled a message")
if 'recalled a message' in content:
return False
# Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
return True
return False
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
"""
Concatenate messages based on length and time rules.
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
List of concatenated message groups
"""
if not messages:
return []
concatenated_groups = []
current_group = []
current_length = 0
last_timestamp = None
for message in messages:
# Extract message info
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Skip empty messages
if not readable_text.strip():
continue
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
if current_group:
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Check length constraint (only if max_length != -1)
message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Add message to current group
current_group.append(message)
current_length += message_length
last_timestamp = create_time
# Add the last group if it exists
if current_group:
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
return concatenated_groups
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
"""
Create concatenated content from a group of messages.
Args:
message_group: Dictionary containing messages and metadata
contact_name: Name of the contact
Returns:
Formatted concatenated content
"""
messages = message_group['messages']
start_time = message_group['start_time']
end_time = message_group['end_time']
# Format timestamps
if start_time:
try:
start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
start_time_str = str(start_time)
else:
start_time_str = "Unknown"
if end_time:
try:
end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
end_time_str = str(end_time)
else:
end_time_str = "Unknown"
# Build concatenated message content
message_parts = []
for message in messages:
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Format individual message
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
# Create final document content
doc_content = f"""
Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text}
"""
# TODO @yichuan give better format and rich info here!
doc_content = f"""
{concatenated_text}
"""
return doc_content, contact_name
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load WeChat chat history data from exported JSON files.
Args:
input_dir: Directory containing exported WeChat JSON files
**load_kwargs:
max_count (int): Maximum amount of chat entries to read.
wechat_export_dir (str): Custom path to WeChat export directory.
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
concatenate_messages (bool): Whether to concatenate messages based on length rules.
max_length (int): Maximum length for concatenated message groups (default: 1000).
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
include_non_text = load_kwargs.get('include_non_text', False)
concatenate_messages = load_kwargs.get('concatenate_messages', False)
max_length = load_kwargs.get('max_length', 1000)
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
# Default WeChat export path
if wechat_export_dir is None:
wechat_export_dir = "./wechat_export_test"
if not os.path.exists(wechat_export_dir):
print(f"WeChat export directory not found at: {wechat_export_dir}")
return docs
try:
# Find all JSON files in the export directory
json_files = list(Path(wechat_export_dir).glob("*.json"))
print(f"Found {len(json_files)} WeChat chat history files")
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, 'r', encoding='utf-8') as f:
chat_data = json.load(f)
# Extract contact name from filename
contact_name = json_file.stem
if concatenate_messages:
# Filter messages to only include readable text messages
readable_messages = []
for message in chat_data:
try:
content = message.get('content', '')
if not include_non_text and not self._is_text_message(content):
continue
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
readable_messages.append(message)
except Exception as e:
print(f"Error processing message in {json_file}: {e}")
continue
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=-1,
time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups
)
# Create documents from concatenated groups
for message_group in message_groups:
if count >= max_count and max_count > 0:
break
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
docs.append(doc)
count += 1
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
else:
# Original single-message processing
for message in chat_data:
if count >= max_count and max_count > 0:
break
# Extract message information
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Handle content that might be dict or string
try:
# Check if this is a readable text message
if not include_non_text and not self._is_text_message(content):
continue
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
except Exception as e:
# Skip messages that cause processing errors
print(f"Error processing message in {json_file}: {e}")
continue
# Convert timestamp to readable format
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
# Create document content with metadata header and contact info
doc_content = f"""
Contact: {contact_name}
Is sent from self: {is_sent_from_self}
Time: {time_str}
Message: {readable_text if readable_text else message_text}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error reading {json_file}: {e}")
continue
print(f"Loaded {len(docs)} WeChat chat documents")
except Exception as e:
print(f"Error reading WeChat history: {e}")
return docs
return docs
@staticmethod
def find_wechat_export_dirs() -> List[Path]:
"""
Find all WeChat export directories.
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for common export directory names
possible_dirs = [
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_chat_history"),
Path("./chat_export")
]
for export_dir in possible_dirs:
if export_dir.exists() and export_dir.is_dir():
json_files = list(export_dir.glob("*.json"))
if json_files:
export_dirs.append(export_dir)
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs
@staticmethod
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
"""
Export WeChat chat history to a text file.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
export_dir: Directory containing WeChat JSON files
include_non_text: Whether to include non-text messages
"""
if export_dir is None:
export_dir = "./wechat_export_test"
if not os.path.exists(export_dir):
print(f"WeChat export directory not found at: {export_dir}")
return
try:
json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, 'w', encoding='utf-8') as f:
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, 'r', encoding='utf-8') as json_f:
chat_data = json.load(json_f)
contact_name = json_file.stem
f.write(f"\n=== Chat with {contact_name} ===\n")
for message in chat_data:
if count >= max_count and max_count > 0:
break
from_user = message.get('fromUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
# Skip non-text messages unless requested
if not include_non_text:
reader = WeChatHistoryReader()
if not reader._is_text_message(content):
continue
readable_text = reader._extract_readable_text(content)
if not readable_text:
continue
message_text = readable_text
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
f.write(f"[{time_str}] {from_user}: {message_text}\n")
count += 1
except Exception as e:
print(f"Error processing {json_file}: {e}")
continue
print(f"Exported {count} chat entries to {output_file}")
except Exception as e:
print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
"""
Export WeChat chat history using wechat-exporter tool.
Args:
export_dir: Directory to save exported chat history
Returns:
Path to export directory if successful, None otherwise
"""
try:
import subprocess
import sys
# Create export directory
export_path = Path(export_dir)
export_path.mkdir(exist_ok=True)
print(f"Exporting WeChat chat history to {export_path}...")
# Check if wechat-exporter directory exists
if not self.wechat_exporter_dir.exists():
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
return None
# Install requirements if needed
requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists():
print("Installing wechat-exporter requirements...")
subprocess.run([
"uv", "pip", "install", "-r", str(requirements_file)
], check=True)
# Run the export command
print("Running wechat-exporter...")
result = subprocess.run([
sys.executable, str(self.wechat_exporter_dir / "main.py"),
"export-all", str(export_path)
], capture_output=True, text=True, check=True)
print("Export command output:")
print(result.stdout)
if result.stderr:
print("Export errors:")
print(result.stderr)
# Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json"))
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
return export_path
else:
print("Export completed but no JSON files found")
return None
except subprocess.CalledProcessError as e:
print(f"Export command failed: {e}")
print(f"Command output: {e.stdout}")
print(f"Command errors: {e.stderr}")
return None
except Exception as e:
print(f"Export failed: {e}")
print("Please ensure WeChat is running and WeChatTweak is installed.")
return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
"""
Find existing WeChat exports or create new ones.
Args:
export_dir: Directory to save exported chat history if needed
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for existing exports in common locations
possible_export_dirs = [
Path("./wechat_database_export"),
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export")
]
for export_dir_path in possible_export_dirs:
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
export_dirs.append(export_dir_path)
print(f"Found existing export: {export_dir_path}")
# If no existing exports, try to export automatically
if not export_dirs:
print("No existing WeChat exports found. Starting direct export...")
# Try to export using wechat-exporter
exported_path = self.export_wechat_chat_history(export_dir)
if exported_path:
export_dirs = [exported_path]
else:
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
return export_dirs

View File

@@ -1,189 +0,0 @@
"""
WeChat History RAG example using the unified interface.
Supports WeChat chat history export and search.
"""
import subprocess
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample
from .history_data.wechat_history import WeChatHistoryReader
class WeChatRAG(BaseRAGExample):
"""RAG example for WeChat chat history."""
def __init__(self):
# Set default values BEFORE calling super().__init__
self.max_items_default = -1 # Match original default
self.embedding_model_default = (
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
)
super().__init__(
name="WeChat History",
description="Process and query WeChat chat history with LEANN",
default_index_name="wechat_history_magic_test_11Debug_new",
)
def _add_specific_arguments(self, parser):
"""Add WeChat-specific arguments."""
wechat_group = parser.add_argument_group("WeChat Parameters")
wechat_group.add_argument(
"--export-dir",
type=str,
default="./wechat_export",
help="Directory to store WeChat exports (default: ./wechat_export)",
)
wechat_group.add_argument(
"--force-export",
action="store_true",
help="Force re-export of WeChat data even if exports exist",
)
wechat_group.add_argument(
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
)
wechat_group.add_argument(
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
)
def _export_wechat_data(self, export_dir: Path) -> bool:
"""Export WeChat data using wechattweak-cli."""
print("Exporting WeChat data...")
# Check if WeChat is running
try:
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
if result.returncode != 0:
print("WeChat is not running. Please start WeChat first.")
return False
except Exception:
pass # pgrep might not be available on all systems
# Create export directory
export_dir.mkdir(parents=True, exist_ok=True)
# Run export command
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
try:
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
print("WeChat data exported successfully!")
return True
else:
print(f"Export failed: {result.stderr}")
return False
except FileNotFoundError:
print("\nError: wechattweak-cli not found!")
print("Please install it first:")
print(" sudo packages/wechat-exporter/wechattweak-cli install")
return False
except Exception as e:
print(f"Export error: {e}")
return False
async def load_data(self, args) -> list[str]:
"""Load WeChat history and convert to text chunks."""
# Initialize WeChat reader with export capabilities
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Trying to find any existing exports...")
# Try to find any existing exports in common locations
export_dirs = reader.find_wechat_export_dirs()
if not export_dirs:
print("No WeChat data found. Please ensure WeChat exports exist.")
return []
# Load documents from all found export directories
all_documents = []
total_processed = 0
for i, export_dir in enumerate(export_dirs):
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
try:
# Apply max_items limit per export
max_per_export = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_export = remaining
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_per_export,
concatenate_messages=True, # Enable message concatenation for better context
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return []
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
print("now starting to split into text chunks ... take some time")
# Convert to text chunks with contact information
all_texts = []
for doc in all_documents:
# Split the document into chunks
from llama_index.core.node_parser import SentenceSplitter
text_splitter = SentenceSplitter(
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
# Add contact information to each chunk
contact_name = doc.metadata.get("contact_name", "Unknown")
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
return all_texts
if __name__ == "__main__":
import asyncio
# Check platform
if sys.platform != "darwin":
print("\n⚠️ Warning: WeChat export is only supported on macOS")
print(" You can still query existing exports on other platforms\n")
# Example queries for WeChat RAG
print("\n💬 WeChat History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'Show me conversations about travel plans'")
print("- 'Find group chats about weekend activities'")
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
print("- 'What did we discuss about the project last month?'")
print("\nNote: WeChat must be running for export to work\n")
rag = WeChatRAG()
asyncio.run(rag.run())

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 73 KiB

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 224 KiB

View File

@@ -1,268 +0,0 @@
#!/usr/bin/env python3
"""
DiskANN vs HNSW Search Performance Comparison
This benchmark compares search performance between DiskANN and HNSW backends:
- DiskANN: With graph partitioning enabled (is_recompute=True)
- HNSW: With recompute enabled (is_recompute=True)
- Tests performance across different dataset sizes
- Measures search latency, recall, and index size
"""
import gc
import tempfile
import time
from pathlib import Path
from typing import Any
import numpy as np
def create_test_texts(n_docs: int) -> list[str]:
"""Create synthetic test documents for benchmarking."""
np.random.seed(42)
topics = [
"machine learning and artificial intelligence",
"natural language processing and text analysis",
"computer vision and image recognition",
"data science and statistical analysis",
"deep learning and neural networks",
"information retrieval and search engines",
"database systems and data management",
"software engineering and programming",
"cybersecurity and network protection",
"cloud computing and distributed systems",
]
texts = []
for i in range(n_docs):
topic = topics[i % len(topics)]
variation = np.random.randint(1, 100)
text = (
f"This is document {i} about {topic}. Content variation {variation}. "
f"Additional information about {topic} with details and examples. "
f"Technical discussion of {topic} including implementation aspects."
)
texts.append(text)
return texts
def benchmark_backend(
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
) -> dict[str, float]:
"""Benchmark a specific backend with the given configuration."""
from leann.api import LeannBuilder, LeannSearcher
print(f"\n🔧 Testing {backend_name.upper()} backend...")
with tempfile.TemporaryDirectory() as temp_dir:
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
# Build index
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
start_time = time.time()
builder = LeannBuilder(
backend_name=backend_name,
embedding_model="facebook/contriever",
embedding_mode="sentence-transformers",
**backend_kwargs,
)
for text in texts:
builder.add_text(text)
builder.build_index(index_path)
build_time = time.time() - start_time
# Measure index size
index_dir = Path(index_path).parent
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
size_mb = total_size / (1024 * 1024)
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
# Search benchmark
print("🔍 Running search benchmark...")
searcher = LeannSearcher(index_path)
search_times = []
all_results = []
for query in test_queries:
start_time = time.time()
results = searcher.search(query, top_k=5)
search_time = time.time() - start_time
search_times.append(search_time)
all_results.append(results)
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
# Check for valid scores (detect -inf issues)
all_scores = [
result.score
for results in all_results
for result in results
if result.score is not None
]
valid_scores = [
score for score in all_scores if score != float("-inf") and score != float("inf")
]
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
# Clean up
try:
if hasattr(searcher, "__del__"):
searcher.__del__()
del searcher
del builder
gc.collect()
except Exception as e:
print(f"⚠️ Warning: Resource cleanup error: {e}")
return {
"build_time": build_time,
"avg_search_time_ms": avg_search_time,
"index_size_mb": size_mb,
"score_validity_rate": score_validity_rate,
}
def run_comparison(n_docs: int = 500, n_queries: int = 10):
"""Run performance comparison between DiskANN and HNSW."""
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
# Create test data
texts = create_test_texts(n_docs)
test_queries = [
"machine learning algorithms",
"natural language processing",
"computer vision techniques",
"data analysis methods",
"neural network architectures",
"database query optimization",
"software development practices",
"security vulnerabilities",
"cloud infrastructure",
"distributed computing",
][:n_queries]
# HNSW benchmark
hnsw_results = benchmark_backend(
backend_name="hnsw",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable recompute for fair comparison
"M": 16,
"efConstruction": 200,
},
)
# DiskANN benchmark
diskann_results = benchmark_backend(
backend_name="diskann",
texts=texts,
test_queries=test_queries,
backend_kwargs={
"is_recompute": True, # Enable graph partitioning
"num_neighbors": 32,
"search_list_size": 50,
},
)
# Performance comparison
print("\n📈 Performance Comparison Results")
print(f"{'=' * 60}")
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
print(f"{'-' * 60}")
# Build time comparison
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
print(
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
)
# Search time comparison
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
print(
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
)
# Index size comparison
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
print(
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
)
# Score validity
print(
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
)
print(f"{'=' * 60}")
print("\n🎯 Summary:")
if search_speedup > 1:
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
else:
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
if size_ratio > 1:
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
else:
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
print(
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
)
if __name__ == "__main__":
import sys
try:
# Handle help request
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
print()
print("Arguments:")
print(" n_docs Number of documents to index (default: 500)")
print(" n_queries Number of test queries to run (default: 10)")
print()
print("Examples:")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
sys.exit(0)
# Parse command line arguments
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
print("DiskANN vs HNSW Performance Comparison")
print("=" * 50)
print(f"Dataset: {n_docs} documents, {n_queries} queries")
print()
run_comparison(n_docs=n_docs, n_queries=n_queries)
except KeyboardInterrupt:
print("\n⚠️ Benchmark interrupted by user")
sys.exit(130)
except Exception as e:
print(f"\n❌ Benchmark failed: {e}")
sys.exit(1)
finally:
# Ensure clean exit
try:
gc.collect()
print("\n🧹 Cleanup completed")
except Exception:
pass
sys.exit(0)

View File

View File

Binary file not shown.

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

44
data/README.md Normal file
View File

@@ -0,0 +1,44 @@
---
license: mit
---
# LEANN-RAG Evaluation Data
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
## Dataset Components
This dataset is structured into three main parts:
1. **Pre-built LEANN Indices**:
* `dpr/`: A pre-built index for the DPR dataset.
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
2. **Ground Truth Data**:
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
3. **Queries**:
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
## Usage
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
```bash
pip install huggingface-hub
```
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
```python
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir="data"
)
```
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.

View File

@@ -1,116 +1,37 @@
{ {
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start \n",
"\n",
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
"\n",
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# install this if you are using colab\n", "from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
"! uv pip install leann --no-deps\n",
"# For Colab environment, we need to set some environment variables\n",
"import os\n",
"\n",
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"INDEX_DIR = Path(\"./\").resolve()\n",
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannBuilder\n",
"\n", "\n",
"# 1. Build the index (no embeddings stored!)\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n", "builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n", "builder.add_text(\"C# is a powerful programming language\")\n",
"builder.add_text(\n", "builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
")\n",
"builder.add_text(\"Machine learning transforms industries\")\n", "builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"Neural networks process complex data\")\n", "builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n", "builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.build_index(INDEX_PATH)" "builder.build_index(\"knowledge.leann\")\n",
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search with real-time embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannSearcher\n",
"\n", "\n",
"searcher = LeannSearcher(INDEX_PATH)\n", "# 2. Search with real-time embeddings\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"programming languages\", top_k=2)\n", "results = searcher.search(\"programming languages\", top_k=2)\n",
"results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chat with LEANN using retrieved results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannChat\n",
"\n", "\n",
"# 3. Chat with LEANN using retrieved results\n",
"llm_config = {\n", "llm_config = {\n",
" \"type\": \"hf\",\n", " \"type\": \"ollama\",\n",
" \"model\": \"Qwen/Qwen3-0.6B\",\n", " \"model\": \"llama3.2:1b\"\n",
"}\n", "}\n",
"\n", "\n",
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n", "chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
"response = chat.ask(\n", "response = chat.ask(\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n", " \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
" top_k=2,\n", " top_k=2,\n",
" llm_kwargs={\"max_tokens\": 128},\n", ")"
")\n",
"response"
] ]
} }
], ],

View File

@@ -1,220 +0,0 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## 🚀 Development Setup
### Prerequisites
1. **Install uv** (fast Python package installer):
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. **Clone the repository**:
```bash
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
cd LEANN-RAG
```
3. **Install system dependencies**:
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq pkgconf
```
**Ubuntu/Debian:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
```
4. **Build from source**:
```bash
# macOS
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
# Ubuntu/Debian
uv sync
```
## 🔨 Pre-commit Hooks
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
### Setup Pre-commit
1. **Install pre-commit** (already included when you run `uv sync`):
```bash
uv pip install pre-commit
```
2. **Install the git hooks**:
```bash
pre-commit install
```
3. **Run pre-commit manually** (optional):
```bash
pre-commit run --all-files
```
### Pre-commit Checks
Our pre-commit configuration includes:
- **Trailing whitespace removal**
- **End-of-file fixing**
- **YAML validation**
- **Large file prevention**
- **Merge conflict detection**
- **Debug statement detection**
- **Code formatting with ruff**
- **Code linting with ruff**
## 🧪 Testing
### Running Tests
```bash
# Run all tests
uv run pytest
# Run specific test file
uv run pytest test/test_filename.py
# Run with coverage
uv run pytest --cov=leann
```
### Writing Tests
- Place tests in the `test/` directory
- Follow the naming convention `test_*.py`
- Use descriptive test names that explain what's being tested
- Include both positive and negative test cases
## 📝 Code Style
We use `ruff` for both linting and formatting to ensure consistent code style.
### Format Your Code
```bash
# Format all files
ruff format
# Check formatting without changing files
ruff format --check
```
### Lint Your Code
```bash
# Run linter with auto-fix
ruff check --fix
# Just check without fixing
ruff check
```
### Style Guidelines
- Follow PEP 8 conventions
- Use descriptive variable names
- Add type hints where appropriate
- Write docstrings for all public functions and classes
- Keep functions focused and single-purpose
## 🚦 CI/CD
Our CI pipeline runs automatically on all pull requests. It includes:
1. **Linting and Formatting**: Ensures code follows our style guidelines
2. **Multi-platform builds**: Tests on Ubuntu and macOS
3. **Python version matrix**: Tests on Python 3.9-3.13
4. **Wheel building**: Ensures packages can be built and distributed
### CI Commands
The CI uses the same commands as pre-commit to ensure consistency:
```bash
# Linting
ruff check .
# Format checking
ruff format --check .
```
Make sure your code passes these checks locally before pushing!
## 🔄 Pull Request Process
1. **Fork the repository** and create your branch from `main`:
```bash
git checkout -b feature/your-feature-name
```
2. **Make your changes**:
- Write clean, documented code
- Add tests for new functionality
- Update documentation as needed
3. **Run pre-commit checks**:
```bash
pre-commit run --all-files
```
4. **Test your changes**:
```bash
uv run pytest
```
5. **Commit with descriptive messages**:
```bash
git commit -m "feat: add new search algorithm"
```
Follow [Conventional Commits](https://www.conventionalcommits.org/):
- `feat:` for new features
- `fix:` for bug fixes
- `docs:` for documentation changes
- `test:` for test additions/changes
- `refactor:` for code refactoring
- `perf:` for performance improvements
6. **Push and create a pull request**:
- Provide a clear description of your changes
- Reference any related issues
- Include examples or screenshots if applicable
## 📚 Documentation
When adding new features or making significant changes:
1. Update relevant documentation in `/docs`
2. Add docstrings to new functions/classes
3. Update README.md if needed
4. Include usage examples
## 🤔 Getting Help
- **Discord**: Join our community for discussions
- **Issues**: Check existing issues or create a new one
- **Discussions**: For general questions and ideas
## 📄 License
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
---
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟

View File

@@ -1,22 +0,0 @@
# Release Guide
## Setup (One-time)
Add `PYPI_API_TOKEN` to GitHub Secrets:
1. Get token: https://pypi.org/manage/account/token/
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
## Release (One-click)
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
2. Click "Run workflow"
3. Enter version: `0.1.2`
4. Click green "Run workflow" button
That's it! The workflow will automatically:
- ✅ Update version in all packages
- ✅ Build all packages
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions

View File

@@ -1,123 +0,0 @@
# Thinking Budget Feature Implementation
## Overview
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
## Feature Description
The thinking budget feature provides three levels of computational effort for reasoning models:
- **`low`**: Fast responses, basic reasoning (default for simple queries)
- **`medium`**: Balanced speed and reasoning depth
- **`high`**: Maximum reasoning effort, best for complex analytical questions
## Implementation Details
### 1. Command Line Interface
Added `--thinking-budget` parameter to both CLI and RAG examples:
```bash
# LEANN CLI
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# RAG Examples
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
```
### 2. LLM Backend Support
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
options.pop("thinking_budget", None)
if thinking_budget in ["low", "medium", "high"]:
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
```
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
```python
def ask(self, prompt: str, **kwargs) -> str:
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
params["reasoning_effort"] = thinking_budget
```
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
### 3. Parameter Propagation
The thinking budget parameter is properly propagated through the LEANN architecture:
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
4. **LLM Interface**: Handles the parameter in backend-specific implementations
## Files Modified
### Core Implementation
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
### Documentation
- `README.md`: Added thinking budget parameter to usage examples
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
### Examples
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
## Usage Examples
### Basic Usage
```bash
# High reasoning effort for complex questions
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
# Medium reasoning for balanced performance
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
# Low reasoning for fast responses
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
```
### RAG Examples
```bash
# Email RAG with high reasoning
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
# Document RAG with medium reasoning
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
```
## Supported Models
### Ollama Models
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
### OpenAI Models
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
- **GPT-OSS models**: Models that support reasoning capabilities
## Testing
The implementation includes comprehensive testing:
- Parameter handling verification
- Backend-specific API format validation
- CLI argument parsing tests
- Integration with existing LEANN architecture

View File

@@ -1,98 +0,0 @@
"""
Comparison between Sentence Transformers and OpenAI embeddings
This example shows how different embedding models handle complex queries
and demonstrates the differences between local and API-based embeddings.
"""
import numpy as np
from leann.embedding_compute import compute_embeddings
# OpenAI API key should be set as environment variable
# export OPENAI_API_KEY="your-api-key-here"
# Test data
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
# Two queries with same intent but different wording
query1 = "Tell me my browser history about some conference i often visit"
query2 = "browser history about conference I often visit"
texts = [query1, query2, conference_text, browser_text]
def cosine_similarity(a, b):
return np.dot(a, b) # Already normalized
def analyze_embeddings(embeddings, model_name):
print(f"\n=== {model_name} Results ===")
# Results for Query 1
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
print(f"Query 1: '{query1}'")
print(f" → Conference similarity: {sim1_conf:.4f} {'' if sim1_conf > sim1_browser else ''}")
print(
f" → Browser similarity: {sim1_browser:.4f} {'' if sim1_browser > sim1_conf else ''}"
)
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
# Results for Query 2
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
print(f"\nQuery 2: '{query2}'")
print(f" → Conference similarity: {sim2_conf:.4f} {'' if sim2_conf > sim2_browser else ''}")
print(
f" → Browser similarity: {sim2_browser:.4f} {'' if sim2_browser > sim2_conf else ''}"
)
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
# Show the impact
print("\n=== Impact Analysis ===")
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
print("✅ STABLE: Conference remains winner in both queries")
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
print("✅ STABLE: Browser remains winner in both queries")
else:
print("🔄 MIXED: Results vary between queries")
return {
"query1_conf": sim1_conf,
"query1_browser": sim1_browser,
"query2_conf": sim2_conf,
"query2_browser": sim2_browser,
}
# Test Sentence Transformers
print("Testing Sentence Transformers (facebook/contriever)...")
try:
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
except Exception as e:
print(f"❌ Sentence Transformers failed: {e}")
st_results = None
# Test OpenAI
print("\n" + "=" * 60)
print("Testing OpenAI (text-embedding-3-small)...")
try:
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
except Exception as e:
print(f"❌ OpenAI failed: {e}")
openai_results = None
# Compare results
if st_results and openai_results:
print("\n" + "=" * 60)
print("=== COMPARISON SUMMARY ===")

View File

@@ -1,300 +0,0 @@
# LEANN Configuration Guide
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
## Getting Started: Simple is Better
When first trying LEANN, start with a small dataset to quickly validate your approach:
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
```bash
python -m apps.document_rag --query "What techniques does LEANN use?"
```
**For other data sources**: Limit the dataset size for quick testing
```bash
# WeChat: Test with recent messages only
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
# Browser history: Last few days
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
# Email: Recent inbox
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
```
Once validated, scale up gradually:
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
- This helps identify issues early before committing to long processing times
## Embedding Model Selection: Understanding the Trade-offs
Based on our experience developing LEANN, embedding models fall into three categories:
### Small Models (< 100M parameters)
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
- **Pros**: Lightweight, fast for both indexing and inference
- **Cons**: Lower semantic understanding, may miss nuanced relationships
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
### Medium Models (100M-500M parameters)
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
- **Pros**: Balanced performance, good multilingual support, reasonable speed
- **Cons**: Requires more compute than small models
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
### Large Models (500M+ parameters)
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
- **Cons**: Slower inference, longer index build times
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
### Quick Start: Cloud and Local Embedding Options
**OpenAI Embeddings (Fastest Setup)**
For immediate testing without local model downloads:
```bash
# Set OpenAI embeddings (requires OPENAI_API_KEY)
--embedding-mode openai --embedding-model text-embedding-3-small
```
**Ollama Embeddings (Privacy-Focused)**
For local embeddings with complete privacy:
```bash
# First, pull an embedding model
ollama pull nomic-embed-text
# Use Ollama embeddings
--embedding-mode ollama --embedding-model nomic-embed-text
```
<details>
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
**OpenAI Embeddings** (`text-embedding-3-small/large`)
- **Pros**: No local compute needed, consistently fast, high quality
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- **When to use**: Prototyping, non-sensitive data, need immediate results
**Local Embeddings**
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
- **Cons**: Slower than cloud APIs, requires local compute resources
- **When to use**: Production systems, sensitive data, cost-sensitive applications
</details>
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
- Full recomputation required
- High memory usage during build phase
- Excellent recall (95%+)
```bash
# Optimal for most use cases
--backend-name hnsw --graph-degree 32 --build-complexity 64
```
### DiskANN
**Best for**: Performance-critical applications and large datasets - **Production-ready with automatic graph partitioning**
**How it works:**
- **Product Quantization (PQ) + Real-time Reranking**: Uses compressed PQ codes for fast graph traversal, then recomputes exact embeddings for final candidates
- **Automatic Graph Partitioning**: When `is_recompute=True`, automatically partitions large indices and safely removes redundant files to save storage
- **Superior Speed-Accuracy Trade-off**: Faster search than HNSW while maintaining high accuracy
**Trade-offs compared to HNSW:**
-**Faster search latency** (typically 2-8x speedup)
-**Better scaling** for large datasets
-**Smart storage management** with automatic partitioning
-**Better graph locality** with `--ldg-times` parameter for SSD optimization
- ⚠️ **Slightly larger index size** due to PQ tables and graph metadata
```bash
# Recommended for most use cases
--backend-name diskann --graph-degree 32 --build-complexity 64
# For large-scale deployments
--backend-name diskann --graph-degree 64 --build-complexity 128
```
**Performance Benchmark**: Run `python benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
## LLM Selection: Engine and Model Comparison
### LLM Engines
**OpenAI** (`--llm openai`)
- **Pros**: Best quality, consistent performance, no local resources needed
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
**Ollama** (`--llm ollama`)
- **Pros**: Fully local, free, privacy-preserving, good model variety
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
**HuggingFace** (`--llm hf`)
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
- **Cons**: More complex initial setup
- **Models**: `Qwen/Qwen3-1.7B-FP8`
## Parameter Tuning Guide
### Search Complexity Parameters
**`--build-complexity`** (index building)
- Controls thoroughness during index construction
- Higher = better recall but slower build
- Recommendations:
- 32: Quick prototyping
- 64: Balanced (default)
- 128: Production systems
- 256: Maximum quality
**`--search-complexity`** (query time)
- Controls search thoroughness
- Higher = better results but slower
- Recommendations:
- 16: Fast/Interactive search
- 32: High quality with diversity
- 64+: Maximum accuracy
### Top-K Selection
**`--top-k`** (number of retrieved chunks)
- More chunks = better context but slower LLM processing
- Should be always smaller than `--search-complexity`
- Guidelines:
- 10-20: General questions (default: 20)
- 30+: Complex multi-hop reasoning requiring comprehensive context
**Trade-off formula**:
- Retrieval time ∝ log(n) × search_complexity
- LLM processing time ∝ top_k × chunk_size
- Total context = top_k × chunk_size tokens
### Thinking Budget for Reasoning Models
**`--thinking-budget`** (reasoning effort level)
- Controls the computational effort for reasoning models
- Options: `low`, `medium`, `high`
- Guidelines:
- `low`: Fast responses, basic reasoning (default for simple queries)
- `medium`: Balanced speed and reasoning depth
- `high`: Maximum reasoning effort, best for complex analytical questions
- **Supported Models**:
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
- **Example**: `--thinking-budget high` for complex analytical questions
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
**💡 Quick Examples:**
```bash
# OpenAI o-series reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm openai --llm-model o3 --thinking-budget medium
# Ollama reasoning model
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
--index-dir hnswbuild --backend hnsw \
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
```
### Graph Degree (HNSW/DiskANN)
**`--graph-degree`**
- Number of connections per node in the graph
- Higher = better recall but more memory
- HNSW: 16-32 (default: 32)
- DiskANN: 32-128 (default: 64)
## Performance Optimization Checklist
### If Embedding is Too Slow
1. **Switch to smaller model**:
```bash
# From large model
--embedding-model Qwen/Qwen3-Embedding-0.6B
# To small model
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
2. **Limit dataset size for testing**:
```bash
--max-items 1000 # Process first 1k items only
```
3. **Use MLX on Apple Silicon** (optional optimization):
```bash
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
```
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
4. **Use Ollama**
```bash
--embedding-mode ollama --embedding-model nomic-embed-text
```
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
### If Search Quality is Poor
1. **Increase retrieval count**:
```bash
--top-k 30 # Retrieve more candidates
```
2. **Upgrade embedding model**:
```bash
# For English
--embedding-model BAAI/bge-base-en-v1.5
# For multilingual
--embedding-model intfloat/multilingual-e5-large
```
## Understanding the Trade-offs
Every configuration choice involves trade-offs:
| Factor | Small/Fast | Large/Quality |
|--------|------------|---------------|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
| Chunk Size | 512 tokens | 128 tokens |
| Index Type | HNSW | DiskANN |
| LLM | `qwen3:1.7b` | `gpt-4o` |
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
## Deep Dive: Critical Configuration Decisions
### When to Disable Recomputation
LEANN's recomputation feature provides exact distance calculations but can be disabled for extreme QPS requirements:
```bash
--no-recompute # Disable selective recomputation
```
**Trade-offs**:
- **With recomputation** (default): Exact distances, best quality, higher latency, minimal storage (only stores metadata, recomputes embeddings on-demand)
- **Without recomputation**: Must store full embeddings, significantly higher memory and storage usage (10-100x more), but faster search
**Disable when**:
- You have abundant storage and memory
- Need extremely low latency (< 100ms)
- Running a read-heavy workload where storage cost is acceptable
## Further Reading
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)

View File

@@ -1,10 +0,0 @@
# FAQ
## 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

View File

@@ -1,22 +0,0 @@
# ✨ Detailed Features
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
## 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment

View File

@@ -1,75 +0,0 @@
# Normalized Embeddings Support in LEANN
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
## What are Normalized Embeddings?
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
## Automatic Detection
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
1. **Automatically set `distance_metric="cosine"`** if not specified
2. **Show a warning** if you manually specify a different distance metric
3. **Provide optimal search performance** with the correct metric
## Supported Normalized Embedding Models
### OpenAI
All OpenAI text embedding models are normalized:
- `text-embedding-ada-002`
- `text-embedding-3-small`
- `text-embedding-3-large`
### Voyage AI
All Voyage AI embedding models are normalized:
- `voyage-2`
- `voyage-3`
- `voyage-large-2`
- `voyage-multilingual-2`
- `voyage-code-2`
### Cohere
All Cohere embedding models are normalized:
- `embed-english-v3.0`
- `embed-multilingual-v3.0`
- `embed-english-light-v3.0`
- `embed-multilingual-light-v3.0`
## Example Usage
```python
from leann.api import LeannBuilder
# Automatic detection - will use cosine distance
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai"
)
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
# Automatically setting distance_metric='cosine'
# Manual override (not recommended)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
distance_metric="mips" # Will show warning
)
# Warning: Using 'mips' distance metric with normalized embeddings...
```
## Non-Normalized Embeddings
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
## Why This Matters
Using the wrong distance metric with normalized embeddings can lead to:
- **Poor search quality** due to HNSW's early termination with narrow score ranges
- **Incorrect ranking** of search results
- **Suboptimal performance** compared to using the correct metric
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.

View File

@@ -1,21 +0,0 @@
# 📈 Roadmap
## 🎯 Q2 2025
- [X] HNSW backend integration
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
## 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
## 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion

View File

@@ -1,88 +0,0 @@
"""
Simple demo showing basic leann usage
Run: uv run python examples/basic_demo.py
"""
import argparse
from leann import LeannBuilder, LeannChat, LeannSearcher
def main():
parser = argparse.ArgumentParser(
description="Simple demo of Leann with selectable embedding models."
)
parser.add_argument(
"--embedding_model",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
)
args = parser.parse_args()
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
print()
# Sample knowledge base
chunks = [
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
"Natural language processing helps computers understand and generate human language.",
"Computer vision enables machines to interpret and understand visual information from images and videos.",
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
"Big data refers to extremely large datasets that require special tools and techniques to process.",
"Cloud computing provides on-demand access to computing resources over the internet.",
]
print("1. Building index (no embeddings stored)...")
builder = LeannBuilder(
embedding_model=args.embedding_model,
backend_name="hnsw",
)
for chunk in chunks:
builder.add_text(chunk)
builder.build_index("demo_knowledge.leann")
print()
print("2. Searching with real-time embeddings...")
searcher = LeannSearcher("demo_knowledge.leann")
queries = [
"What is machine learning?",
"How does neural network work?",
"Tell me about data processing",
]
for query in queries:
print(f"Query: {query}")
results = searcher.search(query, top_k=2)
for i, result in enumerate(results, 1):
print(f" {i}. Score: {result.score:.3f}")
print(f" Text: {result.text[:100]}...")
print()
print("3. Interactive chat demo:")
print(" (Note: Requires OpenAI API key for real responses)")
chat = LeannChat("demo_knowledge.leann")
# Demo questions
demo_questions: list[str] = [
"What is the difference between machine learning and deep learning?",
"How is data science related to big data?",
]
for question in demo_questions:
print(f" Q: {question}")
response = chat.ask(question)
print(f" A: {response}")
print()
print("Demo completed! Try running:")
print(" uv run python apps/document_rag.py")
if __name__ == "__main__":
main()

View File

@@ -3,15 +3,14 @@
Memory comparison between Faiss HNSW and LEANN HNSW backend Memory comparison between Faiss HNSW and LEANN HNSW backend
""" """
import gc
import logging import logging
import os import os
import subprocess
import sys import sys
import time import time
from pathlib import Path
import psutil import psutil
import gc
import subprocess
from pathlib import Path
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
# Setup logging # Setup logging
@@ -62,7 +61,7 @@ def test_faiss_hnsw():
try: try:
result = subprocess.run( result = subprocess.run(
[sys.executable, "benchmarks/faiss_only.py"], [sys.executable, "examples/faiss_only.py"],
capture_output=True, capture_output=True,
text=True, text=True,
timeout=300, timeout=300,
@@ -84,7 +83,9 @@ def test_faiss_hnsw():
for line in lines: for line in lines:
if "Peak Memory:" in line: if "Peak Memory:" in line:
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip()) peak_memory = float(
line.split("Peak Memory:")[1].split("MB")[0].strip()
)
return {"peak_memory": peak_memory} return {"peak_memory": peak_memory}
@@ -110,12 +111,13 @@ def test_leann_hnsw():
tracker.checkpoint("After imports") tracker.checkpoint("After imports")
from leann.api import LeannBuilder
from llama_index.core import SimpleDirectoryReader from llama_index.core import SimpleDirectoryReader
from leann.api import LeannBuilder, LeannSearcher
# Load and parse documents # Load and parse documents
documents = SimpleDirectoryReader( documents = SimpleDirectoryReader(
"data", "examples/data",
recursive=True, recursive=True,
encoding="utf-8", encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"], required_exts=[".pdf", ".txt", ".md"],
@@ -133,7 +135,6 @@ def test_leann_hnsw():
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.get_content()) all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking") tracker.checkpoint("After text chunking")
@@ -195,14 +196,16 @@ def test_leann_hnsw():
runtime_start_mem = get_memory_usage() runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB") print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory") tracker.checkpoint("Before load memory")
# Load searcher # Load searcher
searcher = LeannSearcher(index_path) searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading") tracker.checkpoint("After searcher loading")
print("Running search queries...") print("Running search queries...")
queries = [ queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发", "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发",
"What is LEANN and how does it work?", "What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容", "华为诺亚方舟实验室的主要研究内容",
] ]
@@ -300,15 +303,21 @@ def main():
print("\nLEANN vs Faiss Performance:") print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"] memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)") print(
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
)
# Storage comparison # Storage comparison
if leann_storage_size > faiss_storage_size: if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size storage_ratio = leann_storage_size / faiss_storage_size
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)") print(
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
)
elif faiss_storage_size > leann_storage_size: elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size storage_ratio = faiss_storage_size / leann_storage_size
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)") print(
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
)
else: else:
print(" Storage Size: similar") print(" Storage Size: similar")
else: else:

View File

@@ -0,0 +1,122 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str = None) -> List[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, dirnames, filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
if part.get_content_type() == "text/html" and not self.include_html:
continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
[File]: {filename}
[From]: {from_addr}
[To]: {to_addr}
[Subject]: {subject}
[Date]: {date}
[EMAIL BODY Start]:
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs

View File

@@ -0,0 +1,192 @@
"""
Mbox parser.
Contains simple parser for mbox files.
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
class MboxReader(BaseReader):
"""
Mbox parser.
Extract messages from mailbox files.
Returns string including date, subject, sender, receiver and
content for each message.
"""
DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\n"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
)
def __init__(
self,
*args: Any,
max_count: int = 0,
message_format: str = DEFAULT_MESSAGE_FORMAT,
**kwargs: Any,
) -> None:
"""Init params."""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError(
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs)
self.max_count = max_count
self.message_format = message_format
def load_data(
self,
file: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse file into string."""
# Import required libraries
import mailbox
from email.parser import BytesParser
from email.policy import default
from bs4 import BeautifulSoup
if fs:
logger.warning(
"fs was specified but MboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
i = 0
results: List[str] = []
# Load file using mailbox
bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
# Iterate through all messages
for _, _msg in enumerate(mbox):
try:
msg: mailbox.mboxMessage = _msg
# Parse multipart messages
if msg.is_multipart():
for part in msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if "attachment" in cdispo:
print(f"Attachment found: {part.get_filename()}")
if ctype == "text/plain" and "attachment" not in cdispo:
content = part.get_payload(decode=True) # decode
break
# Get plain message payload for non-multipart messages
else:
content = msg.get_payload(decode=True)
# Parse message HTML content and remove unneeded whitespace
soup = BeautifulSoup(content)
stripped_content = " ".join(soup.get_text().split())
# Format message to include date, sender, receiver and subject
msg_string = self.message_format.format(
_date=msg["date"],
_from=msg["from"],
_to=msg["to"],
_subject=msg["subject"],
_content=stripped_content,
)
# Add message string to results
results.append(msg_string)
except Exception as e:
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
# Increment counter and return if max count is met
i += 1
if self.max_count > 0 and i >= self.max_count:
break
return [Document(text=result, metadata=extra_info or {}) for result in results]
class EmlxMboxReader(MboxReader):
"""
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory
2. Converting them to mbox format in memory
3. Using the parent MboxReader's parsing logic
"""
def load_data(
self,
directory: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic."""
import tempfile
import os
if fs:
logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
# Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files:
logger.warning(f"No .emlx files found in {directory}")
return []
# Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format
for emlx_file in emlx_files:
try:
# Read the .emlx file
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx format: first line is length, rest is email content
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}")
continue
# Close the temporary file so MboxReader can read it
temp_mbox.close()
try:
# Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs)
finally:
# Clean up temporary file
try:
os.unlink(temp_mbox_path)
except:
pass

View File

@@ -1,11 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Test only Faiss HNSW""" """Test only Faiss HNSW"""
import os
import sys import sys
import time import time
import psutil import psutil
import gc
import os
def get_memory_usage(): def get_memory_usage():
@@ -37,20 +37,20 @@ def main():
import faiss import faiss
except ImportError: except ImportError:
print("Faiss is not installed.") print("Faiss is not installed.")
print( print("Please install it with `uv pip install faiss-cpu`")
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
)
sys.exit(1) sys.exit(1)
from llama_index.core import ( from llama_index.core import (
Settings,
SimpleDirectoryReader, SimpleDirectoryReader,
StorageContext,
VectorStoreIndex, VectorStoreIndex,
StorageContext,
Settings,
node_parser,
Document,
) )
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
tracker = MemoryTracker("Faiss HNSW") tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial") tracker.checkpoint("Initial")
@@ -65,7 +65,7 @@ def main():
tracker.checkpoint("After Faiss index creation") tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader( documents = SimpleDirectoryReader(
"data", "examples/data",
recursive=True, recursive=True,
encoding="utf-8", encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"], required_exts=[".pdf", ".txt", ".md"],
@@ -90,9 +90,8 @@ def main():
vector_store=vector_store, persist_dir="./storage_faiss" vector_store=vector_store, persist_dir="./storage_faiss"
) )
from llama_index.core import load_index_from_storage from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context) index = load_index_from_storage(storage_context=storage_context)
print("Index loaded from ./storage_faiss") print(f"Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index") tracker.checkpoint("After loading existing index")
index_loaded = True index_loaded = True
except Exception as e: except Exception as e:
@@ -100,18 +99,19 @@ def main():
print("Cleaning up corrupted index and building new one...") print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index # Clean up corrupted index
import shutil import shutil
if os.path.exists("./storage_faiss"): if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss") shutil.rmtree("./storage_faiss")
if not index_loaded: if not index_loaded:
print("Building new Faiss HNSW index...") print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example # Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index) vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store) storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents( index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context, transformations=[node_parser] documents,
storage_context=storage_context,
transformations=[node_parser]
) )
tracker.checkpoint("After index building") tracker.checkpoint("After index building")
@@ -124,10 +124,10 @@ def main():
runtime_start_mem = get_memory_usage() runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB") print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory") tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20) query_engine = index.as_query_engine(similarity_top_k=20)
queries = [ queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发", "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面任务令一般在什么城市颁发",
"What is LEANN and how does it work?", "What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容", "华为诺亚方舟实验室的主要研究内容",
] ]
@@ -141,7 +141,7 @@ def main():
runtime_end_mem = get_memory_usage() runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary() peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB") print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB") print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")

View File

@@ -0,0 +1,285 @@
import os
import asyncio
import argparse
try:
import dotenv
dotenv.load_dotenv()
except ModuleNotFoundError:
# python-dotenv is not installed; skip loading environment variables
dotenv = None
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
# dotenv.load_dotenv() # handled above if python-dotenv is available
# Default Chrome profile path
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
"""
Create LEANN index from multiple Chrome profile data sources.
Args:
profile_dirs: List of Path objects pointing to Chrome profile directories
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process per profile
"""
print("Creating LEANN index from multiple Chrome profile data sources...")
# Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Chrome profile directory
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
try:
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_count
)
if documents:
print(f"Loaded {len(documents)} history documents from {profile_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {profile_dir}")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = node.get_content()
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
all_texts.append(text)
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
"""
Create LEANN index from Chrome history data.
Args:
profile_path: Path to the Chrome profile directory (optional, uses default if None)
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process
"""
print("Creating LEANN index from Chrome history data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using ChromeHistoryReader from history_data
from history_data.history import ChromeHistoryReader
reader = ChromeHistoryReader()
documents = reader.load_data(
chrome_profile_path=profile_path,
max_count=max_count
)
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} history documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={
"temperature": 0.0,
"max_tokens": 1000
}
)
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./all_google_new",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')
parser.add_argument('--query', type=str, default=None,
help='Single query to run (default: runs example queries)')
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
help='Automatically find all Chrome profiles (default: True)')
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
print(f"Using Chrome profile: {args.chrome_profile}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Find Chrome profile directories
from history_data.history import ChromeHistoryReader
if args.auto_find_profiles:
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return
else:
# Use single specified profile
profile_path = Path(args.chrome_profile)
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"What websites did I visit about machine learning?",
"Find my search history about programming"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,3 +1,3 @@
from .history import ChromeHistoryReader from .history import ChromeHistoryReader
__all__ = ["ChromeHistoryReader"] __all__ = ['ChromeHistoryReader']

View File

@@ -1,81 +1,77 @@
import os
import sqlite3 import sqlite3
import os
from pathlib import Path from pathlib import Path
from typing import Any from typing import List, Any
from llama_index.core import Document from llama_index.core import Document
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader): class ChromeHistoryReader(BaseReader):
""" """
Chrome browser history reader that extracts browsing data from SQLite database. Chrome browser history reader that extracts browsing data from SQLite database.
Reads Chrome history from the default Chrome profile location and creates documents Reads Chrome history from the default Chrome profile location and creates documents
with embedded metadata similar to the email reader structure. with embedded metadata similar to the email reader structure.
""" """
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize.""" """Initialize."""
pass pass
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]: def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
""" """
Load Chrome history data from the default Chrome profile location. Load Chrome history data from the default Chrome profile location.
Args: Args:
input_dir: Not used for Chrome history (kept for compatibility) input_dir: Not used for Chrome history (kept for compatibility)
**load_kwargs: **load_kwargs:
max_count (int): Maximum amount of history entries to read. max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory. chrome_profile_path (str): Custom path to Chrome profile directory.
""" """
docs: list[Document] = [] docs: List[Document] = []
max_count = load_kwargs.get("max_count", 1000) max_count = load_kwargs.get('max_count', 1000)
chrome_profile_path = load_kwargs.get("chrome_profile_path", None) chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
# Default Chrome profile path on macOS # Default Chrome profile path on macOS
if chrome_profile_path is None: if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser( chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
"~/Library/Application Support/Google/Chrome/Default"
)
history_db_path = os.path.join(chrome_profile_path, "History") history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path): if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}") print(f"Chrome history database not found at: {history_db_path}")
return docs return docs
try: try:
# Connect to the Chrome history database # Connect to the Chrome history database
print(f"Connecting to database: {history_db_path}") print(f"Connecting to database: {history_db_path}")
conn = sqlite3.connect(history_db_path) conn = sqlite3.connect(history_db_path)
cursor = conn.cursor() cursor = conn.cursor()
# Query to get browsing history with metadata (removed created_time column) # Query to get browsing history with metadata (removed created_time column)
query = """ query = """
SELECT SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit, datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url, url,
title, title,
visit_count, visit_count,
typed_count, typed_count,
hidden hidden
FROM urls FROM urls
ORDER BY last_visit_time DESC ORDER BY last_visit_time DESC
""" """
print(f"Executing query on database: {history_db_path}") print(f"Executing query on database: {history_db_path}")
cursor.execute(query) cursor.execute(query)
rows = cursor.fetchall() rows = cursor.fetchall()
print(f"Query returned {len(rows)} rows") print(f"Query returned {len(rows)} rows")
count = 0 count = 0
for row in rows: for row in rows:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
last_visit, url, title, visit_count, typed_count, hidden = row last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text # Create document content with metadata embedded in text
doc_content = f""" doc_content = f"""
[Title]: {title} [Title]: {title}
@@ -84,43 +80,38 @@ class ChromeHistoryReader(BaseReader):
[Visit times]: {visit_count} [Visit times]: {visit_count}
[Typed times]: {typed_count} [Typed times]: {typed_count}
""" """
# Create document with embedded metadata # Create document with embedded metadata
doc = Document(text=doc_content, metadata={"title": title[0:150]}) doc = Document(text=doc_content, metadata={ "title": title[0:150]})
# if len(title) > 150: # if len(title) > 150:
# print(f"Title is too long: {title}") # print(f"Title is too long: {title}")
docs.append(doc) docs.append(doc)
count += 1 count += 1
conn.close() conn.close()
print(f"Loaded {len(docs)} Chrome history documents") print(f"Loaded {len(docs)} Chrome history documents")
except Exception as e: except Exception as e:
print(f"Error reading Chrome history: {e}") print(f"Error reading Chrome history: {e}")
# add you may need to close your browser to make the database file available
# also highlight in red
print(
"\033[91mYou may need to close your browser to make the database file available\033[0m"
)
return docs return docs
return docs return docs
@staticmethod @staticmethod
def find_chrome_profiles() -> list[Path]: def find_chrome_profiles() -> List[Path]:
""" """
Find all Chrome profile directories. Find all Chrome profile directories.
Returns: Returns:
List of Path objects pointing to Chrome profile directories List of Path objects pointing to Chrome profile directories
""" """
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome")) chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
profile_dirs = [] profile_dirs = []
if not chrome_base_path.exists(): if not chrome_base_path.exists():
print(f"Chrome directory not found at: {chrome_base_path}") print(f"Chrome directory not found at: {chrome_base_path}")
return profile_dirs return profile_dirs
# Find all profile directories # Find all profile directories
for profile_dir in chrome_base_path.iterdir(): for profile_dir in chrome_base_path.iterdir():
if profile_dir.is_dir() and profile_dir.name != "System Profile": if profile_dir.is_dir() and profile_dir.name != "System Profile":
@@ -128,59 +119,53 @@ class ChromeHistoryReader(BaseReader):
if history_path.exists(): if history_path.exists():
profile_dirs.append(profile_dir) profile_dirs.append(profile_dir)
print(f"Found Chrome profile: {profile_dir}") print(f"Found Chrome profile: {profile_dir}")
print(f"Found {len(profile_dirs)} Chrome profiles") print(f"Found {len(profile_dirs)} Chrome profiles")
return profile_dirs return profile_dirs
@staticmethod @staticmethod
def export_history_to_file( def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
output_file: str = "chrome_history_export.txt", max_count: int = 1000
):
""" """
Export Chrome history to a text file using the same SQL query format. Export Chrome history to a text file using the same SQL query format.
Args: Args:
output_file: Path to the output file output_file: Path to the output file
max_count: Maximum number of entries to export max_count: Maximum number of entries to export
""" """
chrome_profile_path = os.path.expanduser( chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
"~/Library/Application Support/Google/Chrome/Default"
)
history_db_path = os.path.join(chrome_profile_path, "History") history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path): if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}") print(f"Chrome history database not found at: {history_db_path}")
return return
try: try:
conn = sqlite3.connect(history_db_path) conn = sqlite3.connect(history_db_path)
cursor = conn.cursor() cursor = conn.cursor()
query = """ query = """
SELECT SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit, datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url, url,
title, title,
visit_count, visit_count,
typed_count, typed_count,
hidden hidden
FROM urls FROM urls
ORDER BY last_visit_time DESC ORDER BY last_visit_time DESC
LIMIT ? LIMIT ?
""" """
cursor.execute(query, (max_count,)) cursor.execute(query, (max_count,))
rows = cursor.fetchall() rows = cursor.fetchall()
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, 'w', encoding='utf-8') as f:
for row in rows: for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row last_visit, url, title, visit_count, typed_count, hidden = row
f.write( f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
)
conn.close() conn.close()
print(f"Exported {len(rows)} history entries to {output_file}") print(f"Exported {len(rows)} history entries to {output_file}")
except Exception as e: except Exception as e:
print(f"Error exporting Chrome history: {e}") print(f"Error exporting Chrome history: {e}")

View File

@@ -2,31 +2,30 @@ import json
import os import os
import re import re
import subprocess import subprocess
import sys
import time import time
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import List, Any, Dict, Optional
from llama_index.core import Document from llama_index.core import Document
from llama_index.core.readers.base import BaseReader from llama_index.core.readers.base import BaseReader
from datetime import datetime
class WeChatHistoryReader(BaseReader): class WeChatHistoryReader(BaseReader):
""" """
WeChat chat history reader that extracts chat data from exported JSON files. WeChat chat history reader that extracts chat data from exported JSON files.
Reads WeChat chat history from exported JSON files (from wechat-exporter tool) Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
and creates documents with embedded metadata similar to the Chrome history reader structure. and creates documents with embedded metadata similar to the Chrome history reader structure.
Also includes utilities for automatic WeChat chat history export. Also includes utilities for automatic WeChat chat history export.
""" """
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize.""" """Initialize."""
self.packages_dir = Path(__file__).parent.parent.parent / "packages" self.packages_dir = Path(__file__).parent.parent.parent / "packages"
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter" self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos" self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
def check_wechat_running(self) -> bool: def check_wechat_running(self) -> bool:
"""Check if WeChat is currently running.""" """Check if WeChat is currently running."""
try: try:
@@ -34,30 +33,24 @@ class WeChatHistoryReader(BaseReader):
return result.returncode == 0 return result.returncode == 0
except Exception: except Exception:
return False return False
def install_wechattweak(self) -> bool: def install_wechattweak(self) -> bool:
"""Install WeChatTweak CLI tool.""" """Install WeChatTweak CLI tool."""
try: try:
# Create wechat-exporter directory if it doesn't exist # Create wechat-exporter directory if it doesn't exist
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True) self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli" wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists(): if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...") print("Downloading WeChatTweak CLI...")
subprocess.run( subprocess.run([
[ "curl", "-L", "-o", str(wechattweak_path),
"curl", "https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
"-L", ], check=True)
"-o",
str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
],
check=True,
)
# Make executable # Make executable
wechattweak_path.chmod(0o755) wechattweak_path.chmod(0o755)
# Install WeChatTweak # Install WeChatTweak
print("Installing WeChatTweak...") print("Installing WeChatTweak...")
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True) subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
@@ -65,7 +58,7 @@ class WeChatHistoryReader(BaseReader):
except Exception as e: except Exception as e:
print(f"Error installing WeChatTweak: {e}") print(f"Error installing WeChatTweak: {e}")
return False return False
def restart_wechat(self): def restart_wechat(self):
"""Restart WeChat to apply WeChatTweak.""" """Restart WeChat to apply WeChatTweak."""
try: try:
@@ -76,325 +69,302 @@ class WeChatHistoryReader(BaseReader):
time.sleep(5) # Wait for WeChat to start time.sleep(5) # Wait for WeChat to start
except Exception as e: except Exception as e:
print(f"Error restarting WeChat: {e}") print(f"Error restarting WeChat: {e}")
def check_api_available(self) -> bool: def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available.""" """Check if WeChatTweak API is available."""
try: try:
result = subprocess.run( result = subprocess.run([
["curl", "-s", "http://localhost:48065/wechat/allcontacts"], "curl", "-s", "http://localhost:48065/wechat/allcontacts"
capture_output=True, ], capture_output=True, text=True, timeout=5)
text=True,
timeout=5,
)
return result.returncode == 0 and result.stdout.strip() return result.returncode == 0 and result.stdout.strip()
except Exception: except Exception:
return False return False
def _extract_readable_text(self, content: str) -> str: def _extract_readable_text(self, content: str) -> str:
""" """
Extract readable text from message content, removing XML and system messages. Extract readable text from message content, removing XML and system messages.
Args: Args:
content: The raw message content (can be string or dict) content: The raw message content (can be string or dict)
Returns: Returns:
Cleaned, readable text Cleaned, readable text
""" """
if not content: if not content:
return "" return ""
# Handle dictionary content (like quoted messages) # Handle dictionary content (like quoted messages)
if isinstance(content, dict): if isinstance(content, dict):
# Extract text from dictionary structure # Extract text from dictionary structure
text_parts = [] text_parts = []
if "title" in content: if 'title' in content:
text_parts.append(str(content["title"])) text_parts.append(str(content['title']))
if "quoted" in content: if 'quoted' in content:
text_parts.append(str(content["quoted"])) text_parts.append(str(content['quoted']))
if "content" in content: if 'content' in content:
text_parts.append(str(content["content"])) text_parts.append(str(content['content']))
if "text" in content: if 'text' in content:
text_parts.append(str(content["text"])) text_parts.append(str(content['text']))
if text_parts: if text_parts:
return " | ".join(text_parts) return " | ".join(text_parts)
else: else:
# If we can't extract meaningful text from dict, return empty # If we can't extract meaningful text from dict, return empty
return "" return ""
# Handle string content # Handle string content
if not isinstance(content, str): if not isinstance(content, str):
return "" return ""
# Remove common prefixes like "wxid_xxx:\n" # Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content) clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content) clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If it's just XML or system message, return empty # If it's just XML or system message, return empty
if clean_content.strip().startswith("<") or "recalled a message" in clean_content: if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
return "" return ""
return clean_content.strip() return clean_content.strip()
def _is_text_message(self, content: str) -> bool: def _is_text_message(self, content: str) -> bool:
""" """
Check if a message contains readable text content. Check if a message contains readable text content.
Args: Args:
content: The message content (can be string or dict) content: The message content (can be string or dict)
Returns: Returns:
True if the message contains readable text, False otherwise True if the message contains readable text, False otherwise
""" """
if not content: if not content:
return False return False
# Handle dictionary content # Handle dictionary content
if isinstance(content, dict): if isinstance(content, dict):
# Check if dict has any readable text fields # Check if dict has any readable text fields
text_fields = ["title", "quoted", "content", "text"] text_fields = ['title', 'quoted', 'content', 'text']
for field in text_fields: for field in text_fields:
if content.get(field): if field in content and content[field]:
return True return True
return False return False
# Handle string content # Handle string content
if not isinstance(content, str): if not isinstance(content, str):
return False return False
# Skip image messages (contain XML with img tags) # Skip image messages (contain XML with img tags)
if "<img" in content and "cdnurl" in content: if '<img' in content and 'cdnurl' in content:
return False return False
# Skip emoji messages (contain emoji XML tags) # Skip emoji messages (contain emoji XML tags)
if "<emoji" in content and "productid" in content: if '<emoji' in content and 'productid' in content:
return False return False
# Skip voice messages # Skip voice messages
if "<voice" in content: if '<voice' in content:
return False return False
# Skip video messages # Skip video messages
if "<video" in content: if '<video' in content:
return False return False
# Skip file messages # Skip file messages
if "<appmsg" in content and "appid" in content: if '<appmsg' in content and 'appid' in content:
return False return False
# Skip system messages (like "recalled a message") # Skip system messages (like "recalled a message")
if "recalled a message" in content: if 'recalled a message' in content:
return False return False
# Check if there's actual readable text (not just XML or system messages) # Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content # Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content) clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content) clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If after cleaning we have meaningful text, consider it readable # If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"): if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
return True return True
return False return False
def _concatenate_messages( def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
self, time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
messages: list[dict],
max_length: int = 128,
time_window_minutes: int = 30,
overlap_messages: int = 0,
) -> list[dict]:
""" """
Concatenate messages based on length and time rules. Concatenate messages based on length and time rules.
Args: Args:
messages: List of message dictionaries messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint. max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint. time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups overlap_messages: Number of messages to overlap between consecutive groups
Returns: Returns:
List of concatenated message groups List of concatenated message groups
""" """
if not messages: if not messages:
return [] return []
concatenated_groups = [] concatenated_groups = []
current_group = [] current_group = []
current_length = 0 current_length = 0
last_timestamp = None last_timestamp = None
for message in messages: for message in messages:
# Extract message info # Extract message info
content = message.get("content", "") content = message.get('content', '')
message_text = message.get("message", "") message_text = message.get('message', '')
create_time = message.get("createTime", 0) create_time = message.get('createTime', 0)
message.get("fromUser", "") from_user = message.get('fromUser', '')
message.get("toUser", "") to_user = message.get('toUser', '')
message.get("isSentFromSelf", False) is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text # Extract readable text
readable_text = self._extract_readable_text(content) readable_text = self._extract_readable_text(content)
if not readable_text: if not readable_text:
readable_text = message_text readable_text = message_text
# Skip empty messages # Skip empty messages
if not readable_text.strip(): if not readable_text.strip():
continue continue
# Check time window constraint (only if time_window_minutes != -1) # Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0: if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60 time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes: if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group # Time gap too large, start new group
if current_group: if current_group:
concatenated_groups.append( concatenated_groups.append({
{ 'messages': current_group,
"messages": current_group, 'total_length': current_length,
"total_length": current_length, 'start_time': current_group[0].get('createTime', 0),
"start_time": current_group[0].get("createTime", 0), 'end_time': current_group[-1].get('createTime', 0)
"end_time": current_group[-1].get("createTime", 0), })
}
)
# Keep last few messages for overlap # Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages: if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:] current_group = current_group[-overlap_messages:]
current_length = sum( current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
else: else:
current_group = [] current_group = []
current_length = 0 current_length = 0
# Check length constraint (only if max_length != -1) # Check length constraint (only if max_length != -1)
message_length = len(readable_text) message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group: if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new # Current group would exceed max length, save it and start new
concatenated_groups.append( concatenated_groups.append({
{ 'messages': current_group,
"messages": current_group, 'total_length': current_length,
"total_length": current_length, 'start_time': current_group[0].get('createTime', 0),
"start_time": current_group[0].get("createTime", 0), 'end_time': current_group[-1].get('createTime', 0)
"end_time": current_group[-1].get("createTime", 0), })
}
)
# Keep last few messages for overlap # Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages: if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:] current_group = current_group[-overlap_messages:]
current_length = sum( current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
len(
self._extract_readable_text(msg.get("content", ""))
or msg.get("message", "")
)
for msg in current_group
)
else: else:
current_group = [] current_group = []
current_length = 0 current_length = 0
# Add message to current group # Add message to current group
current_group.append(message) current_group.append(message)
current_length += message_length current_length += message_length
last_timestamp = create_time last_timestamp = create_time
# Add the last group if it exists # Add the last group if it exists
if current_group: if current_group:
concatenated_groups.append( concatenated_groups.append({
{ 'messages': current_group,
"messages": current_group, 'total_length': current_length,
"total_length": current_length, 'start_time': current_group[0].get('createTime', 0),
"start_time": current_group[0].get("createTime", 0), 'end_time': current_group[-1].get('createTime', 0)
"end_time": current_group[-1].get("createTime", 0), })
}
)
return concatenated_groups return concatenated_groups
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str: def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
""" """
Create concatenated content from a group of messages. Create concatenated content from a group of messages.
Args: Args:
message_group: Dictionary containing messages and metadata message_group: Dictionary containing messages and metadata
contact_name: Name of the contact contact_name: Name of the contact
Returns: Returns:
Formatted concatenated content Formatted concatenated content
""" """
messages = message_group["messages"] messages = message_group['messages']
start_time = message_group["start_time"] start_time = message_group['start_time']
end_time = message_group["end_time"] end_time = message_group['end_time']
# Format timestamps # Format timestamps
if start_time: if start_time:
try: try:
start_timestamp = datetime.fromtimestamp(start_time) start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S") start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except (ValueError, OSError): except:
start_time_str = str(start_time) start_time_str = str(start_time)
else: else:
start_time_str = "Unknown" start_time_str = "Unknown"
if end_time: if end_time:
try: try:
end_timestamp = datetime.fromtimestamp(end_time) end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S") end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except (ValueError, OSError): except:
end_time_str = str(end_time) end_time_str = str(end_time)
else: else:
end_time_str = "Unknown" end_time_str = "Unknown"
# Build concatenated message content # Build concatenated message content
message_parts = [] message_parts = []
for message in messages: for message in messages:
content = message.get("content", "") content = message.get('content', '')
message_text = message.get("message", "") message_text = message.get('message', '')
create_time = message.get("createTime", 0) create_time = message.get('createTime', 0)
is_sent_from_self = message.get("isSentFromSelf", False) is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text # Extract readable text
readable_text = self._extract_readable_text(content) readable_text = self._extract_readable_text(content)
if not readable_text: if not readable_text:
readable_text = message_text readable_text = message_text
# Format individual message # Format individual message
if create_time: if create_time:
try: try:
timestamp = datetime.fromtimestamp(create_time) timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS # change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except (ValueError, OSError): except:
time_str = str(create_time) time_str = str(create_time)
else: else:
time_str = "Unknown" time_str = "Unknown"
sender = "[Me]" if is_sent_from_self else "[Contact]" sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}") message_parts.append(f"({time_str}) {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts) concatenated_text = "\n".join(message_parts)
# Create final document content # Create final document content
doc_content = f""" doc_content = f"""
Contact: {contact_name} Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str} Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group["total_length"]} chars): Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text} {concatenated_text}
""" """
# TODO @yichuan give better format and rich info here! # TODO @yichuan give better format and rich info here!
doc_content = f""" doc_content = f"""
{concatenated_text} {concatenated_text}
""" """
return doc_content, contact_name return doc_content, contact_name
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]: def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
""" """
Load WeChat chat history data from exported JSON files. Load WeChat chat history data from exported JSON files.
Args: Args:
input_dir: Directory containing exported WeChat JSON files input_dir: Directory containing exported WeChat JSON files
**load_kwargs: **load_kwargs:
@@ -406,104 +376,97 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
time_window_minutes (int): Time window in minutes to group messages together (default: 30). time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2). overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
""" """
docs: list[Document] = [] docs: List[Document] = []
max_count = load_kwargs.get("max_count", 1000) max_count = load_kwargs.get('max_count', 1000)
wechat_export_dir = load_kwargs.get("wechat_export_dir", None) wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
include_non_text = load_kwargs.get("include_non_text", False) include_non_text = load_kwargs.get('include_non_text', False)
concatenate_messages = load_kwargs.get("concatenate_messages", False) concatenate_messages = load_kwargs.get('concatenate_messages', False)
max_length = load_kwargs.get("max_length", 1000) max_length = load_kwargs.get('max_length', 1000)
time_window_minutes = load_kwargs.get("time_window_minutes", 30) time_window_minutes = load_kwargs.get('time_window_minutes', 30)
# Default WeChat export path # Default WeChat export path
if wechat_export_dir is None: if wechat_export_dir is None:
wechat_export_dir = "./wechat_export_test" wechat_export_dir = "./wechat_export_test"
if not os.path.exists(wechat_export_dir): if not os.path.exists(wechat_export_dir):
print(f"WeChat export directory not found at: {wechat_export_dir}") print(f"WeChat export directory not found at: {wechat_export_dir}")
return docs return docs
try: try:
# Find all JSON files in the export directory # Find all JSON files in the export directory
json_files = list(Path(wechat_export_dir).glob("*.json")) json_files = list(Path(wechat_export_dir).glob("*.json"))
print(f"Found {len(json_files)} WeChat chat history files") print(f"Found {len(json_files)} WeChat chat history files")
count = 0 count = 0
for json_file in json_files: for json_file in json_files:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
try: try:
with open(json_file, encoding="utf-8") as f: with open(json_file, 'r', encoding='utf-8') as f:
chat_data = json.load(f) chat_data = json.load(f)
# Extract contact name from filename # Extract contact name from filename
contact_name = json_file.stem contact_name = json_file.stem
if concatenate_messages: if concatenate_messages:
# Filter messages to only include readable text messages # Filter messages to only include readable text messages
readable_messages = [] readable_messages = []
for message in chat_data: for message in chat_data:
try: try:
content = message.get("content", "") content = message.get('content', '')
if not include_non_text and not self._is_text_message(content): if not include_non_text and not self._is_text_message(content):
continue continue
readable_text = self._extract_readable_text(content) readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text: if not readable_text and not include_non_text:
continue continue
readable_messages.append(message) readable_messages.append(message)
except Exception as e: except Exception as e:
print(f"Error processing message in {json_file}: {e}") print(f"Error processing message in {json_file}: {e}")
continue continue
# Concatenate messages based on rules # Concatenate messages based on rules
message_groups = self._concatenate_messages( message_groups = self._concatenate_messages(
readable_messages, readable_messages,
max_length=max_length, max_length=-1,
time_window_minutes=time_window_minutes, time_window_minutes=-1,
overlap_messages=0, # No overlap between groups overlap_messages=0 # Keep 2 messages overlap between groups
) )
# Create documents from concatenated groups # Create documents from concatenated groups
for message_group in message_groups: for message_group in message_groups:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
doc_content, contact_name = self._create_concatenated_content( doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
message_group, contact_name doc = Document(text=doc_content, metadata={"contact_name": contact_name})
)
doc = Document(
text=doc_content,
metadata={"contact_name": contact_name},
)
docs.append(doc) docs.append(doc)
count += 1 count += 1
print( print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
)
else: else:
# Original single-message processing # Original single-message processing
for message in chat_data: for message in chat_data:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
# Extract message information # Extract message information
message.get("fromUser", "") from_user = message.get('fromUser', '')
message.get("toUser", "") to_user = message.get('toUser', '')
content = message.get("content", "") content = message.get('content', '')
message_text = message.get("message", "") message_text = message.get('message', '')
create_time = message.get("createTime", 0) create_time = message.get('createTime', 0)
is_sent_from_self = message.get("isSentFromSelf", False) is_sent_from_self = message.get('isSentFromSelf', False)
# Handle content that might be dict or string # Handle content that might be dict or string
try: try:
# Check if this is a readable text message # Check if this is a readable text message
if not include_non_text and not self._is_text_message(content): if not include_non_text and not self._is_text_message(content):
continue continue
# Extract readable text # Extract readable text
readable_text = self._extract_readable_text(content) readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text: if not readable_text and not include_non_text:
@@ -512,17 +475,17 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
# Skip messages that cause processing errors # Skip messages that cause processing errors
print(f"Error processing message in {json_file}: {e}") print(f"Error processing message in {json_file}: {e}")
continue continue
# Convert timestamp to readable format # Convert timestamp to readable format
if create_time: if create_time:
try: try:
timestamp = datetime.fromtimestamp(create_time) timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except (ValueError, OSError): except:
time_str = str(create_time) time_str = str(create_time)
else: else:
time_str = "Unknown" time_str = "Unknown"
# Create document content with metadata header and contact info # Create document content with metadata header and contact info
doc_content = f""" doc_content = f"""
Contact: {contact_name} Contact: {contact_name}
@@ -530,66 +493,57 @@ Is sent from self: {is_sent_from_self}
Time: {time_str} Time: {time_str}
Message: {readable_text if readable_text else message_text} Message: {readable_text if readable_text else message_text}
""" """
# Create document with embedded metadata # Create document with embedded metadata
doc = Document( doc = Document(text=doc_content, metadata={})
text=doc_content, metadata={"contact_name": contact_name}
)
docs.append(doc) docs.append(doc)
count += 1 count += 1
except Exception as e: except Exception as e:
print(f"Error reading {json_file}: {e}") print(f"Error reading {json_file}: {e}")
continue continue
print(f"Loaded {len(docs)} WeChat chat documents") print(f"Loaded {len(docs)} WeChat chat documents")
except Exception as e: except Exception as e:
print(f"Error reading WeChat history: {e}") print(f"Error reading WeChat history: {e}")
return docs return docs
return docs return docs
@staticmethod @staticmethod
def find_wechat_export_dirs() -> list[Path]: def find_wechat_export_dirs() -> List[Path]:
""" """
Find all WeChat export directories. Find all WeChat export directories.
Returns: Returns:
List of Path objects pointing to WeChat export directories List of Path objects pointing to WeChat export directories
""" """
export_dirs = [] export_dirs = []
# Look for common export directory names # Look for common export directory names
possible_dirs = [ possible_dirs = [
Path("./wechat_export_test"),
Path("./wechat_export"), Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"), Path("./wechat_chat_history"),
Path("./chat_export"), Path("./chat_export")
] ]
for export_dir in possible_dirs: for export_dir in possible_dirs:
if export_dir.exists() and export_dir.is_dir(): if export_dir.exists() and export_dir.is_dir():
json_files = list(export_dir.glob("*.json")) json_files = list(export_dir.glob("*.json"))
if json_files: if json_files:
export_dirs.append(export_dir) export_dirs.append(export_dir)
print( print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
)
print(f"Found {len(export_dirs)} WeChat export directories") print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs return export_dirs
@staticmethod @staticmethod
def export_chat_to_file( def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
output_file: str = "wechat_chat_export.txt",
max_count: int = 1000,
export_dir: str | None = None,
include_non_text: bool = False,
):
""" """
Export WeChat chat history to a text file. Export WeChat chat history to a text file.
Args: Args:
output_file: Path to the output file output_file: Path to the output file
max_count: Maximum number of entries to export max_count: Maximum number of entries to export
@@ -598,36 +552,36 @@ Message: {readable_text if readable_text else message_text}
""" """
if export_dir is None: if export_dir is None:
export_dir = "./wechat_export_test" export_dir = "./wechat_export_test"
if not os.path.exists(export_dir): if not os.path.exists(export_dir):
print(f"WeChat export directory not found at: {export_dir}") print(f"WeChat export directory not found at: {export_dir}")
return return
try: try:
json_files = list(Path(export_dir).glob("*.json")) json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, 'w', encoding='utf-8') as f:
count = 0 count = 0
for json_file in json_files: for json_file in json_files:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
try: try:
with open(json_file, encoding="utf-8") as json_f: with open(json_file, 'r', encoding='utf-8') as json_f:
chat_data = json.load(json_f) chat_data = json.load(json_f)
contact_name = json_file.stem contact_name = json_file.stem
f.write(f"\n=== Chat with {contact_name} ===\n") f.write(f"\n=== Chat with {contact_name} ===\n")
for message in chat_data: for message in chat_data:
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
from_user = message.get("fromUser", "") from_user = message.get('fromUser', '')
content = message.get("content", "") content = message.get('content', '')
message_text = message.get("message", "") message_text = message.get('message', '')
create_time = message.get("createTime", 0) create_time = message.get('createTime', 0)
# Skip non-text messages unless requested # Skip non-text messages unless requested
if not include_non_text: if not include_non_text:
reader = WeChatHistoryReader() reader = WeChatHistoryReader()
@@ -637,90 +591,83 @@ Message: {readable_text if readable_text else message_text}
if not readable_text: if not readable_text:
continue continue
message_text = readable_text message_text = readable_text
if create_time: if create_time:
try: try:
timestamp = datetime.fromtimestamp(create_time) timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except (ValueError, OSError): except:
time_str = str(create_time) time_str = str(create_time)
else: else:
time_str = "Unknown" time_str = "Unknown"
f.write(f"[{time_str}] {from_user}: {message_text}\n") f.write(f"[{time_str}] {from_user}: {message_text}\n")
count += 1 count += 1
except Exception as e: except Exception as e:
print(f"Error processing {json_file}: {e}") print(f"Error processing {json_file}: {e}")
continue continue
print(f"Exported {count} chat entries to {output_file}") print(f"Exported {count} chat entries to {output_file}")
except Exception as e: except Exception as e:
print(f"Error exporting WeChat chat history: {e}") print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None: def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
""" """
Export WeChat chat history using wechat-exporter tool. Export WeChat chat history using wechat-exporter tool.
Args: Args:
export_dir: Directory to save exported chat history export_dir: Directory to save exported chat history
Returns: Returns:
Path to export directory if successful, None otherwise Path to export directory if successful, None otherwise
""" """
try: try:
import subprocess import subprocess
import sys import sys
# Create export directory # Create export directory
export_path = Path(export_dir) export_path = Path(export_dir)
export_path.mkdir(exist_ok=True) export_path.mkdir(exist_ok=True)
print(f"Exporting WeChat chat history to {export_path}...") print(f"Exporting WeChat chat history to {export_path}...")
# Check if wechat-exporter directory exists # Check if wechat-exporter directory exists
if not self.wechat_exporter_dir.exists(): if not self.wechat_exporter_dir.exists():
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}") print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
return None return None
# Install requirements if needed # Install requirements if needed
requirements_file = self.wechat_exporter_dir / "requirements.txt" requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists(): if requirements_file.exists():
print("Installing wechat-exporter requirements...") print("Installing wechat-exporter requirements...")
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True) subprocess.run([
"uv", "pip", "install", "-r", str(requirements_file)
], check=True)
# Run the export command # Run the export command
print("Running wechat-exporter...") print("Running wechat-exporter...")
result = subprocess.run( result = subprocess.run([
[ sys.executable, str(self.wechat_exporter_dir / "main.py"),
sys.executable, "export-all", str(export_path)
str(self.wechat_exporter_dir / "main.py"), ], capture_output=True, text=True, check=True)
"export-all",
str(export_path),
],
capture_output=True,
text=True,
check=True,
)
print("Export command output:") print("Export command output:")
print(result.stdout) print(result.stdout)
if result.stderr: if result.stderr:
print("Export errors:") print("Export errors:")
print(result.stderr) print(result.stderr)
# Check if export was successful # Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")): if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json")) json_files = list(export_path.glob("*.json"))
print( print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
f"Successfully exported {len(json_files)} chat history files to {export_path}"
)
return export_path return export_path
else: else:
print("Export completed but no JSON files found") print("Export completed but no JSON files found")
return None return None
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Export command failed: {e}") print(f"Export command failed: {e}")
print(f"Command output: {e.stdout}") print(f"Command output: {e.stdout}")
@@ -731,18 +678,18 @@ Message: {readable_text if readable_text else message_text}
print("Please ensure WeChat is running and WeChatTweak is installed.") print("Please ensure WeChat is running and WeChatTweak is installed.")
return None return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]: def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
""" """
Find existing WeChat exports or create new ones. Find existing WeChat exports or create new ones.
Args: Args:
export_dir: Directory to save exported chat history if needed export_dir: Directory to save exported chat history if needed
Returns: Returns:
List of Path objects pointing to WeChat export directories List of Path objects pointing to WeChat export directories
""" """
export_dirs = [] export_dirs = []
# Look for existing exports in common locations # Look for existing exports in common locations
possible_export_dirs = [ possible_export_dirs = [
Path("./wechat_database_export"), Path("./wechat_database_export"),
@@ -750,25 +697,23 @@ Message: {readable_text if readable_text else message_text}
Path("./wechat_export"), Path("./wechat_export"),
Path("./wechat_export_direct"), Path("./wechat_export_direct"),
Path("./wechat_chat_history"), Path("./wechat_chat_history"),
Path("./chat_export"), Path("./chat_export")
] ]
for export_dir_path in possible_export_dirs: for export_dir_path in possible_export_dirs:
if export_dir_path.exists() and any(export_dir_path.glob("*.json")): if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
export_dirs.append(export_dir_path) export_dirs.append(export_dir_path)
print(f"Found existing export: {export_dir_path}") print(f"Found existing export: {export_dir_path}")
# If no existing exports, try to export automatically # If no existing exports, try to export automatically
if not export_dirs: if not export_dirs:
print("No existing WeChat exports found. Starting direct export...") print("No existing WeChat exports found. Starting direct export...")
# Try to export using wechat-exporter # Try to export using wechat-exporter
exported_path = self.export_wechat_chat_history(export_dir) exported_path = self.export_wechat_chat_history(export_dir)
if exported_path: if exported_path:
export_dirs = [exported_path] export_dirs = [exported_path]
else: else:
print( print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
) return export_dirs
return export_dirs

View File

@@ -0,0 +1,288 @@
import os
import sys
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
# Auto-detect user's mail path
def get_mail_path():
"""Get the mail path for the current user"""
home_dir = os.path.expanduser("~")
return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from multiple mail data sources.
Args:
messages_dirs: List of Path objects pointing to Messages directories
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process per directory
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from multiple mail data sources...")
# Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Messages directory
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
try:
documents = reader.load_data(messages_dir)
if documents:
print(f"Loaded {len(documents)} email documents from {messages_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {messages_dir}")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = node.get_content()
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
all_texts.append(text)
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from mail data.
Args:
mail_path: Path to the mail directory
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from mail data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using EmlxReader from LEANN_email_reader
from examples.email_data.LEANN_email_reader import EmlxReader
reader = EmlxReader(include_html=include_html)
# from email_data.email import EmlxMboxReader
# from pathlib import Path
# reader = EmlxMboxReader()
documents = reader.load_data(Path(mail_path))
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path,
llm_config={"type": "openai", "model": "gpt-4o"})
print(f"You: {query}")
import time
start_time = time.time()
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=12,
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
help='Single query to run (default: runs example queries)')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
help='Embedding model to use (default: facebook/contriever)')
args = parser.parse_args()
print(f"args: {args}")
# Automatically find all Messages directories under the current user's Mail directory
from examples.email_data.LEANN_email_reader import find_all_messages_directories
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
print('len(messages_dirs): ', len(messages_dirs))
if not messages_dirs:
print("No Messages directories found. Exiting.")
return
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
print(f"Index directory: {INDEX_DIR}")
print(f"Found {len(messages_dirs)} Messages directories.")
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,115 @@
import argparse
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import dotenv
from leann.api import LeannBuilder, LeannChat
from pathlib import Path
dotenv.load_dotenv()
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
documents = SimpleDirectoryReader(
args.data_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
print(f"\n[PHASE 2] Starting Leann chat session...")
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
llm_config = {"type": "openai", "model": "gpt-4o"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Leann Chat with various LLM backends."
)
parser.add_argument(
"--llm",
type=str,
default="hf",
choices=["simulated", "ollama", "hf", "openai"],
help="The LLM backend to use.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-0.6B",
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
)
parser.add_argument(
"--host",
type=str,
default="http://localhost:11434",
help="The host for the Ollama API.",
)
parser.add_argument(
"--index-dir",
type=str,
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
parser.add_argument(
"--data-dir",
type=str,
default="examples/data",
help="Directory containing documents to index (PDF, TXT, MD files).",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -5,21 +5,24 @@ It correctly compares results by fetching the text content for both the new sear
results and the golden standard results, making the comparison robust to ID changes. results and the golden standard results, making the comparison robust to ID changes.
""" """
import argparse
import json import json
import sys import argparse
import time import time
from pathlib import Path from pathlib import Path
import sys
import numpy as np import numpy as np
from leann.api import LeannBuilder, LeannSearcher from typing import List
from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path, download_embeddings: bool = False): def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub.""" """Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists(): if not data_root.exists():
print(f"Data directory '{data_root}' not found.") print(f"Data directory '{data_root}' not found.")
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)") print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try: try:
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@@ -60,7 +63,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
sys.exit(1) sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None): def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically.""" """Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings" embeddings_dir = data_root / "embeddings"
@@ -98,7 +101,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = No
# --- Helper Function to get Golden Passages --- # --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set: def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
""" """
Retrieves the text for golden passage IDs directly from the LeannSearcher's Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager. passage manager.
@@ -110,20 +113,24 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
passage_data = searcher.passage_manager.get_passage(str(gid)) passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"]) golden_texts.add(passage_data["text"])
except KeyError: except KeyError:
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.") print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts return golden_texts
def load_queries(file_path: Path) -> list[str]: def load_queries(file_path: Path) -> List[str]:
queries = [] queries = []
with open(file_path, encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
for line in f: for line in f:
data = json.loads(line) data = json.loads(line)
queries.append(data["query"]) queries.append(data["query"])
return queries return queries
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"): def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
""" """
Build a LEANN index from pre-computed embeddings. Build a LEANN index from pre-computed embeddings.
@@ -166,7 +173,9 @@ def build_index_from_embeddings(embeddings_file: str, output_path: str, backend:
def main(): def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.") parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument( parser.add_argument(
"index_path", "index_path",
type=str, type=str,
@@ -193,22 +202,26 @@ def main():
parser.add_argument( parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate." "--num-queries", type=int, default=10, help="Number of queries to evaluate."
) )
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.") parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument( parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW." "--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
) )
args = parser.parse_args() args = parser.parse_args()
# --- Path Configuration --- # --- Path Configuration ---
# Assumes a project structure where the script is in 'benchmarks/' # Assumes a project structure where the script is in 'examples/'
# and evaluation data is in 'benchmarks/data/'. # and data is in 'data/' at the project root.
script_dir = Path(__file__).resolve().parent project_root = Path(__file__).resolve().parent.parent
data_root = script_dir / "data" data_root = project_root / "data"
# Download data based on mode # Download data based on mode
if args.mode == "build": if args.mode == "build":
# For building mode, we need embeddings # For building mode, we need embeddings
download_data_if_needed(data_root, download_embeddings=False) # Basic data first download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings # Auto-detect dataset type and download embeddings
if args.embeddings_file: if args.embeddings_file:
@@ -249,7 +262,9 @@ def main():
print(f"Index built successfully: {built_index_path}") print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation # Ask if user wants to run evaluation
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower() eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y": if eval_response != "y":
print("Index building complete. Exiting.") print("Index building complete. Exiting.")
return return
@@ -278,9 +293,11 @@ def main():
break break
if not args.index_path: if not args.index_path:
print("No indices found. The data download should have included pre-built indices.")
print( print(
"Please check the benchmarks/data/indices/ directory or provide --index-path manually." "No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
) )
sys.exit(1) sys.exit(1)
@@ -293,10 +310,14 @@ def main():
else: else:
# Fallback: try to infer from the index directory name # Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name dataset_type = Path(args.index_path).name
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.") print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl" queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json" golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}") print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}") print(f"INFO: Using queries file: {queries_file}")
@@ -306,7 +327,7 @@ def main():
searcher = LeannSearcher(args.index_path) searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file) queries = load_queries(queries_file)
with open(golden_results_file) as f: with open(golden_results_file, "r") as f:
golden_results_data = json.load(f) golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries)) num_eval_queries = min(args.num_queries, len(queries))
@@ -318,7 +339,9 @@ def main():
for i in range(num_eval_queries): for i in range(num_eval_queries):
start_time = time.time() start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search) new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time) search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content # Correct Recall Calculation: Based on TEXT content

View File

@@ -0,0 +1,319 @@
import os
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any, Optional
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
import requests
import time
dotenv.load_dotenv()
# Default WeChat export directory
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path],
index_path: str = "wechat_history_index.leann",
max_count: int = -1,
):
"""
Create LEANN index from multiple WeChat export data sources.
Args:
export_dirs: List of Path objects pointing to WeChat export directories
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process per export
"""
print("Creating LEANN index from multiple WeChat export data sources...")
# Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each WeChat export directory
for i, export_dir in enumerate(export_dirs):
print(
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try:
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_count,
concatenate_messages=True, # Disable concatenation - one message per document
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
all_texts.append(text)
print(
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
)
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="Qwen/Qwen3-Embedding-0.6B",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
def create_leann_index(
export_dir: str = None,
index_path: str = "wechat_history_index.leann",
max_count: int = 1000,
):
"""
Create LEANN index from WeChat chat history data.
Args:
export_dir: Path to the WeChat export directory (optional, uses default if None)
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process
"""
print("Creating LEANN index from WeChat chat history data...")
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Load documents using WeChatHistoryReader from history_data
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
documents = reader.load_data(
wechat_export_dir=export_dir,
max_count=max_count,
concatenate_messages=False, # Disable concatenation - one message per document
)
if not documents:
print("No documents loaded. Exiting.")
return None
print(f"Loaded {len(documents)} chat documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=16,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann: {chat_response}")
async def main():
"""Main function with integrated WeChat export functionality."""
# Parse command line arguments
parser = argparse.ArgumentParser(
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
)
parser.add_argument(
"--export-dir",
type=str,
default=DEFAULT_WECHAT_EXPORT_DIR,
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
)
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_magic_test_11Debug_new",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=50,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--force-export",
action="store_true",
default=False,
help="Force re-export of WeChat data even if exports exist",
)
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
print(f"Using WeChat export directory: {args.export_dir}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Initialize WeChat reader with export capabilities
from history_data.wechat_history import WeChatHistoryReader
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Exiting.")
return
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_wechat_exports(
export_dirs, INDEX_PATH, max_count=args.max_entries
)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
]
for query in queries:
print("\n" + "=" * 60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1 @@

View File

@@ -1 +1 @@
# This file makes the directory a Python package # This file makes the directory a Python package

View File

@@ -1,7 +1 @@
from . import diskann_backend as diskann_backend from . import diskann_backend
from . import graph_partition
# Export main classes and functions
from .graph_partition import GraphPartitioner, partition_graph
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]

View File

@@ -1,20 +1,20 @@
import contextlib import numpy as np
import logging
import os import os
import struct import struct
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Dict, Any, List, Literal, Optional
import contextlib
import numpy as np import logging
import psutil
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendBuilderInterface,
LeannBackendFactoryInterface, LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface, LeannBackendSearcherInterface,
) )
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -85,43 +85,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
f.write(data.tobytes()) f.write(data.tobytes())
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
"""
Calculate smart memory configuration for DiskANN based on data size and system specs.
Args:
data: The embedding data array
Returns:
tuple: (search_memory_maximum, build_memory_maximum) in GB
"""
num_vectors, dim = data.shape
# Calculate embedding storage size
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
embedding_size_gb = embedding_size_bytes / (1024**3)
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
# This controls Product Quantization size - smaller means more compression
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
# build_memory_maximum: Based on available system RAM for sharding control
# This controls how much memory DiskANN uses during index construction
available_memory_gb = psutil.virtual_memory().available / (1024**3)
total_memory_gb = psutil.virtual_memory().total / (1024**3)
# Use 50% of available memory, but at least 2GB and at most 75% of total
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
logger.info(
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
)
return search_memory_gb, build_memory_gb
@register_backend("diskann") @register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface): class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -137,72 +100,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""
Safely cleanup files after partition.
In partition mode, C++ doesn't read _disk.index content,
so we can delete it if all derived files exist.
"""
disk_index_file = index_dir / f"{index_prefix}_disk.index"
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
# Required files that C++ partition mode needs
# Note: C++ generates these with _disk.index suffix
disk_suffix = "_disk.index"
required_files = [
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
f"{index_prefix}_pq_pivots.bin", # PQ table
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
]
# Check if all required files exist
missing_files = []
for filename in required_files:
file_path = index_dir / filename
if not file_path.exists():
missing_files.append(filename)
if missing_files:
logger.warning(
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
)
logger.info("Keeping all original files for safety")
return
# Calculate space savings
space_saved = 0
files_to_delete = []
if disk_index_file.exists():
space_saved += disk_index_file.stat().st_size
files_to_delete.append(disk_index_file)
if beam_search_file.exists():
space_saved += beam_search_file.stat().st_size
files_to_delete.append(beam_search_file)
# Safe to delete!
for file_to_delete in files_to_delete:
try:
os.remove(file_to_delete)
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
except Exception as e:
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
if space_saved > 0:
space_saved_mb = space_saved / (1024 * 1024)
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
# Show what files are kept
logger.info("📁 Kept essential files for partition mode:")
for filename in required_files:
file_path = index_dir / filename
if file_path.exists():
size_mb = file_path.stat().st_size / (1024 * 1024)
logger.info(f" - {filename} ({size_mb:.1f} MB)")
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
@@ -216,17 +114,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
# Extract is_recompute from nested backend_kwargs if needed
is_recompute = build_kwargs.get("is_recompute", False)
if not is_recompute and "backend_kwargs" in build_kwargs:
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
# Flatten all backend_kwargs parameters to top level for compatibility
if "backend_kwargs" in build_kwargs:
nested_params = build_kwargs.pop("backend_kwargs")
build_kwargs.update(nested_params)
metric_enum = _get_diskann_metrics().get( metric_enum = _get_diskann_metrics().get(
build_kwargs.get("distance_metric", "mips").lower() build_kwargs.get("distance_metric", "mips").lower()
) )
@@ -235,16 +122,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'." f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
) )
# Calculate smart memory configuration if not explicitly provided
if (
"search_memory_maximum" not in build_kwargs
or "build_memory_maximum" not in build_kwargs
):
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
else:
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
try: try:
from . import _diskannpy as diskannpy # type: ignore from . import _diskannpy as diskannpy # type: ignore
@@ -255,36 +132,12 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_prefix, index_prefix,
build_kwargs.get("complexity", 64), build_kwargs.get("complexity", 64),
build_kwargs.get("graph_degree", 32), build_kwargs.get("graph_degree", 32),
build_kwargs.get("search_memory_maximum", smart_search_mem), build_kwargs.get("search_memory_maximum", 4.0),
build_kwargs.get("build_memory_maximum", smart_build_mem), build_kwargs.get("build_memory_maximum", 8.0),
build_kwargs.get("num_threads", 8), build_kwargs.get("num_threads", 8),
build_kwargs.get("pq_disk_bytes", 0), build_kwargs.get("pq_disk_bytes", 0),
"", "",
) )
# Auto-partition if is_recompute is enabled
if build_kwargs.get("is_recompute", False):
logger.info("is_recompute=True, starting automatic graph partitioning...")
from .graph_partition import partition_graph
# Partition the index using absolute paths
# Convert to absolute paths to avoid issues with working directory changes
absolute_index_dir = Path(index_dir).resolve()
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
disk_graph_path, partition_bin_path = partition_graph(
index_prefix_path=absolute_index_prefix_path,
output_dir=str(absolute_index_dir),
partition_prefix=index_prefix,
)
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
self._safe_cleanup_after_partition(index_dir, index_prefix)
logger.info("✅ Graph partitioning completed successfully!")
logger.info(f" - Disk graph: {disk_graph_path}")
logger.info(f" - Partition file: {partition_bin_path}")
finally: finally:
temp_data_file = index_dir / data_filename temp_data_file = index_dir / data_filename
if temp_data_file.exists(): if temp_data_file.exists():
@@ -311,69 +164,18 @@ class DiskannSearcher(BaseSearcher):
self.num_threads = kwargs.get("num_threads", 8) self.num_threads = kwargs.get("num_threads", 8)
# For DiskANN, we need to reinitialize the index when zmq_port changes fake_zmq_port = 6666
# Store the initialization parameters for later use full_index_prefix = str(self.index_dir / self.index_path.stem)
# Note: C++ load method expects the BASE path (without _disk.index suffix) self._index = diskannpy.StaticDiskFloatIndex(
# C++ internally constructs: index_prefix + "_disk.index" metric_enum,
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test" full_index_prefix,
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test self.num_threads,
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path) kwargs.get("num_nodes_to_cache", 0),
1,
# Auto-detect partition files and set partition_prefix fake_zmq_port, # Initial port, can be updated at runtime
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index" "",
partition_bin_file = self.index_dir / f"{index_name}_partition.bin" "",
)
partition_prefix = ""
if partition_graph_file.exists() and partition_bin_file.exists():
# C++ expects full path prefix, not just filename
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
else:
logger.debug("No partition files detected, using standard index files")
self._init_params = {
"metric_enum": metric_enum,
"full_index_prefix": full_index_prefix,
"num_threads": self.num_threads,
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
"cache_mechanism": 1,
"pq_prefix": "",
"partition_prefix": partition_prefix,
}
# Log partition configuration for debugging
if partition_prefix:
logger.info(
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
)
self._diskannpy = diskannpy
self._current_zmq_port = None
self._index = None
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
def _ensure_index_loaded(self, zmq_port: int):
"""Ensure the index is loaded with the correct zmq_port."""
if self._index is None or self._current_zmq_port != zmq_port:
# Need to (re)load the index with the correct zmq_port
with suppress_cpp_output_if_needed():
if self._index is not None:
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
else:
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
self._index = self._diskannpy.StaticDiskFloatIndex(
self._init_params["metric_enum"],
self._init_params["full_index_prefix"],
self._init_params["num_threads"],
self._init_params["num_nodes_to_cache"],
self._init_params["cache_mechanism"],
zmq_port,
self._init_params["pq_prefix"],
self._init_params["partition_prefix"],
)
self._current_zmq_port = zmq_port
def search( def search(
self, self,
@@ -388,7 +190,7 @@ class DiskannSearcher(BaseSearcher):
batch_recompute: bool = False, batch_recompute: bool = False,
dedup_node_dis: bool = False, dedup_node_dis: bool = False,
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> Dict[str, Any]:
""" """
Search for nearest neighbors using DiskANN index. Search for nearest neighbors using DiskANN index.
@@ -411,15 +213,18 @@ class DiskannSearcher(BaseSearcher):
Returns: Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray) Dict with 'labels' (list of lists) and 'distances' (ndarray)
""" """
# Handle zmq_port compatibility: Ensure index is loaded with correct port # Handle zmq_port compatibility: DiskANN can now update port at runtime
if recompute_embeddings: if recompute_embeddings:
if zmq_port is None: if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True") raise ValueError(
self._ensure_index_loaded(zmq_port) "zmq_port must be provided if recompute_embeddings is True"
else: )
# If not recomputing, we still need an index, use a default port current_port = self._index.get_zmq_port()
if self._index is None: if zmq_port != current_port:
self._ensure_index_loaded(6666) # Default port when not recomputing logger.debug(
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
)
self._index.set_zmq_port(zmq_port)
# DiskANN doesn't support "proportional" strategy # DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional": if pruning_strategy == "proportional":
@@ -437,8 +242,6 @@ class DiskannSearcher(BaseSearcher):
use_global_pruning = True use_global_pruning = True
# Perform search with suppressed C++ output based on log level # Perform search with suppressed C++ output based on log level
use_deferred_fetch = kwargs.get("USE_DEFERRED_FETCH", True)
recompute_neighors = False
with suppress_cpp_output_if_needed(): with suppress_cpp_output_if_needed():
labels, distances = self._index.batch_search( labels, distances = self._index.batch_search(
query, query,
@@ -447,37 +250,17 @@ class DiskannSearcher(BaseSearcher):
complexity, complexity,
beam_width, beam_width,
self.num_threads, self.num_threads,
use_deferred_fetch, kwargs.get("USE_DEFERRED_FETCH", False),
kwargs.get("skip_search_reorder", False), kwargs.get("skip_search_reorder", False),
recompute_neighors, recompute_embeddings,
dedup_node_dis, dedup_node_dis,
prune_ratio, prune_ratio,
batch_recompute, batch_recompute,
use_global_pruning, use_global_pruning,
) )
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels] string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}
def cleanup(self):
"""Cleanup DiskANN-specific resources including C++ index."""
# Call parent cleanup first
super().cleanup()
# Delete the C++ index to trigger destructors
try:
if hasattr(self, "_index") and self._index is not None:
del self._index
self._index = None
self._current_zmq_port = None
except Exception:
pass
# Force garbage collection to ensure C++ objects are destroyed
try:
import gc
gc.collect()
except Exception:
pass

View File

@@ -3,17 +3,16 @@ DiskANN-specific embedding server
""" """
import argparse import argparse
import json
import logging
import os
import sys
import threading import threading
import time import time
import os
import zmq
import numpy as np
import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import sys
import numpy as np import logging
import zmq
# Set up logging based on environment variable # Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -37,7 +36,6 @@ def create_diskann_embedding_server(
zmq_port: int = 5555, zmq_port: int = 5555,
model_name: str = "sentence-transformers/all-mpnet-base-v2", model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
distance_metric: str = "l2",
): ):
""" """
Create and start a ZMQ-based embedding server for DiskANN backend. Create and start a ZMQ-based embedding server for DiskANN backend.
@@ -52,8 +50,8 @@ def create_diskann_embedding_server(
sys.path.insert(0, str(leann_core_path)) sys.path.insert(0, str(leann_core_path))
try: try:
from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
logger.info("Successfully imported unified embedding computation module") logger.info("Successfully imported unified embedding computation module")
except ImportError as e: except ImportError as e:
@@ -78,11 +76,10 @@ def create_diskann_embedding_server(
raise ValueError("Only metadata files (.meta.json) are supported") raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources # Load metadata to get passage sources
with open(passages_file) as f: with open(passages_file, "r") as f:
meta = json.load(f) meta = json.load(f)
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}") passages = PassageManager(meta["passage_sources"])
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
logger.info( logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
) )
@@ -100,7 +97,6 @@ def create_diskann_embedding_server(
socket = context.socket( socket = context.socket(
zmq.REP zmq.REP
) # REP socket for both BaseSearcher and DiskANN C++ REQ clients ) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}") logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
@@ -154,7 +150,9 @@ def create_diskann_embedding_server(
): ):
texts = request texts = request
is_text_request = True is_text_request = True
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts") logger.info(
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
)
else: else:
raise ValueError("Not a valid msgpack text request") raise ValueError("Not a valid msgpack text request")
except Exception as msgpack_error: except Exception as msgpack_error:
@@ -169,7 +167,9 @@ def create_diskann_embedding_server(
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data["text"]
if not txt: if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt) texts.append(txt)
except KeyError as e: except KeyError as e:
logger.error(f"Passage ID {nid} not found: {e}") logger.error(f"Passage ID {nid} not found: {e}")
@@ -180,7 +180,9 @@ def create_diskann_embedding_server(
# Debug logging # Debug logging
logger.debug(f"Processing {len(texts)} texts") logger.debug(f"Processing {len(texts)} texts")
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5 logger.debug(
f"Text lengths: {[len(t) for t in texts[:5]]}"
) # Show first 5
# Process embeddings using unified computation # Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
@@ -197,7 +199,9 @@ def create_diskann_embedding_server(
else: else:
# For DiskANN C++ compatibility: return protobuf format # For DiskANN C++ compatibility: return protobuf format
resp_proto = embedding_pb2.NodeEmbeddingResponse() resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32) hidden_contiguous = np.ascontiguousarray(
embeddings, dtype=np.float32
)
# Serialize embeddings data # Serialize embeddings data
resp_proto.embeddings_data = hidden_contiguous.tobytes() resp_proto.embeddings_data = hidden_contiguous.tobytes()
@@ -264,16 +268,9 @@ if __name__ == "__main__":
"--embedding-mode", "--embedding-mode",
type=str, type=str,
default="sentence-transformers", default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode", help="Embedding backend mode",
) )
parser.add_argument(
"--distance-metric",
type=str,
default="l2",
choices=["l2", "mips", "cosine"],
help="Distance metric for similarity computation",
)
args = parser.parse_args() args = parser.parse_args()
@@ -283,5 +280,4 @@ if __name__ == "__main__":
zmq_port=args.zmq_port, zmq_port=args.zmq_port,
model_name=args.model_name, model_name=args.model_name,
embedding_mode=args.embedding_mode, embedding_mode=args.embedding_mode,
distance_metric=args.distance_metric,
) )

View File

@@ -1,28 +1,27 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: embedding.proto # source: embedding.proto
# ruff: noqa
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
) DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
if not _descriptor._USE_C_DESCRIPTORS: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_NODEEMBEDDINGREQUEST._serialized_start = 35 DESCRIPTOR._options = None
_NODEEMBEDDINGREQUEST._serialized_end = 75 _NODEEMBEDDINGREQUEST._serialized_start=35
_NODEEMBEDDINGRESPONSE._serialized_start = 77 _NODEEMBEDDINGREQUEST._serialized_end=75
_NODEEMBEDDINGRESPONSE._serialized_end = 166 _NODEEMBEDDINGRESPONSE._serialized_start=77
_NODEEMBEDDINGRESPONSE._serialized_end=166
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env python3
"""
Graph Partition Module for LEANN DiskANN Backend
This module provides Python bindings for the graph partition functionality
of DiskANN, allowing users to partition disk-based indices for better
performance.
"""
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
class GraphPartitioner:
"""
A Python interface for DiskANN's graph partition functionality.
This class provides methods to partition disk-based indices for improved
search performance and memory efficiency.
"""
def __init__(self, build_type: str = "release"):
"""
Initialize the GraphPartitioner.
Args:
build_type: Build type for the executables ("debug" or "release")
"""
self.build_type = build_type
self._ensure_executables()
def _get_executable_path(self, name: str) -> str:
"""Get the path to a graph partition executable."""
# Get the directory where this Python module is located
module_dir = Path(__file__).parent
# Navigate to the graph_partition directory
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
if not executable_path.exists():
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
return str(executable_path)
def _ensure_executables(self):
"""Ensure that the required executables are built."""
try:
self._get_executable_path("partitioner")
self._get_executable_path("index_relayout")
except FileNotFoundError:
# Try to build the executables automatically
print("Executables not found, attempting to build them...")
self._build_executables()
def _build_executables(self):
"""Build the required executables."""
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Clean any existing build
if (graph_partition_dir / "build").exists():
shutil.rmtree(graph_partition_dir / "build")
# Run the build script
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
# Check if executables were created
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
print(f"✅ Built partitioner: {partitioner_path}")
print(f"✅ Built index_relayout: {relayout_path}")
except Exception as e:
raise RuntimeError(f"Failed to build executables: {e}")
finally:
os.chdir(original_dir)
def partition_graph(
self,
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
**kwargs,
) -> tuple[str, str]:
"""
Partition a disk-based index for improved performance.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory for results (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
**kwargs: Additional parameters for graph partitioning:
- gp_times: Number of LDG partition iterations (default: 10)
- lock_nums: Number of lock nodes (default: 10)
- cut: Cut adjacency list degree (default: 100)
- scale_factor: Scale factor (default: 1)
- data_type: Data type (default: "float")
- thread_nums: Number of threads (default: 10)
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
Raises:
RuntimeError: If the partitioning process fails
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Determine partition prefix
if partition_prefix is None:
partition_prefix = Path(index_prefix_path).name
# Get executable paths
partitioner_path = self._get_executable_path("partitioner")
relayout_path = self._get_executable_path("index_relayout")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Change to the graph_partition directory for temporary files
graph_partition_dir = (
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
)
original_dir = os.getcwd()
try:
os.chdir(graph_partition_dir)
# Create temporary data directory
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Run partitioner
gp_file_path = graph_gp_path / "_part.bin"
partitioner_cmd = [
partitioner_path,
"--index_file",
old_index_file,
"--data_type",
params["data_type"],
"--gp_file",
str(gp_file_path),
"-T",
str(params["thread_nums"]),
"--ldg_times",
str(params["gp_times"]),
"--scale",
str(params["scale_factor"]),
"--mode",
"1",
]
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
result = subprocess.run(
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Partitioner failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Run relayout
part_tmp_index = graph_gp_path / "_part_tmp.index"
relayout_cmd = [
relayout_path,
old_index_file,
str(gp_file_path),
params["data_type"],
"1",
]
print(f"Running relayout: {' '.join(relayout_cmd)}")
result = subprocess.run(
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
raise RuntimeError(
f"Relayout failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Copy results to output directory
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
shutil.copy2(part_tmp_index, disk_graph_path)
shutil.copy2(gp_file_path, partition_bin_path)
print(f"Results copied to: {output_dir}")
return str(disk_graph_path), str(partition_bin_path)
finally:
os.chdir(original_dir)
def get_partition_info(self, partition_bin_path: str) -> dict:
"""
Get information about a partition file.
Args:
partition_bin_path: Path to the partition binary file
Returns:
Dictionary containing partition information
"""
if not os.path.exists(partition_bin_path):
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
# For now, return basic file information
# In the future, this could parse the binary file for detailed info
stat = os.stat(partition_bin_path)
return {
"file_size": stat.st_size,
"file_path": partition_bin_path,
"modified_time": stat.st_mtime,
}
def partition_graph(
index_prefix_path: str,
output_dir: Optional[str] = None,
partition_prefix: Optional[str] = None,
build_type: str = "release",
**kwargs,
) -> tuple[str, str]:
"""
Convenience function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix
output_dir: Output directory (defaults to parent of index_prefix_path)
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
build_type: Build type for executables ("debug" or "release")
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
partitioner = GraphPartitioner(build_type=build_type)
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
# Example usage:
if __name__ == "__main__":
# Example: partition an index
try:
disk_graph_path, partition_bin_path = partition_graph(
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
)
print("Partitioning completed successfully!")
print(f"Disk graph index: {disk_graph_path}")
print(f"Partition binary: {partition_bin_path}")
except Exception as e:
print(f"Partitioning failed: {e}")

View File

@@ -1,137 +0,0 @@
#!/usr/bin/env python3
"""
Simplified Graph Partition Module for LEANN DiskANN Backend
This module provides a simple Python interface for graph partitioning
that directly calls the existing executables.
"""
import os
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
def partition_graph_simple(
index_prefix_path: str, output_dir: Optional[str] = None, **kwargs
) -> tuple[str, str]:
"""
Simple function to partition a graph index.
Args:
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
output_dir: Output directory (defaults to parent of index_prefix_path)
**kwargs: Additional parameters for graph partitioning
Returns:
Tuple of (disk_graph_index_path, partition_bin_path)
"""
# Set default parameters
params = {
"gp_times": 10,
"lock_nums": 10,
"cut": 100,
"scale_factor": 1,
"data_type": "float",
"thread_nums": 10,
**kwargs,
}
# Determine output directory
if output_dir is None:
output_dir = str(Path(index_prefix_path).parent)
# Find the graph_partition directory
current_file = Path(__file__)
graph_partition_dir = current_file.parent.parent / "third_party" / "DiskANN" / "graph_partition"
if not graph_partition_dir.exists():
raise RuntimeError(f"Graph partition directory not found: {graph_partition_dir}")
# Find input index file
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
if not os.path.exists(old_index_file):
old_index_file = f"{index_prefix_path}_disk.index"
if not os.path.exists(old_index_file):
raise RuntimeError(f"Index file not found: {old_index_file}")
# Create temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
temp_data_dir = Path(temp_dir) / "data"
temp_data_dir.mkdir(parents=True, exist_ok=True)
# Set up paths for temporary files
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
graph_gp_path = (
graph_path
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
)
graph_gp_path.mkdir(parents=True, exist_ok=True)
# Run the build script with our parameters
cmd = [str(graph_partition_dir / "build.sh"), "release", "split_graph", index_prefix_path]
# Set environment variables for parameters
env = os.environ.copy()
env.update(
{
"GP_TIMES": str(params["gp_times"]),
"GP_LOCK_NUMS": str(params["lock_nums"]),
"GP_CUT": str(params["cut"]),
"GP_SCALE_F": str(params["scale_factor"]),
"DATA_TYPE": params["data_type"],
"GP_T": str(params["thread_nums"]),
}
)
print(f"Running graph partition with command: {' '.join(cmd)}")
print(f"Working directory: {graph_partition_dir}")
# Run the command
result = subprocess.run(
cmd, env=env, capture_output=True, text=True, cwd=graph_partition_dir
)
if result.returncode != 0:
print(f"Command failed with return code {result.returncode}")
print(f"stdout: {result.stdout}")
print(f"stderr: {result.stderr}")
raise RuntimeError(
f"Graph partitioning failed with return code {result.returncode}.\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Check if output files were created
disk_graph_path = Path(output_dir) / "_disk_graph.index"
partition_bin_path = Path(output_dir) / "_partition.bin"
if not disk_graph_path.exists():
raise RuntimeError(f"Expected output file not found: {disk_graph_path}")
if not partition_bin_path.exists():
raise RuntimeError(f"Expected output file not found: {partition_bin_path}")
print("✅ Partitioning completed successfully!")
print(f" Disk graph index: {disk_graph_path}")
print(f" Partition binary: {partition_bin_path}")
return str(disk_graph_path), str(partition_bin_path)
# Example usage
if __name__ == "__main__":
try:
disk_graph_path, partition_bin_path = partition_graph_simple(
"/Users/yichuan/Desktop/release2/leann/diskannbuild/test_doc_files",
gp_times=5,
lock_nums=5,
cut=50,
)
print("Success! Output files:")
print(f" - {disk_graph_path}")
print(f" - {partition_bin_path}")
except Exception as e:
print(f"Error: {e}")

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.2.7" version = "0.1.0"
dependencies = ["leann-core==0.2.7", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.1.0", "numpy"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path
@@ -16,4 +16,4 @@ wheel.packages = ["leann_backend_diskann"]
editable.mode = "redirect" editable.mode = "redirect"
cmake.build-type = "Release" cmake.build-type = "Release"
build.verbose = true build.verbose = true
build.tool-args = ["-j8"] build.tool-args = ["-j8"]

View File

@@ -2,12 +2,12 @@ syntax = "proto3";
package protoembedding; package protoembedding;
message NodeEmbeddingRequest { message NodeEmbeddingRequest {
repeated uint32 node_ids = 1; repeated uint32 node_ids = 1;
} }
message NodeEmbeddingResponse { message NodeEmbeddingResponse {
bytes embeddings_data = 1; // All embedded binary datas bytes embeddings_data = 1; // All embedded binary datas
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim] repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
repeated uint32 missing_ids = 3; // Missing node ids repeated uint32 missing_ids = 3; // Missing node ids
} }

View File

@@ -10,14 +10,6 @@ if(APPLE)
set(OpenMP_C_LIB_NAMES "omp") set(OpenMP_C_LIB_NAMES "omp")
set(OpenMP_CXX_LIB_NAMES "omp") set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib") set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
# Force use of system libc++ to avoid version mismatch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
# Set minimum macOS version for better compatibility
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
endif() endif()
# Use system ZeroMQ instead of building from source # Use system ZeroMQ instead of building from source
@@ -60,4 +52,4 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
# IMPORTANT: Disable building AVX versions to speed up compilation # IMPORTANT: Disable building AVX versions to speed up compilation
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE) set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
add_subdirectory(third_party/faiss) add_subdirectory(third_party/faiss)

View File

@@ -1 +1 @@
from . import hnsw_backend as hnsw_backend from . import hnsw_backend

View File

@@ -1,122 +1,87 @@
import argparse
import gc # Import garbage collector interface
import logging
import os
import struct import struct
import sys import sys
import time
import numpy as np import numpy as np
import os
# Set up logging to avoid print buffer issues import argparse
logger = logging.getLogger(__name__) import gc # Import garbage collector interface
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() import time
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# --- FourCCs (add more if needed) --- # --- FourCCs (add more if needed) ---
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little") INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
# Add other HNSW fourccs if you expect different storage types inside HNSW # Add other HNSW fourccs if you expect different storage types inside HNSW
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little') # INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little') # INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example # INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little") NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
# --- Helper functions for reading/writing binary data --- # --- Helper functions for reading/writing binary data ---
def read_struct(f, fmt): def read_struct(f, fmt):
"""Reads data according to the struct format.""" """Reads data according to the struct format."""
size = struct.calcsize(fmt) size = struct.calcsize(fmt)
data = f.read(size) data = f.read(size)
if len(data) != size: if len(data) != size:
raise EOFError( raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
)
return struct.unpack(fmt, data)[0] return struct.unpack(fmt, data)[0]
def read_vector_raw(f, element_fmt_char): def read_vector_raw(f, element_fmt_char):
"""Reads a vector (size followed by data), returns count and raw bytes.""" """Reads a vector (size followed by data), returns count and raw bytes."""
count = -1 # Initialize count count = -1 # Initialize count
total_bytes = -1 # Initialize total_bytes total_bytes = -1 # Initialize total_bytes
try: try:
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
element_size = struct.calcsize(element_fmt_char) element_size = struct.calcsize(element_fmt_char)
# --- FIX for MemoryError: Check for unreasonably large count --- # --- FIX for MemoryError: Check for unreasonably large count ---
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
if count > max_reasonable_count or count < 0: if count > max_reasonable_count or count < 0:
raise MemoryError( raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
)
total_bytes = count * element_size total_bytes = count * element_size
# --- FIX for MemoryError: Check for huge byte size before allocation --- # --- FIX for MemoryError: Check for huge byte size before allocation ---
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
raise MemoryError( raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
)
data_bytes = f.read(total_bytes) data_bytes = f.read(total_bytes)
if len(data_bytes) != total_bytes: if len(data_bytes) != total_bytes:
raise EOFError( raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
)
return count, data_bytes return count, data_bytes
except (MemoryError, OverflowError) as e: except (MemoryError, OverflowError) as e:
# Add context to the error message # Add context to the error message
print( print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", raise e # Re-raise the original error type
file=sys.stderr,
)
raise e # Re-raise the original error type
def read_numpy_vector(f, np_dtype, struct_fmt_char): def read_numpy_vector(f, np_dtype, struct_fmt_char):
"""Reads a vector into a NumPy array.""" """Reads a vector into a NumPy array."""
count = -1 # Initialize count for robust error handling count = -1 # Initialize count for robust error handling
print( print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
end="",
flush=True,
)
try: try:
count, data_bytes = read_vector_raw(f, struct_fmt_char) count, data_bytes = read_vector_raw(f, struct_fmt_char)
print(f"Count={count}, Bytes={len(data_bytes)}") print(f"Count={count}, Bytes={len(data_bytes)}")
if count > 0 and len(data_bytes) > 0: if count > 0 and len(data_bytes) > 0:
arr = np.frombuffer(data_bytes, dtype=np_dtype) arr = np.frombuffer(data_bytes, dtype=np_dtype)
if arr.size != count: if arr.size != count:
raise ValueError( raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
)
return arr return arr
elif count == 0: elif count == 0:
return np.array([], dtype=np_dtype) return np.array([], dtype=np_dtype)
else: else:
raise ValueError("Read zero bytes but count > 0.") raise ValueError("Read zero bytes but count > 0.")
except MemoryError as e: except MemoryError as e:
# Now count should be defined (or -1 if error was in read_struct) # Now count should be defined (or -1 if error was in read_struct)
print( print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
file=sys.stderr,
)
raise e raise e
except Exception as e: # Catch other potential errors like ValueError except Exception as e: # Catch other potential errors like ValueError
print( print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
file=sys.stderr,
)
raise e raise e
def write_numpy_vector(f, arr, struct_fmt_char): def write_numpy_vector(f, arr, struct_fmt_char):
"""Writes a NumPy array as a vector (size followed by data).""" """Writes a NumPy array as a vector (size followed by data)."""
count = arr.size count = arr.size
f.write(struct.pack("<Q", count)) f.write(struct.pack('<Q', count))
try: try:
expected_dtype = np.dtype(struct_fmt_char) expected_dtype = np.dtype(struct_fmt_char)
if arr.dtype != expected_dtype: if arr.dtype != expected_dtype:
@@ -124,30 +89,23 @@ def write_numpy_vector(f, arr, struct_fmt_char):
else: else:
data_to_write = arr.tobytes() data_to_write = arr.tobytes()
f.write(data_to_write) f.write(data_to_write)
del data_to_write # Hint GC del data_to_write # Hint GC
except MemoryError as e: except MemoryError as e:
print( print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", raise e
file=sys.stderr,
)
raise e
def write_list_vector(f, lst, struct_fmt_char): def write_list_vector(f, lst, struct_fmt_char):
"""Writes a Python list as a vector iteratively.""" """Writes a Python list as a vector iteratively."""
count = len(lst) count = len(lst)
f.write(struct.pack("<Q", count)) f.write(struct.pack('<Q', count))
fmt = "<" + struct_fmt_char fmt = '<' + struct_fmt_char
chunk_size = 1024 * 1024 chunk_size = 1024 * 1024
element_size = struct.calcsize(fmt) element_size = struct.calcsize(fmt)
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation # Allocate buffer outside the loop if possible, or handle MemoryError during allocation
try: try:
buffer = bytearray(chunk_size * element_size) buffer = bytearray(chunk_size * element_size)
except MemoryError: except MemoryError:
print( print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
file=sys.stderr,
)
raise raise
buffer_count = 0 buffer_count = 0
@@ -158,80 +116,66 @@ def write_list_vector(f, lst, struct_fmt_char):
buffer_count += 1 buffer_count += 1
if buffer_count == chunk_size or i == count - 1: if buffer_count == chunk_size or i == count - 1:
f.write(buffer[: buffer_count * element_size]) f.write(buffer[:buffer_count * element_size])
buffer_count = 0 buffer_count = 0
except struct.error as e: except struct.error as e:
print( print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
file=sys.stderr,
)
raise e raise e
def get_cum_neighbors(cum_nneighbor_per_level_np, level): def get_cum_neighbors(cum_nneighbor_per_level_np, level):
"""Helper to get cumulative neighbors count, matching C++ logic.""" """Helper to get cumulative neighbors count, matching C++ logic."""
if level < 0: if level < 0: return 0
return 0
if level < len(cum_nneighbor_per_level_np): if level < len(cum_nneighbor_per_level_np):
return cum_nneighbor_per_level_np[level] return cum_nneighbor_per_level_np[level]
else: else:
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0 return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
def write_compact_format( levels_np, compact_level_ptr, compact_node_offsets_np,
f_out, compact_neighbors_data, storage_fourcc, storage_data):
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data,
):
"""Write HNSW data in compact format following C++ read order exactly.""" """Write HNSW data in compact format following C++ read order exactly."""
# Write IndexHNSW Header # Write IndexHNSW Header
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"])) f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
f_out.write(struct.pack("<i", original_hnsw_data["d"])) f_out.write(struct.pack('<i', original_hnsw_data['d']))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"])) f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"])) f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"])) f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"])) f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"])) f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
if original_hnsw_data["metric_type"] > 1: if original_hnsw_data['metric_type'] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"])) f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
# Write HNSW struct parts (standard order) # Write HNSW struct parts (standard order)
write_numpy_vector(f_out, assign_probas_np, "d") write_numpy_vector(f_out, assign_probas_np, 'd')
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i") write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
write_numpy_vector(f_out, levels_np, "i") write_numpy_vector(f_out, levels_np, 'i')
# Write compact format flag # Write compact format flag
f_out.write(struct.pack("<?", True)) # storage_is_compact = True f_out.write(struct.pack('<?', True)) # storage_is_compact = True
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST # Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
if isinstance(compact_level_ptr, np.ndarray): if isinstance(compact_level_ptr, np.ndarray):
write_numpy_vector(f_out, compact_level_ptr, "Q") write_numpy_vector(f_out, compact_level_ptr, 'Q')
else: else:
write_list_vector(f_out, compact_level_ptr, "Q") write_list_vector(f_out, compact_level_ptr, 'Q')
write_numpy_vector(f_out, compact_node_offsets_np, "Q") write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
# Write HNSW scalar parameters # Write HNSW scalar parameters
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"])) f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"])) f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"])) f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"])) f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"])) f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
# Write storage fourcc (this determines how to read what follows) # Write storage fourcc (this determines how to read what follows)
f_out.write(struct.pack("<I", storage_fourcc)) f_out.write(struct.pack('<I', storage_fourcc))
# Write compact neighbors data AFTER storage fourcc # Write compact neighbors data AFTER storage fourcc
write_list_vector(f_out, compact_neighbors_data, "i") write_list_vector(f_out, compact_neighbors_data, 'i')
# Write storage data if not NULL (only after neighbors) # Write storage data if not NULL (only after neighbors)
if storage_fourcc != NULL_INDEX_FOURCC and storage_data: if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data) f_out.write(storage_data)
@@ -239,248 +183,185 @@ def write_compact_format(
# --- Main Conversion Logic --- # --- Main Conversion Logic ---
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True): def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
""" """
Converts an HNSW graph file to the CSR format. Converts an HNSW graph file to the CSR format.
Supports both original and already-compact formats (backward compatibility). Supports both original and already-compact formats (backward compatibility).
Args: Args:
input_filename: Input HNSW index file input_filename: Input HNSW index file
output_filename: Output CSR index file output_filename: Output CSR index file
prune_embeddings: Whether to prune embedding storage (write NULL storage marker) prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
""" """
# Disable buffering for print statements to avoid deadlock in CI/pytest
import functools
global print
print = functools.partial(print, flush=True)
print(f"Starting conversion: {input_filename} -> {output_filename}") print(f"Starting conversion: {input_filename} -> {output_filename}")
start_time = time.time() start_time = time.time()
original_hnsw_data = {} original_hnsw_data = {}
neighbors_np = None # Initialize to allow check in finally block neighbors_np = None # Initialize to allow check in finally block
try: try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out: with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
# --- Read IndexHNSW FourCC and Header --- # --- Read IndexHNSW FourCC and Header ---
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...") print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
# ... (Keep the header reading logic as before) ... # ... (Keep the header reading logic as before) ...
hnsw_index_fourcc = read_struct(f_in, "<I") hnsw_index_fourcc = read_struct(f_in, '<I')
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS: if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print( print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", return False
file=sys.stderr, original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
) original_hnsw_data['d'] = read_struct(f_in, '<i')
return False original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
original_hnsw_data["d"] = read_struct(f_in, "<i") original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
original_hnsw_data["ntotal"] = read_struct(f_in, "<q") original_hnsw_data['is_trained'] = read_struct(f_in, '?')
original_hnsw_data["dummy1"] = read_struct(f_in, "<q") original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
original_hnsw_data["dummy2"] = read_struct(f_in, "<q") original_hnsw_data['metric_arg'] = 0.0
original_hnsw_data["is_trained"] = read_struct(f_in, "?") if original_hnsw_data['metric_type'] > 1:
original_hnsw_data["metric_type"] = read_struct(f_in, "<i") original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
original_hnsw_data["metric_arg"] = 0.0 print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
print(
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
)
# --- Read original HNSW struct data --- # --- Read original HNSW struct data ---
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...") print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d") assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
print( print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
)
gc.collect() gc.collect()
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i") cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
print( print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
)
gc.collect() gc.collect()
levels_np = read_numpy_vector(f_in, np.int32, "i") levels_np = read_numpy_vector(f_in, np.int32, 'i')
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})") print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
gc.collect() gc.collect()
ntotal = len(levels_np) ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]: if ntotal != original_hnsw_data['ntotal']:
print( print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", original_hnsw_data['ntotal'] = ntotal
file=sys.stderr,
)
original_hnsw_data["ntotal"] = ntotal
# --- Check for compact format flag --- # --- Check for compact format flag ---
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...") print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
pos_before_compact = f_in.tell() pos_before_compact = f_in.tell()
try: try:
is_compact_flag = read_struct(f_in, "<?") is_compact_flag = read_struct(f_in, '<?')
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}") print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
if is_compact_flag: if is_compact_flag:
# Input is already in compact format - read compact data # Input is already in compact format - read compact data
print( print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
) compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
print( compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})" print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
)
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
print(
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
)
# Read scalar parameters # Read scalar parameters
original_hnsw_data["entry_point"] = read_struct(f_in, "<i") original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
original_hnsw_data["max_level"] = read_struct(f_in, "<i") original_hnsw_data['max_level'] = read_struct(f_in, '<i')
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i") original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
original_hnsw_data["efSearch"] = read_struct(f_in, "<i") original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i") original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
print( print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
# Read storage fourcc # Read storage fourcc
storage_fourcc = read_struct(f_in, "<I") storage_fourcc = read_struct(f_in, '<I')
print( print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
)
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC: if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
# Read compact neighbors data # Read compact neighbors data
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i") compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
print( print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
)
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
# Skip storage data and write with NULL marker # Skip storage data and write with NULL marker
print( print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
)
storage_fourcc = NULL_INDEX_FOURCC storage_fourcc = NULL_INDEX_FOURCC
elif not prune_embeddings: elif not prune_embeddings:
# Read and preserve compact neighbors and storage # Read and preserve compact neighbors and storage
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i") compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
# Read remaining storage data # Read remaining storage data
storage_data = f_in.read() storage_data = f_in.read()
else: else:
# Already pruned (NULL storage) # Already pruned (NULL storage)
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i") compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
compact_neighbors_data = compact_neighbors_data_np.tolist() compact_neighbors_data = compact_neighbors_data_np.tolist()
del compact_neighbors_data_np del compact_neighbors_data_np
storage_data = b"" storage_data = b''
# Write the updated compact format # Write the updated compact format
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...") print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
write_compact_format( write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
f_out, levels_np, compact_level_ptr, compact_node_offsets_np,
original_hnsw_data, compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
storage_fourcc,
storage_data if not prune_embeddings else b"",
)
print(f"[{time.time() - start_time:.2f}s] Conversion complete.") print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
return True return True
else: else:
# is_compact=False, rewind and read original format # is_compact=False, rewind and read original format
f_in.seek(pos_before_compact) f_in.seek(pos_before_compact)
print( print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
)
except EOFError: except EOFError:
# No compact flag found, assume original format # No compact flag found, assume original format
f_in.seek(pos_before_compact) f_in.seek(pos_before_compact)
print( print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
)
# --- Handle potential extra byte in original format (like C++ code) --- # --- Handle potential extra byte in original format (like C++ code) ---
print( print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
)
pos_before_probe = f_in.tell() pos_before_probe = f_in.tell()
try: try:
suspected_flag = read_struct(f_in, "<B") # Read 1 byte suspected_flag = read_struct(f_in, '<B') # Read 1 byte
if suspected_flag == 0x00: if suspected_flag == 0x00:
print( print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
)
elif suspected_flag == 0x01: elif suspected_flag == 0x01:
print( print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
)
raise ValueError("Inconsistent compact flag state") raise ValueError("Inconsistent compact flag state")
else: else:
# Rewind - this byte is part of offsets data # Rewind - this byte is part of offsets data
f_in.seek(pos_before_probe) f_in.seek(pos_before_probe)
print( print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
)
except EOFError: except EOFError:
f_in.seek(pos_before_probe) f_in.seek(pos_before_probe)
print( print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
)
# --- Read original format data --- # --- Read original format data ---
offsets_np = read_numpy_vector(f_in, np.uint64, "Q") offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})") print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
if len(offsets_np) != ntotal + 1: if len(offsets_np) != ntotal + 1:
raise ValueError( raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
)
gc.collect() gc.collect()
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...") print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
neighbors_np = read_numpy_vector(f_in, np.int32, "i") neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})") print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0 expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
if neighbors_np.size != expected_neighbors_size: if neighbors_np.size != expected_neighbors_size:
print( print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
)
gc.collect() gc.collect()
original_hnsw_data["entry_point"] = read_struct(f_in, "<i") original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
original_hnsw_data["max_level"] = read_struct(f_in, "<i") original_hnsw_data['max_level'] = read_struct(f_in, '<i')
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i") original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
original_hnsw_data["efSearch"] = read_struct(f_in, "<i") original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i") original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
print( print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
)
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...") print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
storage_fourcc = None storage_fourcc = None
try: try:
storage_fourcc = read_struct(f_in, "<I") storage_fourcc = read_struct(f_in, '<I')
print( print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
)
except EOFError: except EOFError:
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).") print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
except Exception as e: except Exception as e:
print( print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
)
# --- Perform Conversion --- # --- Perform Conversion ---
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...") print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
@@ -492,21 +373,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
current_level_ptr_idx = 0 current_level_ptr_idx = 0
current_data_idx = 0 current_data_idx = 0
total_valid_neighbors_counted = 0 # For validation total_valid_neighbors_counted = 0 # For validation
# Optimize calculation by getting slices once per node if possible # Optimize calculation by getting slices once per node if possible
for i in range(ntotal): for i in range(ntotal):
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1% if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
progress = (i / ntotal) * 100 progress = (i / ntotal) * 100
elapsed = time.time() - start_time elapsed = time.time() - start_time
print( print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
end="",
)
node_max_level = levels_np[i] - 1 node_max_level = levels_np[i] - 1
if node_max_level < -1: if node_max_level < -1: node_max_level = -1
node_max_level = -1
node_ptr_start_index = current_level_ptr_idx node_ptr_start_index = current_level_ptr_idx
compact_node_offsets_np[i] = node_ptr_start_index compact_node_offsets_np[i] = node_ptr_start_index
@@ -517,17 +394,13 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
for level in range(node_max_level + 1): for level in range(node_max_level + 1):
compact_level_ptr.append(current_data_idx) compact_level_ptr.append(current_data_idx)
begin_orig_np = original_offset_start + get_cum_neighbors( begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
cum_nneighbor_per_level_np, level end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
)
end_orig_np = original_offset_start + get_cum_neighbors(
cum_nneighbor_per_level_np, level + 1
)
begin_orig = int(begin_orig_np) begin_orig = int(begin_orig_np)
end_orig = int(end_orig_np) end_orig = int(end_orig_np)
neighbors_len = len(neighbors_np) # Cache length neighbors_len = len(neighbors_np) # Cache length
begin_orig = min(max(0, begin_orig), neighbors_len) begin_orig = min(max(0, begin_orig), neighbors_len)
end_orig = min(max(begin_orig, end_orig), neighbors_len) end_orig = min(max(begin_orig, end_orig), neighbors_len)
@@ -540,117 +413,83 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
if num_valid > 0: if num_valid > 0:
# Append valid neighbors # Append valid neighbors
compact_neighbors_data.extend( compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
level_neighbors_slice[valid_neighbors_mask]
)
current_data_idx += num_valid current_data_idx += num_valid
total_valid_neighbors_counted += num_valid total_valid_neighbors_counted += num_valid
compact_level_ptr.append(current_data_idx) compact_level_ptr.append(current_data_idx)
current_level_ptr_idx += num_pointers_expected current_level_ptr_idx += num_pointers_expected
compact_node_offsets_np[ntotal] = current_level_ptr_idx compact_node_offsets_np[ntotal] = current_level_ptr_idx
print( print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
) # Clear progress line
# --- Validation Checks --- # --- Validation Checks ---
print(f"[{time.time() - start_time:.2f}s] Running validation checks...") print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
valid_check_passed = True valid_check_passed = True
# Check 1: Total valid neighbors count # Check 1: Total valid neighbors count
print(" Checking total valid neighbor count...") print(f" Checking total valid neighbor count...")
expected_valid_count = np.sum(neighbors_np >= 0) expected_valid_count = np.sum(neighbors_np >= 0)
if total_valid_neighbors_counted != len(compact_neighbors_data): if total_valid_neighbors_counted != len(compact_neighbors_data):
print( print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", valid_check_passed = False
file=sys.stderr,
)
valid_check_passed = False
if expected_valid_count != len(compact_neighbors_data): if expected_valid_count != len(compact_neighbors_data):
print( print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", valid_check_passed = False
file=sys.stderr,
)
valid_check_passed = False
else: else:
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}") print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
# Check 2: Final pointer indices consistency # Check 2: Final pointer indices consistency
print(" Checking final pointer indices...") print(f" Checking final pointer indices...")
if compact_node_offsets_np[ntotal] != len(compact_level_ptr): if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
print( print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", valid_check_passed = False
file=sys.stderr, if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
) (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
valid_check_passed = False last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
if ( print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data) valid_check_passed = False
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
print(
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
file=sys.stderr,
)
valid_check_passed = False
else: else:
print(" OK: Final pointers match data size.") print(f" OK: Final pointers match data size.")
if not valid_check_passed: if not valid_check_passed:
print( print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
"Error: Validation checks failed. Output file might be incorrect.",
file=sys.stderr,
)
# Optional: Exit here if validation fails # Optional: Exit here if validation fails
# return False # return False
# --- Explicitly delete large intermediate arrays --- # --- Explicitly delete large intermediate arrays ---
print( print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
)
del neighbors_np del neighbors_np
del offsets_np del offsets_np
gc.collect() gc.collect()
print( print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
)
# --- Write CSR HNSW graph data using unified function --- # --- Write CSR HNSW graph data using unified function ---
print( print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
)
# Determine storage fourcc and data based on prune_embeddings # Determine storage fourcc and data based on prune_embeddings
if prune_embeddings: if prune_embeddings:
print(" Pruning embeddings: Writing NULL storage marker.") print(f" Pruning embeddings: Writing NULL storage marker.")
output_storage_fourcc = NULL_INDEX_FOURCC output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b"" storage_data = b''
else: else:
# Keep embeddings - read and preserve original storage data # Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC: if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(" Preserving embeddings: Reading original storage data...") print(f" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data") print(f" Read {len(storage_data)} bytes of storage data")
else: else:
print(" No embeddings found in original file (NULL storage)") print(f" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b"" storage_data = b''
# Use the unified write function # Use the unified write function
write_compact_format( write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
f_out, levels_np, compact_level_ptr, compact_node_offsets_np,
original_hnsw_data, compact_neighbors_data, output_storage_fourcc, storage_data)
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
output_storage_fourcc,
storage_data,
)
# Clean up memory # Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np del assign_probas_np, cum_nneighbor_per_level_np, levels_np
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
@@ -664,66 +503,40 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
print(f"Error: Input file not found: {input_filename}", file=sys.stderr) print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
return False return False
except MemoryError as e: except MemoryError as e:
print( print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", # Clean up potentially partially written output file?
file=sys.stderr, try: os.remove(output_filename)
) except OSError: pass
# Clean up potentially partially written output file? return False
try:
os.remove(output_filename)
except OSError:
pass
return False
except EOFError as e: except EOFError as e:
print( print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", try: os.remove(output_filename)
file=sys.stderr, except OSError: pass
)
try:
os.remove(output_filename)
except OSError:
pass
return False return False
except Exception as e: except Exception as e:
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr) print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
import traceback import traceback
traceback.print_exc() traceback.print_exc()
try: try:
os.remove(output_filename) os.remove(output_filename)
except OSError: except OSError: pass
pass
return False return False
# Ensure neighbors_np is deleted even if an error occurs after its allocation # Ensure neighbors_np is deleted even if an error occurs after its allocation
finally: finally:
try: if 'neighbors_np' in locals() and neighbors_np is not None:
if "neighbors_np" in locals() and neighbors_np is not None: del neighbors_np
del neighbors_np gc.collect()
gc.collect()
except NameError:
pass
# --- Script Execution --- # --- Script Execution ---
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
)
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file") parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
parser.add_argument( parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file" parser.add_argument("--prune-embeddings", action="store_true", default=True,
) help="Prune embedding storage (write NULL storage marker)")
parser.add_argument( parser.add_argument("--keep-embeddings", action="store_true",
"--prune-embeddings", help="Keep embedding storage (overrides --prune-embeddings)")
action="store_true",
default=True,
help="Prune embedding storage (write NULL storage marker)",
)
parser.add_argument(
"--keep-embeddings",
action="store_true",
help="Keep embedding storage (overrides --prune-embeddings)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -732,12 +545,10 @@ if __name__ == "__main__":
sys.exit(1) sys.exit(1)
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file): if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
print("Error: Input and output filenames cannot be the same.", file=sys.stderr) print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
sys.exit(1) sys.exit(1)
prune_embeddings = args.prune_embeddings and not args.keep_embeddings prune_embeddings = args.prune_embeddings and not args.keep_embeddings
success = convert_hnsw_graph_to_csr( success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
args.input_index_file, args.output_csr_graph_file, prune_embeddings
)
if not success: if not success:
sys.exit(1) sys.exit(1)

View File

@@ -1,19 +1,19 @@
import logging
import os
import shutil
from pathlib import Path
from typing import Any, Literal, Optional
import numpy as np import numpy as np
import os
from pathlib import Path
from typing import Dict, Any, List, Literal, Optional
import shutil
import logging
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendBuilderInterface,
LeannBackendFactoryInterface, LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface, LeannBackendSearcherInterface,
) )
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -28,12 +28,6 @@ def get_metric_map():
} }
def normalize_l2(data: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(data, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
return data / norms
@register_backend("hnsw") @register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface): class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -54,14 +48,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions") self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute:
if self.is_compact:
# TODO: support this case @andy
raise ValueError(
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
)
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss # type: ignore from . import faiss # type: ignore
path = Path(index_path) path = Path(index_path)
@@ -82,7 +70,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
index.hnsw.efConstruction = self.efConstruction index.hnsw.efConstruction = self.efConstruction
if self.distance_metric.lower() == "cosine": if self.distance_metric.lower() == "cosine":
data = normalize_l2(data) faiss.normalize_L2(data)
index.add(data.shape[0], faiss.swig_ptr(data)) index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index" index_file = index_dir / f"{index_prefix}.index"
@@ -104,15 +92,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if success: if success:
logger.info("✅ CSR conversion successful.") logger.info("✅ CSR conversion successful.")
# index_file_old = index_file.with_suffix(".old") index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old)) shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file)) shutil.move(str(csr_temp_file), str(index_file))
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'") logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
)
else: else:
# Clean up and fail fast # Clean up and fail fast
if csr_temp_file.exists(): if csr_temp_file.exists():
os.remove(csr_temp_file) os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format") raise RuntimeError(
"CSR conversion failed - cannot proceed with compact format"
)
class HNSWSearcher(BaseSearcher): class HNSWSearcher(BaseSearcher):
@@ -124,9 +116,7 @@ class HNSWSearcher(BaseSearcher):
) )
from . import faiss # type: ignore from . import faiss # type: ignore
self.distance_metric = ( self.distance_metric = self.meta.get("distance_metric", "mips").lower()
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
)
metric_enum = get_metric_map().get(self.distance_metric) metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
@@ -160,7 +150,7 @@ class HNSWSearcher(BaseSearcher):
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
batch_size: int = 0, batch_size: int = 0,
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> Dict[str, Any]:
""" """
Search for nearest neighbors using HNSW index. Search for nearest neighbors using HNSW index.
@@ -189,29 +179,23 @@ class HNSWSearcher(BaseSearcher):
raise RuntimeError("Recompute is required for pruned index.") raise RuntimeError("Recompute is required for pruned index.")
if recompute_embeddings: if recompute_embeddings:
if zmq_port is None: if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True") raise ValueError(
"zmq_port must be provided if recompute_embeddings is True"
)
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if self.distance_metric == "cosine": if self.distance_metric == "cosine":
query = normalize_l2(query) faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW() params = faiss.SearchParametersHNSW()
if zmq_port is not None: if zmq_port is not None:
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False params.zmq_port = (
zmq_port # C++ code won't use this if recompute_embeddings is False
)
params.efSearch = complexity params.efSearch = complexity
params.beam_size = beam_width params.beam_size = beam_width
# For OpenAI embeddings with cosine distance, disable relative distance check
# This prevents early termination when all scores are in a narrow range
embedding_model = self.meta.get("embedding_model", "").lower()
if self.distance_metric == "cosine" and any(
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
):
params.check_relative_distance = False
else:
params.check_relative_distance = True
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio # PQ pruning: direct mapping to HNSW's pq_pruning_ratio
params.pq_pruning_ratio = prune_ratio params.pq_pruning_ratio = prune_ratio
@@ -221,7 +205,9 @@ class HNSWSearcher(BaseSearcher):
params.send_neigh_times_ratio = 0.0 params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional": elif pruning_strategy == "proportional":
params.local_prune = False params.local_prune = False
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode params.send_neigh_times_ratio = (
1.0 # Any value > 1e-6 triggers proportional mode
)
else: # "global" else: # "global"
params.local_prune = False params.local_prune = False
params.send_neigh_times_ratio = 0.0 params.send_neigh_times_ratio = 0.0
@@ -242,28 +228,8 @@ class HNSWSearcher(BaseSearcher):
params, params,
) )
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels] string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances} return {"labels": string_labels, "distances": distances}
def cleanup(self):
"""Cleanup HNSW-specific resources including C++ ZMQ connections."""
# Call parent cleanup first
super().cleanup()
# Additional cleanup for C++ side ZMQ connections
# The ZmqDistanceComputer in C++ uses ZMQ connections that need cleanup
try:
# Delete the index to trigger C++ destructors
if hasattr(self, "index"):
del self.index
except Exception:
pass
# Force garbage collection to ensure C++ objects are destroyed
try:
import gc
gc.collect()
except Exception:
pass

View File

@@ -3,18 +3,17 @@ HNSW-specific embedding server
""" """
import argparse import argparse
import json
import logging
import os
import sys
import threading import threading
import time import time
import os
import zmq
import numpy as np
import msgpack
import json
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import sys
import msgpack import logging
import numpy as np
import zmq
# Set up logging based on environment variable # Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -53,8 +52,8 @@ def create_hnsw_embedding_server(
sys.path.insert(0, str(leann_core_path)) sys.path.insert(0, str(leann_core_path))
try: try:
from leann.api import PassageManager
from leann.embedding_compute import compute_embeddings from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
logger.info("Successfully imported unified embedding computation module") logger.info("Successfully imported unified embedding computation module")
except ImportError as e: except ImportError as e:
@@ -79,11 +78,10 @@ def create_hnsw_embedding_server(
raise ValueError("Only metadata files (.meta.json) are supported") raise ValueError("Only metadata files (.meta.json) are supported")
# Load metadata to get passage sources # Load metadata to get passage sources
with open(passages_file) as f: with open(passages_file, "r") as f:
meta = json.load(f) meta = json.load(f)
# Let PassageManager handle path resolution uniformly passages = PassageManager(meta["passage_sources"])
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
logger.info( logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
) )
@@ -92,7 +90,6 @@ def create_hnsw_embedding_server(
"""ZMQ server thread""" """ZMQ server thread"""
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REP) socket = context.socket(zmq.REP)
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ server listening on port {zmq_port}") logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
@@ -123,7 +120,9 @@ def create_hnsw_embedding_server(
response = embeddings.tolist() response = embeddings.tolist()
socket.send(msgpack.packb(response)) socket.send(msgpack.packb(response))
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") logger.info(
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
continue continue
# Handle distance calculation requests # Handle distance calculation requests
@@ -149,13 +148,17 @@ def create_hnsw_embedding_server(
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
logger.error(f"Passage ID {nid} not found") logger.error(f"Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e: except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise raise
# Process embeddings # Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
@@ -169,12 +172,18 @@ def create_hnsw_embedding_server(
distances = -np.dot(embeddings, query_vector) distances = -np.dot(embeddings, query_vector)
response_payload = distances.flatten().tolist() response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True) response_bytes = msgpack.packb(
logger.debug(f"Sending distance response with {len(distances)} distances") [response_payload], use_single_float=True
)
logger.debug(
f"Sending distance response with {len(distances)} distances"
)
socket.send(response_bytes) socket.send(response_bytes)
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue continue
# Standard embedding request (passage ID lookup) # Standard embedding request (passage ID lookup)
@@ -199,7 +208,9 @@ def create_hnsw_embedding_server(
passage_data = passages.get_passage(str(nid)) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"] txt = passage_data["text"]
if not txt: if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") raise RuntimeError(
f"FATAL: Empty text for passage ID {nid}"
)
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
@@ -218,9 +229,11 @@ def create_hnsw_embedding_server(
logger.error( logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
) )
raise AssertionError() assert False
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) hidden_contiguous_f32 = np.ascontiguousarray(
embeddings, dtype=np.float32
)
response_payload = [ response_payload = [
list(hidden_contiguous_f32.shape), list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(), hidden_contiguous_f32.flatten().tolist(),
@@ -257,15 +270,15 @@ def create_hnsw_embedding_server(
if __name__ == "__main__": if __name__ == "__main__":
import signal import signal
import sys import sys
def signal_handler(sig, frame): def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...") logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0) sys.exit(0)
# Register signal handlers for graceful shutdown # Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="HNSW Embedding service") parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument( parser.add_argument(
@@ -286,7 +299,7 @@ if __name__ == "__main__":
"--embedding-mode", "--embedding-mode",
type=str, type=str,
default="sentence-transformers", default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx"],
help="Embedding backend mode", help="Embedding backend mode",
) )

View File

@@ -6,14 +6,9 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.2.7" version = "0.1.0"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = ["leann-core==0.1.0", "numpy"]
"leann-core==0.2.7",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
]
[tool.scikit-build] [tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"] wheel.packages = ["leann_backend_hnsw"]
@@ -24,4 +19,4 @@ build.tool-args = ["-j8"]
# CMake definitions to optimize compilation # CMake definitions to optimize compilation
[tool.scikit-build.cmake.define] [tool.scikit-build.cmake.define]
CMAKE_BUILD_PARALLEL_LEVEL = "8" CMAKE_BUILD_PARALLEL_LEVEL = "8"

View File

@@ -4,49 +4,19 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.2.7" version = "0.1.0"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for Leann."
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
license = { text = "MIT" } license = { text = "MIT" }
# All required dependencies included
dependencies = [ dependencies = [
"numpy>=1.20.0", "numpy>=1.20.0",
"tqdm>=4.60.0", "tqdm>=4.60.0"
"psutil>=5.8.0",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0", # Essential for document reading
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
"pymupdf>=1.23.0",
"pdfplumber>=0.10.0",
"nbconvert>=7.0.0", # For .ipynb file support
"gitignore-parser>=0.1.12", # For proper .gitignore handling
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
]
[project.optional-dependencies]
colab = [
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
"transformers>=4.30.0,<5.0.0", # Limit transformers version
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
] ]
[project.scripts] [project.scripts]
leann = "leann.cli:main" leann = "leann.cli:main"
leann_mcp = "leann.mcp:main"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
where = ["src"] where = ["src"]

View File

@@ -8,14 +8,10 @@ if platform.system() == "Darwin":
os.environ["MKL_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["KMP_BLOCKTIME"] = "0" os.environ["KMP_BLOCKTIME"] = "0"
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
if os.environ.get("CI") == "true":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends() autodiscover_backends()
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"] __all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]

View File

@@ -4,32 +4,23 @@ with the correct, original embedding logic from the user's reference code.
""" """
import json import json
import logging
import pickle import pickle
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional
import numpy as np
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
import numpy as np
from .chat import get_llm import time
from .interface import LeannBackendFactoryInterface from pathlib import Path
from typing import List, Dict, Any, Optional, Literal
from dataclasses import dataclass, field
from .registry import BACKEND_REGISTRY from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from .chat import get_llm
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_registered_backends() -> list[str]:
"""Get list of registered backend names."""
return list(BACKEND_REGISTRY.keys())
def compute_embeddings( def compute_embeddings(
chunks: list[str], chunks: List[str],
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
use_server: bool = True, use_server: bool = True,
@@ -70,7 +61,9 @@ def compute_embeddings(
) )
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray: def compute_embeddings_via_server(
chunks: List[str], model_name: str, port: int
) -> np.ndarray:
"""Computes embeddings using sentence-transformers. """Computes embeddings using sentence-transformers.
Args: Args:
@@ -80,33 +73,28 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
logger.info( logger.info(
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
) )
import zmq
import msgpack import msgpack
import numpy as np import numpy as np
import zmq
# Connect to embedding server # Connect to embedding server
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.LINGER, 0) # Don't block on close
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1s timeout on receive
socket.setsockopt(zmq.SNDTIMEO, 1000) # 1s timeout on send
socket.setsockopt(zmq.IMMEDIATE, 1) # Don't wait for connection
socket.connect(f"tcp://localhost:{port}") socket.connect(f"tcp://localhost:{port}")
try: # Send chunks to server for embedding computation
# Send chunks to server for embedding computation request = chunks
request = chunks socket.send(msgpack.packb(request))
socket.send(msgpack.packb(request))
# Receive embeddings from server # Receive embeddings from server
response = socket.recv() response = socket.recv()
embeddings_list = msgpack.unpackb(response) embeddings_list = msgpack.unpackb(response)
# Convert back to numpy array # Convert back to numpy array
embeddings = np.array(embeddings_list, dtype=np.float32) embeddings = np.array(embeddings_list, dtype=np.float32)
finally:
socket.close(linger=0) socket.close()
context.term() context.term()
return embeddings return embeddings
@@ -116,13 +104,11 @@ class SearchResult:
id: str id: str
score: float score: float
text: str text: str
metadata: dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager: class PassageManager:
def __init__( def __init__(self, passage_sources: List[Dict[str, Any]]):
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
):
self.offset_maps = {} self.offset_maps = {}
self.passage_files = {} self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup self.global_offset_map = {} # Combined map for fast lookup
@@ -131,31 +117,8 @@ class PassageManager:
assert source["type"] == "jsonl", "only jsonl is supported" assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"] passage_file = source["path"]
index_file = source["index_path"] # .idx file index_file = source["index_path"] # .idx file
# Fix path resolution - relative paths should be relative to metadata file directory
if not Path(index_file).is_absolute():
if metadata_file_path:
# Resolve relative to metadata file directory
metadata_dir = Path(metadata_file_path).parent
logger.debug(
f"PassageManager: Resolving relative paths from metadata_dir: {metadata_dir}"
)
index_file = str((metadata_dir / index_file).resolve())
passage_file = str((metadata_dir / passage_file).resolve())
logger.debug(f"PassageManager: Resolved index_file: {index_file}")
else:
# Fallback to current directory resolution (legacy behavior)
logger.warning(
"PassageManager: No metadata_file_path provided, using fallback resolution from cwd"
)
logger.debug(f"PassageManager: Current working directory: {Path.cwd()}")
index_file = str(Path(index_file).resolve())
passage_file = str(Path(passage_file).resolve())
logger.debug(f"PassageManager: Fallback resolved index_file: {index_file}")
if not Path(index_file).exists(): if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}") raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f: with open(index_file, "rb") as f:
offset_map = pickle.load(f) offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map self.offset_maps[passage_file] = offset_map
@@ -165,11 +128,11 @@ class PassageManager:
for passage_id, offset in offset_map.items(): for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset) self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> dict[str, Any]: def get_passage(self, passage_id: str) -> Dict[str, Any]:
if passage_id in self.global_offset_map: if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id] passage_file, offset = self.global_offset_map[passage_id]
# Lazy file opening - only open when needed # Lazy file opening - only open when needed
with open(passage_file, encoding="utf-8") as f: with open(passage_file, "r", encoding="utf-8") as f:
f.seek(offset) f.seek(offset)
return json.loads(f.readline()) return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}") raise KeyError(f"Passage ID not found: {passage_id}")
@@ -179,93 +142,25 @@ class LeannBuilder:
def __init__( def __init__(
self, self,
backend_name: str, backend_name: str,
embedding_model: str = "facebook/contriever", embedding_model: str = "facebook/contriever-msmarco",
dimensions: Optional[int] = None, dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
**backend_kwargs, **backend_kwargs,
): ):
self.backend_name = backend_name self.backend_name = backend_name
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name) backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
backend_name
)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.") raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory self.backend_factory = backend_factory
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions self.dimensions = dimensions
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
# Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
# OpenAI models
("openai", "text-embedding-ada-002"),
("openai", "text-embedding-3-small"),
("openai", "text-embedding-3-large"),
# Voyage AI models
("voyage", "voyage-2"),
("voyage", "voyage-3"),
("voyage", "voyage-large-2"),
("voyage", "voyage-multilingual-2"),
("voyage", "voyage-code-2"),
# Cohere models
("cohere", "embed-english-v3.0"),
("cohere", "embed-multilingual-v3.0"),
("cohere", "embed-english-light-v3.0"),
("cohere", "embed-multilingual-light-v3.0"),
}
# Also check for patterns in model names
is_normalized = False
current_model_lower = embedding_model.lower()
current_mode_lower = embedding_mode.lower()
# Check exact matches
for mode, model in normalized_embeddings_models:
if (current_mode_lower == mode and current_model_lower == model) or (
mode in current_mode_lower and model in current_model_lower
):
is_normalized = True
break
# Check patterns
if not is_normalized:
# OpenAI patterns
if "openai" in current_mode_lower or "openai" in current_model_lower:
if any(
pattern in current_model_lower
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
):
is_normalized = True
# Voyage patterns
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
is_normalized = True
# Cohere patterns
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
if "embed" in current_model_lower:
is_normalized = True
# Handle distance metric
if is_normalized and "distance_metric" not in backend_kwargs:
backend_kwargs["distance_metric"] = "cosine"
warnings.warn(
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
f"Automatically setting distance_metric='cosine' for optimal performance. "
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
UserWarning,
stacklevel=2,
)
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
current_metric = backend_kwargs.get("distance_metric", "mips")
warnings.warn(
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
f"'{embedding_model}' may lead to suboptimal search results. "
f"Consider using 'cosine' distance metric for better performance.",
UserWarning,
stacklevel=2,
)
self.backend_kwargs = backend_kwargs self.backend_kwargs = backend_kwargs
self.chunks: list[dict[str, Any]] = [] self.chunks: List[Dict[str, Any]] = []
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None: if metadata is None:
metadata = {} metadata = {}
passage_id = metadata.get("id", str(len(self.chunks))) passage_id = metadata.get("id", str(len(self.chunks)))
@@ -295,7 +190,9 @@ class LeannBuilder:
try: try:
from tqdm import tqdm from tqdm import tqdm
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk") chunk_iterator = tqdm(
self.chunks, desc="Writing passages", unit="chunk"
)
except ImportError: except ImportError:
chunk_iterator = self.chunks chunk_iterator = self.chunks
@@ -325,7 +222,9 @@ class LeannBuilder:
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs) builder_instance.build(
embeddings, string_ids, index_path, **current_backend_kwargs
)
leann_meta_path = index_dir / f"{index_name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "1.0", "version": "1.0",
@@ -337,8 +236,8 @@ class LeannBuilder:
"passage_sources": [ "passage_sources": [
{ {
"type": "jsonl", "type": "jsonl",
"path": passages_file.name, # Use relative path (just filename) "path": str(passages_file),
"index_path": offset_file.name, # Use relative path (just filename) "index_path": str(offset_file),
} }
], ],
} }
@@ -374,7 +273,9 @@ class LeannBuilder:
ids, embeddings = data ids, embeddings = data
if not isinstance(embeddings, np.ndarray): if not isinstance(embeddings, np.ndarray):
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}") raise ValueError(
f"Expected embeddings to be numpy array, got {type(embeddings)}"
)
if len(ids) != embeddings.shape[0]: if len(ids) != embeddings.shape[0]:
raise ValueError( raise ValueError(
@@ -386,7 +287,9 @@ class LeannBuilder:
if self.dimensions is None: if self.dimensions is None:
self.dimensions = embedding_dim self.dimensions = embedding_dim
elif self.dimensions != embedding_dim: elif self.dimensions != embedding_dim:
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}") raise ValueError(
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
)
logger.info( logger.info(
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions" f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
@@ -453,8 +356,8 @@ class LeannBuilder:
"passage_sources": [ "passage_sources": [
{ {
"type": "jsonl", "type": "jsonl",
"path": passages_file.name, # Use relative path (just filename) "path": str(passages_file),
"index_path": offset_file.name, # Use relative path (just filename) "index_path": str(offset_file),
} }
], ],
"built_from_precomputed_embeddings": True, "built_from_precomputed_embeddings": True,
@@ -471,34 +374,27 @@ class LeannBuilder:
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}") logger.info(
f"Index built successfully from precomputed embeddings: {index_path}"
)
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
self.meta_path_str = f"{index_path}.meta.json" self.meta_path_str = f"{index_path}.meta.json"
if not Path(self.meta_path_str).exists(): if not Path(self.meta_path_str).exists():
parent_dir = Path(index_path).parent
print(
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
)
# highlight in red the filenotfound error
raise FileNotFoundError( raise FileNotFoundError(
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m" f"Leann metadata file not found at {self.meta_path_str}"
) )
with open(self.meta_path_str, encoding="utf-8") as f: with open(self.meta_path_str, "r", encoding="utf-8") as f:
self.meta_data = json.load(f) self.meta_data = json.load(f)
backend_name = self.meta_data["backend_name"] backend_name = self.meta_data["backend_name"]
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta_data.get(
self.passage_manager = PassageManager( "embedding_mode", "sentence-transformers"
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
) )
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
@@ -519,22 +415,12 @@ class LeannSearcher:
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
**kwargs, **kwargs,
) -> list[SearchResult]: ) -> List[SearchResult]:
logger.info("🔍 LeannSearcher.search() called:") logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'") logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}") logger.info(f" Top_k: {top_k}")
logger.info(f" Additional kwargs: {kwargs}") logger.info(f" Additional kwargs: {kwargs}")
# Smart top_k detection and adjustment
total_docs = len(self.passage_manager.global_offset_map)
original_top_k = top_k
if top_k > total_docs:
top_k = total_docs
logger.warning(
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
)
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
zmq_port = None zmq_port = None
start_time = time.time() start_time = time.time()
@@ -555,9 +441,9 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings, use_server_if_available=recompute_embeddings,
zmq_port=zmq_port, zmq_port=zmq_port,
) )
# logger.info(f" Generated embedding shape: {query_embedding.shape}") logger.info(f" Generated embedding shape: {query_embedding.shape}")
time.time() - start_time embedding_time = time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds") logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time() start_time = time.time()
results = self.backend_impl.search( results = self.backend_impl.search(
@@ -571,13 +457,15 @@ class LeannSearcher:
zmq_port=zmq_port, zmq_port=zmq_port,
**kwargs, **kwargs,
) )
# logger.info(f" Search time: {search_time} seconds") search_time = time.time() - start_time
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results") logger.info(f" Search time: {search_time} seconds")
logger.info(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
enriched_results = [] enriched_results = []
if "labels" in results and "distances" in results: if "labels" in results and "distances" in results:
logger.info(f" Processing {len(results['labels'][0])} passage IDs:") logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
# Python 3.9 does not support zip(strict=...); lengths are expected to match
for i, (string_id, dist) in enumerate( for i, (string_id, dist) in enumerate(
zip(results["labels"][0], results["distances"][0]) zip(results["labels"][0], results["distances"][0])
): ):
@@ -591,59 +479,23 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}), metadata=passage_data.get("metadata", {}),
) )
) )
# Color codes for better logging
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RESET = "\033[0m"
# Truncate text for display (first 100 chars)
display_text = passage_data["text"]
logger.info( logger.info(
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}" f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
) )
except KeyError: except KeyError:
RED = "\033[91m"
RESET = "\033[0m"
logger.error( logger.error(
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}" f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
) )
# Define color codes outside the loop for final message logger.info(f" Final enriched results: {len(enriched_results)} passages")
GREEN = "\033[92m"
RESET = "\033[0m"
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results return enriched_results
def cleanup(self):
"""Explicitly cleanup embedding server and ZMQ resources.
This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios.
"""
# Stop embedding server
if hasattr(self.backend_impl, "embedding_server_manager"):
self.backend_impl.embedding_server_manager.stop_server()
# Set ZMQ linger but don't terminate global context
try:
import zmq
# Just set linger on the global instance
ctx = zmq.Context.instance()
ctx.linger = 0
# NEVER call ctx.term() or destroy() on the global instance
# That would block waiting for all sockets to close
except Exception:
pass
class LeannChat: class LeannChat:
def __init__( def __init__(
self, self,
index_path: str, index_path: str,
llm_config: Optional[dict[str, Any]] = None, llm_config: Optional[Dict[str, Any]] = None,
enable_warmup: bool = False, enable_warmup: bool = False,
**kwargs, **kwargs,
): ):
@@ -659,13 +511,13 @@ class LeannChat:
prune_ratio: float = 0.0, prune_ratio: float = 0.0,
recompute_embeddings: bool = True, recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None, llm_kwargs: Optional[Dict[str, Any]] = None,
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
**search_kwargs, **search_kwargs,
): ):
if llm_kwargs is None: if llm_kwargs is None:
llm_kwargs = {} llm_kwargs = {}
search_time = time.time()
results = self.searcher.search( results = self.searcher.search(
question, question,
top_k=top_k, top_k=top_k,
@@ -677,8 +529,6 @@ class LeannChat:
expected_zmq_port=expected_zmq_port, expected_zmq_port=expected_zmq_port,
**search_kwargs, **search_kwargs,
) )
search_time = time.time() - search_time
# logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results]) context = "\n\n".join([r.text for r in results])
prompt = ( prompt = (
"Here is some retrieved context that might help answer your question:\n\n" "Here is some retrieved context that might help answer your question:\n\n"
@@ -687,10 +537,7 @@ class LeannChat:
"Please provide the best answer you can based on this context and your knowledge." "Please provide the best answer you can based on this context and your knowledge."
) )
ask_time = time.time()
ans = self.llm.ask(prompt, **llm_kwargs) ans = self.llm.ask(prompt, **llm_kwargs)
ask_time = time.time() - ask_time
logger.info(f" Ask time: {ask_time} seconds")
return ans return ans
def start_interactive(self): def start_interactive(self):
@@ -707,12 +554,3 @@ class LeannChat:
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
print("\nGoodbye!") print("\nGoodbye!")
break break
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
if hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()

View File

@@ -4,25 +4,22 @@ This file contains the chat generation logic for the LEANN project,
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode. supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
""" """
import difflib from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import logging import logging
import os import os
from abc import ABC, abstractmethod import difflib
from typing import Any, Optional
import torch
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_ollama_models(host: str) -> list[str]: def check_ollama_models() -> List[str]:
"""Check available Ollama models and return a list""" """Check available Ollama models and return a list"""
try: try:
import requests import requests
response = requests.get("http://localhost:11434/api/tags", timeout=5)
response = requests.get(f"{host}/api/tags", timeout=5)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
return [model["name"] for model in data.get("models", [])] return [model["name"] for model in data.get("models", [])]
@@ -31,135 +28,68 @@ def check_ollama_models(host: str) -> list[str]:
return [] return []
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]: def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import re
import requests
# Split model name and tag
if ":" in model_name:
base_model, requested_tag = model_name.split(":", 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if "<" in tag or ">" in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
return True, available_tags[:10] # Return up to 10 tags
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
"""Use intelligent fuzzy search for Ollama models""" """Use intelligent fuzzy search for Ollama models"""
if not available_models: if not available_models:
return [] return []
query_lower = query.lower() query_lower = query.lower()
suggestions = [] suggestions = []
# 1. Exact matches first # 1. Exact matches first
exact_matches = [m for m in available_models if query_lower == m.lower()] exact_matches = [m for m in available_models if query_lower == m.lower()]
suggestions.extend(exact_matches) suggestions.extend(exact_matches)
# 2. Starts with query # 2. Starts with query
starts_with = [ starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
]
suggestions.extend(starts_with) suggestions.extend(starts_with)
# 3. Contains query # 3. Contains query
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions] contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
suggestions.extend(contains) suggestions.extend(contains)
# 4. Base model name matching (remove version numbers) # 4. Base model name matching (remove version numbers)
def get_base_name(model_name: str) -> str: def get_base_name(model_name: str) -> str:
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')""" """Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
return model_name.split(":")[0].split("-")[0] return model_name.split(':')[0].split('-')[0]
query_base = get_base_name(query_lower) query_base = get_base_name(query_lower)
base_matches = [ base_matches = [
m m for m in available_models
for m in available_models
if get_base_name(m.lower()) == query_base and m not in suggestions if get_base_name(m.lower()) == query_base and m not in suggestions
] ]
suggestions.extend(base_matches) suggestions.extend(base_matches)
# 5. Family/variant matching # 5. Family/variant matching
model_families = { model_families = {
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"], 'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
"qwen": ["qwen", "qwen2", "qwen3"], 'qwen': ['qwen', 'qwen2', 'qwen3'],
"gemma": ["gemma", "gemma2"], 'gemma': ['gemma', 'gemma2'],
"phi": ["phi", "phi2", "phi3"], 'phi': ['phi', 'phi2', 'phi3'],
"mistral": ["mistral", "mixtral", "openhermes"], 'mistral': ['mistral', 'mixtral', 'openhermes'],
"dolphin": ["dolphin", "openchat"], 'dolphin': ['dolphin', 'openchat'],
"deepseek": ["deepseek", "deepseek-coder"], 'deepseek': ['deepseek', 'deepseek-coder']
} }
query_family = None query_family = None
for family, variants in model_families.items(): for family, variants in model_families.items():
if any(variant in query_lower for variant in variants): if any(variant in query_lower for variant in variants):
query_family = family query_family = family
break break
if query_family: if query_family:
family_variants = model_families[query_family] family_variants = model_families[query_family]
family_matches = [ family_matches = [
m m for m in available_models
for m in available_models
if any(variant in m.lower() for variant in family_variants) and m not in suggestions if any(variant in m.lower() for variant in family_variants) and m not in suggestions
] ]
suggestions.extend(family_matches) suggestions.extend(family_matches)
# 6. Use difflib for remaining fuzzy matches # 6. Use difflib for remaining fuzzy matches
remaining_models = [m for m in available_models if m not in suggestions] remaining_models = [m for m in available_models if m not in suggestions]
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4) difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
suggestions.extend(difflib_matches) suggestions.extend(difflib_matches)
return suggestions[:8] # Return top 8 suggestions return suggestions[:8] # Return top 8 suggestions
@@ -169,13 +99,15 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
# Remove this too - no need for fallback # Remove this too - no need for fallback
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]: def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
"""Use difflib to find similar model names""" """Use difflib to find similar model names"""
if not available_models: if not available_models:
return [] return []
# Get close matches using fuzzy matching # Get close matches using fuzzy matching
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3) suggestions = difflib.get_close_matches(
invalid_model, available_models, n=3, cutoff=0.3
)
return suggestions return suggestions
@@ -183,50 +115,49 @@ def check_hf_model_exists(model_name: str) -> bool:
"""Quick check if HuggingFace model exists without downloading""" """Quick check if HuggingFace model exists without downloading"""
try: try:
from huggingface_hub import model_info from huggingface_hub import model_info
model_info(model_name) model_info(model_name)
return True return True
except Exception: except Exception:
return False return False
def get_popular_hf_models() -> list[str]: def get_popular_hf_models() -> List[str]:
"""Return a list of popular HuggingFace models for suggestions""" """Return a list of popular HuggingFace models for suggestions"""
try: try:
from huggingface_hub import list_models from huggingface_hub import list_models
# Get popular text-generation models, sorted by downloads # Get popular text-generation models, sorted by downloads
models = list_models( models = list_models(
filter="text-generation", filter="text-generation",
sort="downloads", sort="downloads",
direction=-1, direction=-1,
limit=20, # Get top 20 most downloaded limit=20 # Get top 20 most downloaded
) )
# Extract model names and filter for chat/conversation models # Extract model names and filter for chat/conversation models
model_names = [] model_names = []
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"] chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
for model in models: for model in models:
model_name = model.id if hasattr(model, "id") else str(model) model_name = model.id if hasattr(model, 'id') else str(model)
# Prioritize models with chat-related keywords # Prioritize models with chat-related keywords
if any(keyword in model_name.lower() for keyword in chat_keywords): if any(keyword in model_name.lower() for keyword in chat_keywords):
model_names.append(model_name) model_names.append(model_name)
elif len(model_names) < 10: # Fill up with other popular models elif len(model_names) < 10: # Fill up with other popular models
model_names.append(model_name) model_names.append(model_name)
return model_names[:10] if model_names else _get_fallback_hf_models() return model_names[:10] if model_names else _get_fallback_hf_models()
except Exception: except Exception:
# Fallback to static list if API call fails # Fallback to static list if API call fails
return _get_fallback_hf_models() return _get_fallback_hf_models()
def _get_fallback_hf_models() -> list[str]: def _get_fallback_hf_models() -> List[str]:
"""Fallback list of popular HuggingFace models""" """Fallback list of popular HuggingFace models"""
return [ return [
"microsoft/DialoGPT-medium", "microsoft/DialoGPT-medium",
"microsoft/DialoGPT-large", "microsoft/DialoGPT-large",
"facebook/blenderbot-400M-distill", "facebook/blenderbot-400M-distill",
"microsoft/phi-2", "microsoft/phi-2",
"deepseek-ai/deepseek-llm-7b-chat", "deepseek-ai/deepseek-llm-7b-chat",
@@ -234,44 +165,44 @@ def _get_fallback_hf_models() -> list[str]:
"facebook/blenderbot_small-90M", "facebook/blenderbot_small-90M",
"microsoft/phi-1_5", "microsoft/phi-1_5",
"facebook/opt-350m", "facebook/opt-350m",
"EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-1.3B"
] ]
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]: def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
"""Use HuggingFace Hub's native fuzzy search for model suggestions""" """Use HuggingFace Hub's native fuzzy search for model suggestions"""
try: try:
from huggingface_hub import list_models from huggingface_hub import list_models
# HF Hub's search is already fuzzy! It handles typos and partial matches # HF Hub's search is already fuzzy! It handles typos and partial matches
models = list_models( models = list_models(
search=query, search=query,
filter="text-generation", filter="text-generation",
sort="downloads", sort="downloads",
direction=-1, direction=-1,
limit=limit, limit=limit
) )
model_names = [model.id if hasattr(model, "id") else str(model) for model in models] model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
# If direct search doesn't return enough results, try some variations # If direct search doesn't return enough results, try some variations
if len(model_names) < 3: if len(model_names) < 3:
# Try searching for partial matches or common variations # Try searching for partial matches or common variations
variations = [] variations = []
# Extract base name (e.g., "gpt3" from "gpt-3.5") # Extract base name (e.g., "gpt3" from "gpt-3.5")
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "") base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
if base_query != query.lower(): if base_query != query.lower():
variations.append(base_query) variations.append(base_query)
# Try common model name patterns # Try common model name patterns
if "gpt" in query.lower(): if 'gpt' in query.lower():
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"]) variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
elif "llama" in query.lower(): elif 'llama' in query.lower():
variations.extend(["llama2", "alpaca", "vicuna"]) variations.extend(['llama2', 'alpaca', 'vicuna'])
elif "bert" in query.lower(): elif 'bert' in query.lower():
variations.extend(["roberta", "distilbert", "albert"]) variations.extend(['roberta', 'distilbert', 'albert'])
# Search with variations # Search with variations
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
try: try:
@@ -280,15 +211,13 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
filter="text-generation", filter="text-generation",
sort="downloads", sort="downloads",
direction=-1, direction=-1,
limit=3, limit=3
) )
var_names = [ var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
model.id if hasattr(model, "id") else str(model) for model in var_models
]
model_names.extend(var_names) model_names.extend(var_names)
except Exception: except:
continue continue
# Remove duplicates while preserving order # Remove duplicates while preserving order
seen = set() seen = set()
unique_models = [] unique_models = []
@@ -296,102 +225,50 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
if model not in seen: if model not in seen:
seen.add(model) seen.add(model)
unique_models.append(model) unique_models.append(model)
return unique_models[:limit] return unique_models[:limit]
except Exception: except Exception:
# If search fails, return empty list # If search fails, return empty list
return [] return []
def search_hf_models(query: str, limit: int = 10) -> list[str]: def search_hf_models(query: str, limit: int = 10) -> List[str]:
"""Simple search for HuggingFace models based on query (kept for backward compatibility)""" """Simple search for HuggingFace models based on query (kept for backward compatibility)"""
return search_hf_models_fuzzy(query, limit) return search_hf_models_fuzzy(query, limit)
def validate_model_and_suggest( def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
model_name: str, llm_type: str, host: str = "http://localhost:11434"
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid""" """Validate model name and provide suggestions if invalid"""
if llm_type == "ollama": if llm_type == "ollama":
available_models = check_ollama_models(host) available_models = check_ollama_models()
if available_models and model_name not in available_models: if available_models and model_name not in available_models:
# Use intelligent fuzzy search based on locally installed models
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg = f"Model '{model_name}' not found in your local Ollama installation." error_msg = f"Model '{model_name}' not found in your local Ollama installation."
if suggestions:
# Check if the model exists remotely and get available tags error_msg += "\n\nDid you mean one of these installed models?\n"
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name) for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += "\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(":")[0]
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
error_msg += (
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
)
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else: else:
# Model doesn't exist remotely - show fuzzy suggestions error_msg += "\n\nYour installed models:\n"
suggestions = search_ollama_models_fuzzy(model_name, available_models) for i, model in enumerate(available_models[:8], 1):
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library." error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
if suggestions: error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += (
"\n\nDid you mean one of these installed models?\n" error_msg += "\nTo list all models: ollama list"
+ "\nTry to use ollama pull to install the model you need\n" error_msg += "\nTo download a new model: ollama pull <model_name>"
) error_msg += "\nBrowse models: https://ollama.com/library"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += (
f"\n ollama pull {available_tags[0]} # Install recommended variant"
)
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg return error_msg
elif llm_type == "hf": elif llm_type == "hf":
# For HF models, we can do a quick existence check # For HF models, we can do a quick existence check
if not check_hf_model_exists(model_name): if not check_hf_model_exists(model_name):
# Use HF Hub's native fuzzy search directly # Use HF Hub's native fuzzy search directly
search_suggestions = search_hf_models_fuzzy(model_name, limit=8) search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
error_msg = f"Model '{model_name}' not found on HuggingFace Hub." error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
if search_suggestions: if search_suggestions:
error_msg += "\n\nDid you mean one of these?\n" error_msg += "\n\nDid you mean one of these?\n"
@@ -403,10 +280,10 @@ def validate_model_and_suggest(
error_msg += "\n\nPopular chat models:\n" error_msg += "\n\nPopular chat models:\n"
for i, model in enumerate(popular_models[:5], 1): for i, model in enumerate(popular_models[:5], 1):
error_msg += f" {i}. {model}\n" error_msg += f" {i}. {model}\n"
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation" error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
return error_msg return error_msg
return None # Model is valid or we can't check return None # Model is valid or we can't check
@@ -469,61 +346,38 @@ class OllamaChat(LLMInterface):
# Check if the Ollama server is responsive # Check if the Ollama server is responsive
if host: if host:
requests.get(host) requests.get(host)
# Pre-check model availability with helpful suggestions # Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host) model_error = validate_model_and_suggest(model, "ollama")
if model_error: if model_error:
raise ValueError(model_error) raise ValueError(model_error)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'." "The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") logger.error(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
)
raise ConnectionError( raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running." f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
) )
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
import requests
import json import json
import requests
full_url = f"{self.host}/api/generate" full_url = f"{self.host}/api/generate"
# Handle thinking budget for reasoning models
options = kwargs.copy()
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget:
# Remove thinking_budget from options as it's not a standard Ollama option
options.pop("thinking_budget", None)
# Only apply reasoning parameters to models that support it
reasoning_supported_models = [
"gpt-oss:20b",
"gpt-oss:120b",
"deepseek-r1",
"deepseek-coder",
]
if thinking_budget in ["low", "medium", "high"]:
if any(model in self.model.lower() for model in reasoning_supported_models):
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
else:
logger.warning(
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
)
payload = { payload = {
"model": self.model, "model": self.model,
"prompt": prompt, "prompt": prompt,
"stream": False, # Keep it simple for now "stream": False, # Keep it simple for now
"options": options, "options": kwargs,
} }
logger.debug(f"Sending request to Ollama: {payload}") logger.debug(f"Sending request to Ollama: {payload}")
try: try:
logger.info("Sending request to Ollama and waiting for response...") logger.info(f"Sending request to Ollama and waiting for response...")
response = requests.post(full_url, data=json.dumps(payload)) response = requests.post(full_url, data=json.dumps(payload))
response.raise_for_status() response.raise_for_status()
@@ -543,19 +397,19 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface): class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models with proper chat templates.""" """LLM interface for local Hugging Face Transformers models."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"): def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'") logger.info(f"Initializing HFChat with model='{model_name}'")
# Pre-check model availability with helpful suggestions # Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model_name, "hf") model_error = validate_model_and_suggest(model_name, "hf")
if model_error: if model_error:
raise ValueError(model_error) raise ValueError(model_error)
try: try:
from transformers.pipelines import pipeline
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'." "The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
@@ -563,123 +417,54 @@ class HFChat(LLMInterface):
# Auto-detect device # Auto-detect device
if torch.cuda.is_available(): if torch.cuda.is_available():
self.device = "cuda" device = "cuda"
logger.info("CUDA is available. Using GPU.") logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps" device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.") logger.info("MPS is available. Using Apple Silicon GPU.")
else: else:
self.device = "cpu" device = "cpu"
logger.info("No GPU detected. Using CPU.") logger.info("No GPU detected. Using CPU.")
# Load tokenizer and model with timeout protection self.pipeline = pipeline("text-generation", model=model_name, device=device)
try:
import signal
def timeout_handler(signum, frame):
raise TimeoutError("Model download/loading timed out")
# Set timeout for model loading (60 seconds)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(60)
try:
logger.info(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True,
)
logger.info(f"Successfully loaded {model_name}")
finally:
signal.alarm(0) # Cancel the alarm
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
except TimeoutError:
logger.error(f"Model loading timed out for {model_name}")
raise RuntimeError(
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
)
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
raise
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
print("kwargs in HF: ", kwargs) # Map OpenAI-style arguments to Hugging Face equivalents
# Check if this is a Qwen model and add /no_think by default if "max_tokens" in kwargs:
is_qwen_model = "qwen" in self.model.config._name_or_path.lower() # Prefer user-provided max_new_tokens if both are present
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
# Remove the unsupported key to avoid errors in Transformers
kwargs.pop("max_tokens")
# For Qwen models, automatically add /no_think to the prompt # Handle temperature=0 edge-case for greedy decoding
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt: if "temperature" in kwargs and kwargs["temperature"] == 0.0:
prompt = prompt + " /no_think" # Remove unsupported zero temperature and use deterministic generation
kwargs.pop("temperature")
kwargs.setdefault("do_sample", False)
# Prepare chat template # Sensible defaults for text generation
messages = [{"role": "user", "content": prompt}] params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
# Apply chat template if available # Handle different response formats from transformers
if hasattr(self.tokenizer, "apply_chat_template"): if isinstance(results, list) and len(results) > 0:
try: generated_text = (
formatted_prompt = self.tokenizer.apply_chat_template( results[0].get("generated_text", "")
messages, tokenize=False, add_generation_prompt=True if isinstance(results[0], dict)
) else str(results[0])
except Exception as e: )
logger.warning(f"Chat template failed, using raw prompt: {e}")
formatted_prompt = prompt
else: else:
# Fallback for models without chat template generated_text = str(results)
formatted_prompt = prompt
# Tokenize input # Extract only the newly generated portion by removing the original prompt
inputs = self.tokenizer( if isinstance(generated_text, str) and generated_text.startswith(prompt):
formatted_prompt, response = generated_text[len(prompt) :].strip()
return_tensors="pt", else:
padding=True, # Fallback: return the full response if prompt removal fails
truncation=True, response = str(generated_text)
max_length=2048,
)
# Move inputs to device return response
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Set generation parameters
generation_config = {
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(**inputs, **generation_config)
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class OpenAIChat(LLMInterface): class OpenAIChat(LLMInterface):
@@ -710,38 +495,15 @@ class OpenAIChat(LLMInterface):
params = { params = {
"model": self.model, "model": self.model,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 0.7), "temperature": kwargs.get("temperature", 0.7),
**{
k: v
for k, v in kwargs.items()
if k not in ["max_tokens", "temperature"]
},
} }
# Handle max_tokens vs max_completion_tokens based on model
max_tokens = kwargs.get("max_tokens", 1000)
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
# o-series models use max_completion_tokens
params["max_completion_tokens"] = max_tokens
params["temperature"] = 1.0
else:
# Other models use max_tokens
params["max_tokens"] = max_tokens
# Handle thinking budget for reasoning models
thinking_budget = kwargs.get("thinking_budget")
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
# Check if this is an o-series model (partial match for model names)
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
if any(model in self.model for model in o_series_models):
# Use the correct OpenAI reasoning parameter format
params["reasoning_effort"] = thinking_budget
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
else:
logger.warning(
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
)
# Add other kwargs (excluding thinking_budget as it's handled above)
for k, v in kwargs.items():
if k not in ["max_tokens", "temperature", "thinking_budget"]:
params[k] = v
logger.info(f"Sending request to OpenAI with model {self.model}") logger.info(f"Sending request to OpenAI with model {self.model}")
try: try:
@@ -761,7 +523,7 @@ class SimulatedChat(LLMInterface):
return "This is a simulated answer from the LLM based on the retrieved context." return "This is a simulated answer from the LLM based on the retrieved context."
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface: def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
""" """
Factory function to get an LLM interface based on configuration. Factory function to get an LLM interface based on configuration.

View File

@@ -5,59 +5,18 @@ from pathlib import Path
from llama_index.core import SimpleDirectoryReader from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannSearcher, LeannChat
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
"""Extract text from PDF using PyMuPDF for better quality."""
try:
import fitz # PyMuPDF
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text
except ImportError:
# Fallback to default reader
return None
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
"""Extract text from PDF using pdfplumber for better quality."""
try:
import pdfplumber
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
except ImportError:
# Fallback to default reader
return None
class LeannCLI: class LeannCLI:
def __init__(self): def __init__(self):
# Always use project-local .leann directory (like .git) self.indexes_dir = Path.home() / ".leann" / "indexes"
self.indexes_dir = Path.cwd() / ".leann" / "indexes"
self.indexes_dir.mkdir(parents=True, exist_ok=True) self.indexes_dir.mkdir(parents=True, exist_ok=True)
# Default parser for documents
self.node_parser = SentenceSplitter( self.node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n" chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
) )
# Code-optimized parser
self.code_parser = SentenceSplitter(
chunk_size=512, # Larger chunks for code context
chunk_overlap=50, # Less overlap to preserve function boundaries
separator="\n", # Split by lines for code
paragraph_separator="\n\n", # Preserve logical code blocks
)
def get_index_path(self, index_name: str) -> str: def get_index_path(self, index_name: str) -> str:
index_dir = self.indexes_dir / index_name index_dir = self.indexes_dir / index_name
return str(index_dir / "documents.leann") return str(index_dir / "documents.leann")
@@ -74,11 +33,10 @@ class LeannCLI:
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples: Examples:
leann build my-docs --docs ./documents # Build index named my-docs leann build my-docs --docs ./documents # Build index named my-docs
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files leann search my-docs "query" # Search in my-docs index
leann search my-docs "query" # Search in my-docs index leann ask my-docs "question" # Ask my-docs index
leann ask my-docs "question" # Ask my-docs index leann list # List all stored indexes
leann list # List all stored indexes
""", """,
) )
@@ -86,34 +44,24 @@ Examples:
# Build command # Build command
build_parser = subparsers.add_parser("build", help="Build document index") build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
build_parser.add_argument( build_parser.add_argument(
"index_name", nargs="?", help="Index name (default: current directory name)" "--docs", type=str, required=True, help="Documents directory"
)
build_parser.add_argument(
"--docs", type=str, default=".", help="Documents directory (default: current directory)"
) )
build_parser.add_argument( build_parser.add_argument(
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"] "--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
) )
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
build_parser.add_argument( build_parser.add_argument(
"--embedding-mode", "--embedding-model", type=str, default="facebook/contriever"
type=str, )
default="sentence-transformers", build_parser.add_argument(
choices=["sentence-transformers", "openai", "mlx", "ollama"], "--force", "-f", action="store_true", help="Force rebuild"
help="Embedding backend mode (default: sentence-transformers)",
) )
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
build_parser.add_argument("--graph-degree", type=int, default=32) build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64) build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1) build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument("--compact", action="store_true", default=True) build_parser.add_argument("--compact", action="store_true", default=True)
build_parser.add_argument("--recompute", action="store_true", default=True) build_parser.add_argument("--recompute", action="store_true", default=True)
build_parser.add_argument(
"--file-types",
type=str,
help="Comma-separated list of file extensions to include (e.g., '.txt,.pdf,.pptx'). If not specified, uses default supported types.",
)
# Search command # Search command
search_parser = subparsers.add_parser("search", help="Search documents") search_parser = subparsers.add_parser("search", help="Search documents")
@@ -123,12 +71,7 @@ Examples:
search_parser.add_argument("--complexity", type=int, default=64) search_parser.add_argument("--complexity", type=int, default=64)
search_parser.add_argument("--beam-width", type=int, default=1) search_parser.add_argument("--beam-width", type=int, default=1)
search_parser.add_argument("--prune-ratio", type=float, default=0.0) search_parser.add_argument("--prune-ratio", type=float, default=0.0)
search_parser.add_argument( search_parser.add_argument("--recompute-embeddings", action="store_true")
"--recompute-embeddings",
action="store_true",
default=True,
help="Recompute embeddings (default: True)",
)
search_parser.add_argument( search_parser.add_argument(
"--pruning-strategy", "--pruning-strategy",
choices=["global", "local", "proportional"], choices=["global", "local", "proportional"],
@@ -151,370 +94,67 @@ Examples:
ask_parser.add_argument("--complexity", type=int, default=32) ask_parser.add_argument("--complexity", type=int, default=32)
ask_parser.add_argument("--beam-width", type=int, default=1) ask_parser.add_argument("--beam-width", type=int, default=1)
ask_parser.add_argument("--prune-ratio", type=float, default=0.0) ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
ask_parser.add_argument( ask_parser.add_argument("--recompute-embeddings", action="store_true")
"--recompute-embeddings",
action="store_true",
default=True,
help="Recompute embeddings (default: True)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--pruning-strategy", "--pruning-strategy",
choices=["global", "local", "proportional"], choices=["global", "local", "proportional"],
default="global", default="global",
) )
ask_parser.add_argument(
"--thinking-budget",
type=str,
choices=["low", "medium", "high"],
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
# List command # List command
subparsers.add_parser("list", help="List all indexes") list_parser = subparsers.add_parser("list", help="List all indexes")
return parser return parser
def register_project_dir(self):
"""Register current project directory in global registry"""
global_registry = Path.home() / ".leann" / "projects.json"
global_registry.parent.mkdir(exist_ok=True)
current_dir = str(Path.cwd())
# Load existing registry
projects = []
if global_registry.exists():
try:
import json
with open(global_registry) as f:
projects = json.load(f)
except Exception:
projects = []
# Add current directory if not already present
if current_dir not in projects:
projects.append(current_dir)
# Save registry
import json
with open(global_registry, "w") as f:
json.dump(projects, f, indent=2)
def _build_gitignore_parser(self, docs_dir: str):
"""Build gitignore parser using gitignore-parser library."""
from gitignore_parser import parse_gitignore
# Try to parse the root .gitignore
gitignore_path = Path(docs_dir) / ".gitignore"
if gitignore_path.exists():
try:
# gitignore-parser automatically handles all subdirectory .gitignore files!
matches = parse_gitignore(str(gitignore_path))
print(f"📋 Loaded .gitignore from {docs_dir} (includes all subdirectories)")
return matches
except Exception as e:
print(f"Warning: Could not parse .gitignore: {e}")
else:
print("📋 No .gitignore found")
# Fallback: basic pattern matching for essential files
essential_patterns = {".git", ".DS_Store", "__pycache__", "node_modules", ".venv", "venv"}
def basic_matches(file_path):
path_parts = Path(file_path).parts
return any(part in essential_patterns for part in path_parts)
return basic_matches
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
"""Check if a file should be excluded using gitignore parser."""
return gitignore_matches(str(relative_path))
def list_indexes(self): def list_indexes(self):
print("Stored LEANN indexes:") print("Stored LEANN indexes:")
# Get all project directories with .leann if not self.indexes_dir.exists():
global_registry = Path.home() / ".leann" / "projects.json" print(
all_projects = [] "No indexes found. Use 'leann build <name> --docs <dir>' to create one."
)
if global_registry.exists():
try:
import json
with open(global_registry) as f:
all_projects = json.load(f)
except Exception:
pass
# Filter to only existing directories with .leann
valid_projects = []
for project_dir in all_projects:
project_path = Path(project_dir)
if project_path.exists() and (project_path / ".leann" / "indexes").exists():
valid_projects.append(project_path)
# Add current project if it has .leann but not in registry
current_path = Path.cwd()
if (current_path / ".leann" / "indexes").exists() and current_path not in valid_projects:
valid_projects.append(current_path)
if not valid_projects:
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
return return
total_indexes = 0 index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
current_dir = Path.cwd()
for project_path in valid_projects: if not index_dirs:
indexes_dir = project_path / ".leann" / "indexes" print(
if not indexes_dir.exists(): "No indexes found. Use 'leann build <name> --docs <dir>' to create one."
continue )
return
index_dirs = [d for d in indexes_dir.iterdir() if d.is_dir()] print(f"Found {len(index_dirs)} indexes:")
if not index_dirs: for i, index_dir in enumerate(index_dirs, 1):
continue index_name = index_dir.name
status = "" if self.index_exists(index_name) else ""
# Show project header print(f" {i}. {index_name} [{status}]")
if project_path == current_dir: if self.index_exists(index_name):
print(f"\n📁 Current project ({project_path}):")
else:
print(f"\n📂 {project_path}:")
for index_dir in index_dirs:
total_indexes += 1
index_name = index_dir.name
meta_file = index_dir / "documents.leann.meta.json" meta_file = index_dir / "documents.leann.meta.json"
status = "" if meta_file.exists() else "" size_mb = sum(
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
) / (1024 * 1024)
print(f" Size: {size_mb:.1f} MB")
print(f" {total_indexes}. {index_name} [{status}]") if index_dirs:
if status == "": example_name = index_dirs[0].name
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / ( print(f"\nUsage:")
1024 * 1024 print(f' leann search {example_name} "your query"')
) print(f" leann ask {example_name} --interactive")
print(f" Size: {size_mb:.1f} MB")
if total_indexes > 0: def load_documents(self, docs_dir: str):
print(f"\nTotal: {total_indexes} indexes across {len(valid_projects)} projects")
print("\nUsage (current project only):")
# Show example from current project
current_indexes_dir = current_dir / ".leann" / "indexes"
if current_indexes_dir.exists():
current_index_dirs = [d for d in current_indexes_dir.iterdir() if d.is_dir()]
if current_index_dirs:
example_name = current_index_dirs[0].name
print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
def load_documents(self, docs_dir: str, custom_file_types: str | None = None):
print(f"Loading documents from {docs_dir}...") print(f"Loading documents from {docs_dir}...")
if custom_file_types:
print(f"Using custom file types: {custom_file_types}")
# Build gitignore parser documents = SimpleDirectoryReader(
gitignore_matches = self._build_gitignore_parser(docs_dir) docs_dir,
recursive=True,
# Try to use better PDF parsers first, but only if PDFs are requested encoding="utf-8",
documents = [] required_exts=[".pdf", ".txt", ".md", ".docx"],
docs_path = Path(docs_dir) ).load_data(show_progress=True)
# Check if we should process PDFs
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
if should_process_pdfs:
for file_path in docs_path.rglob("*.pdf"):
# Check if file matches any exclude pattern
relative_path = file_path.relative_to(docs_path)
if self._should_exclude_file(relative_path, gitignore_matches):
continue
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
try:
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
documents.extend(default_docs)
except Exception as e:
print(f"Warning: Could not process {file_path}: {e}")
# Load other file types with default reader
if custom_file_types:
# Parse custom file types from comma-separated string
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
# Ensure extensions start with a dot
code_extensions = [ext if ext.startswith(".") else f".{ext}" for ext in code_extensions]
else:
# Use default supported file types
code_extensions = [
# Original document types
".txt",
".md",
".docx",
".pptx",
# Code files for Claude Code integration
".py",
".js",
".ts",
".jsx",
".tsx",
".java",
".cpp",
".c",
".h",
".hpp",
".cs",
".go",
".rs",
".rb",
".php",
".swift",
".kt",
".scala",
".r",
".sql",
".sh",
".bash",
".zsh",
".fish",
".ps1",
".bat",
# Config and markup files
".json",
".yaml",
".yml",
".xml",
".toml",
".ini",
".cfg",
".conf",
".html",
".css",
".scss",
".less",
".vue",
".svelte",
# Data science
".ipynb",
".R",
".py",
".jl",
]
# Try to load other file types, but don't fail if none are found
try:
# Create a custom file filter function using our PathSpec
def file_filter(file_path: str) -> bool:
"""Return True if file should be included (not excluded)"""
try:
docs_path_obj = Path(docs_dir)
file_path_obj = Path(file_path)
relative_path = file_path_obj.relative_to(docs_path_obj)
return not self._should_exclude_file(relative_path, gitignore_matches)
except (ValueError, OSError):
return True # Include files that can't be processed
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=code_extensions,
file_extractor={}, # Use default extractors
filename_as_id=True,
).load_data(show_progress=True)
# Filter documents after loading based on gitignore rules
filtered_docs = []
for doc in other_docs:
file_path = doc.metadata.get("file_path", "")
if file_filter(file_path):
filtered_docs.append(doc)
documents.extend(filtered_docs)
except ValueError as e:
if "No files found" in str(e):
print("No additional files found for other supported types.")
else:
raise e
all_texts = [] all_texts = []
# Define code file extensions for intelligent chunking
code_file_exts = {
".py",
".js",
".ts",
".jsx",
".tsx",
".java",
".cpp",
".c",
".h",
".hpp",
".cs",
".go",
".rs",
".rb",
".php",
".swift",
".kt",
".scala",
".r",
".sql",
".sh",
".bash",
".zsh",
".fish",
".ps1",
".bat",
".json",
".yaml",
".yml",
".xml",
".toml",
".ini",
".cfg",
".conf",
".html",
".css",
".scss",
".less",
".vue",
".svelte",
".ipynb",
".R",
".jl",
}
for doc in documents: for doc in documents:
# Check if this is a code file based on source path nodes = self.node_parser.get_nodes_from_documents([doc])
source_path = doc.metadata.get("source", "")
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
# Use appropriate parser based on file type
parser = self.code_parser if is_code_file else self.node_parser
nodes = parser.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.get_content()) all_texts.append(node.get_content())
@@ -523,23 +163,15 @@ Examples:
async def build_index(self, args): async def build_index(self, args):
docs_dir = args.docs docs_dir = args.docs
# Use current directory name if index_name not provided index_name = args.index_name
if args.index_name:
index_name = args.index_name
else:
index_name = Path.cwd().name
print(f"Using current directory name as index: '{index_name}'")
index_dir = self.indexes_dir / index_name index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name) index_path = self.get_index_path(index_name)
print(f"📂 Indexing: {Path(docs_dir).resolve()}")
if index_dir.exists() and not args.force: if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.") print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return return
all_texts = self.load_documents(docs_dir, args.file_types) all_texts = self.load_documents(docs_dir)
if not all_texts: if not all_texts:
print("No documents found") print("No documents found")
return return
@@ -551,7 +183,6 @@ Examples:
builder = LeannBuilder( builder = LeannBuilder(
backend_name=args.backend, backend_name=args.backend,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
graph_degree=args.graph_degree, graph_degree=args.graph_degree,
complexity=args.complexity, complexity=args.complexity,
is_compact=args.compact, is_compact=args.compact,
@@ -565,9 +196,6 @@ Examples:
builder.build_index(index_path) builder.build_index(index_path)
print(f"Index built at {index_path}") print(f"Index built at {index_path}")
# Register this project directory in global registry
self.register_project_dir()
async def search_documents(self, args): async def search_documents(self, args):
index_name = args.index_name index_name = args.index_name
query = args.query query = args.query
@@ -628,11 +256,6 @@ Examples:
if not user_input: if not user_input:
continue continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask( response = chat.ask(
user_input, user_input,
top_k=args.top_k, top_k=args.top_k,
@@ -641,17 +264,11 @@ Examples:
prune_ratio=args.prune_ratio, prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings, recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy, pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
) )
print(f"LEANN: {response}") print(f"LEANN: {response}")
else: else:
query = input("Enter your question: ").strip() query = input("Enter your question: ").strip()
if query: if query:
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask( response = chat.ask(
query, query,
top_k=args.top_k, top_k=args.top_k,
@@ -660,7 +277,6 @@ Examples:
prune_ratio=args.prune_ratio, prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings, recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy, pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
) )
print(f"LEANN: {response}") print(f"LEANN: {response}")

View File

@@ -4,13 +4,11 @@ Consolidates all embedding computation logic using SentenceTransformer
Preserves all optimization parameters to ensure performance Preserves all optimization parameters to ensure performance
""" """
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import numpy as np import numpy as np
import torch import torch
from typing import List, Dict, Any
import logging
import os
# Set up logger with proper level # Set up logger with proper level
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,11 +17,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level) logger.setLevel(log_level)
# Global model cache to avoid repeated loading # Global model cache to avoid repeated loading
_model_cache: dict[str, Any] = {} _model_cache: Dict[str, Any] = {}
def compute_embeddings( def compute_embeddings(
texts: list[str], texts: List[str],
model_name: str, model_name: str,
mode: str = "sentence-transformers", mode: str = "sentence-transformers",
is_build: bool = False, is_build: bool = False,
@@ -36,7 +34,7 @@ def compute_embeddings(
Args: Args:
texts: List of texts to compute embeddings for texts: List of texts to compute embeddings for
model_name: Model name model_name: Model name
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama') mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
batch_size: Batch size for processing batch_size: Batch size for processing
adaptive_optimization: Whether to use adaptive optimization based on batch size adaptive_optimization: Whether to use adaptive optimization based on batch size
@@ -56,14 +54,12 @@ def compute_embeddings(
return compute_embeddings_openai(texts, model_name) return compute_embeddings_openai(texts, model_name)
elif mode == "mlx": elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name) return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
else: else:
raise ValueError(f"Unsupported embedding mode: {mode}") raise ValueError(f"Unsupported embedding mode: {mode}")
def compute_embeddings_sentence_transformers( def compute_embeddings_sentence_transformers(
texts: list[str], texts: List[str],
model_name: str, model_name: str,
use_fp16: bool = True, use_fp16: bool = True,
device: str = "auto", device: str = "auto",
@@ -105,7 +101,7 @@ def compute_embeddings_sentence_transformers(
if device == "mps": if device == "mps":
batch_size = 128 # MPS optimal batch size from benchmark batch_size = 128 # MPS optimal batch size from benchmark
if model_name == "Qwen/Qwen3-Embedding-0.6B": if model_name == "Qwen/Qwen3-Embedding-0.6B":
batch_size = 32 batch_size = 64
elif device == "cuda": elif device == "cuda":
batch_size = 256 # CUDA optimal batch size batch_size = 256 # CUDA optimal batch size
# Keep original batch_size for CPU # Keep original batch_size for CPU
@@ -118,7 +114,9 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Using cached optimized model: {model_name}") logger.info(f"Using cached optimized model: {model_name}")
model = _model_cache[cache_key] model = _model_cache[cache_key]
else: else:
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}") logger.info(
f"Loading and caching optimized SentenceTransformer model: {model_name}"
)
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
logger.info(f"Using device: {device}") logger.info(f"Using device: {device}")
@@ -136,7 +134,9 @@ def compute_embeddings_sentence_transformers(
if hasattr(torch.mps, "set_per_process_memory_fraction"): if hasattr(torch.mps, "set_per_process_memory_fraction"):
torch.mps.set_per_process_memory_fraction(0.9) torch.mps.set_per_process_memory_fraction(0.9)
except AttributeError: except AttributeError:
logger.warning("Some MPS optimizations not available in this PyTorch version") logger.warning(
"Some MPS optimizations not available in this PyTorch version"
)
elif device == "cpu": elif device == "cpu":
# TODO: Haven't tested this yet # TODO: Haven't tested this yet
torch.set_num_threads(min(8, os.cpu_count() or 4)) torch.set_num_threads(min(8, os.cpu_count() or 4))
@@ -226,22 +226,25 @@ def compute_embeddings_sentence_transformers(
device=device, device=device,
) )
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
# Validate results # Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}") raise RuntimeError(
f"Detected NaN or Inf values in embeddings, model: {model_name}"
)
return embeddings return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray: def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API""" """Compute embeddings using OpenAI API"""
try: try:
import os
import openai import openai
import os
except ImportError as e: except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}") raise ImportError(f"OpenAI package not installed: {e}")
@@ -261,10 +264,9 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
logger.info( logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
) )
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length # OpenAI has limits on batch size and input length
max_batch_size = 1000 # Conservative batch size max_batch_size = 100 # Conservative batch size
all_embeddings = [] all_embeddings = []
try: try:
@@ -291,12 +293,15 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
raise raise
embeddings = np.array(all_embeddings, dtype=np.float32) embeddings = np.array(all_embeddings, dtype=np.float32)
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}") logger.info(
print(f"len of embeddings: {len(embeddings)}") f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
return embeddings return embeddings
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray: def compute_embeddings_mlx(
chunks: List[str], model_name: str, batch_size: int = 16
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Computes embeddings using an MLX model.""" """Computes embeddings using an MLX model."""
try: try:
@@ -368,262 +373,3 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
# Stack numpy arrays # Stack numpy arrays
return np.stack(all_embeddings) return np.stack(all_embeddings)
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
) -> np.ndarray:
"""
Compute embeddings using Ollama API.
Args:
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
try:
import requests
except ImportError:
raise ImportError(
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
)
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
"Installation: https://ollama.com/download"
)
raise RuntimeError(error_msg)
except Exception as e:
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
# Filter for embedding models (models that support embeddings)
embedding_models = []
suggested_embedding_models = [
"nomic-embed-text",
"mxbai-embed-large",
"bge-m3",
"all-minilm",
"snowflake-arctic-embed",
]
for model in model_names:
# Check if it's an embedding model (by name patterns or known models)
base_name = model.split(":")[0]
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
embedding_models.append(model)
# Check if model exists (handle versioned names)
model_found = any(
model_name == name.split(":")[0] or model_name == name for name in model_names
)
if not model_found:
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
# Suggest pulling the model
error_msg += "📦 To install this embedding model:\n"
error_msg += f" ollama pull {model_name}\n\n"
# Show available embedding models
if embedding_models:
error_msg += "✅ Available embedding models:\n"
for model in embedding_models[:5]:
error_msg += f"{model}\n"
if len(embedding_models) > 5:
error_msg += f" ... and {len(embedding_models) - 5} more\n"
else:
error_msg += "💡 Popular embedding models to install:\n"
for model in suggested_embedding_models[:3]:
error_msg += f" • ollama pull {model}\n"
error_msg += "\n📚 Browse more: https://ollama.com/library"
raise ValueError(error_msg)
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
)
if test_response.status_code != 200:
error_msg = (
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
f"Please use an embedding model like:\n"
)
for model in suggested_embedding_models[:3]:
error_msg += f"{model}\n"
raise ValueError(error_msg)
except requests.exceptions.RequestException:
# If test fails, continue anyway - model might still work
pass
except requests.exceptions.RequestException as e:
logger.warning(f"Could not verify model existence: {e}")
# Process embeddings with optimized concurrent processing
import requests
def get_single_embedding(text_idx_tuple):
"""Helper function to get embedding for a single text."""
text, idx = text_idx_tuple
max_retries = 3
retry_count = 0
# Truncate very long texts to avoid API issues
truncated_text = text[:8000] if len(text) > 8000 else text
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)
response.raise_for_status()
result = response.json()
embedding = result.get("embedding")
if embedding is None:
raise ValueError(f"No embedding returned for text {idx}")
return idx, embedding
except requests.exceptions.Timeout:
retry_count += 1
if retry_count >= max_retries:
logger.warning(f"Timeout for text {idx} after {max_retries} retries")
return idx, None
except Exception as e:
if retry_count >= max_retries - 1:
logger.error(f"Failed to get embedding for text {idx}: {e}")
return idx, None
retry_count += 1
return idx, None
# Determine if we should use concurrent processing
use_concurrent = (
len(texts) > 5 and not is_build
) # Don't use concurrent in build mode to avoid overwhelming
max_workers = min(4, len(texts)) # Limit concurrent requests to avoid overwhelming Ollama
all_embeddings = [None] * len(texts) # Pre-allocate list to maintain order
failed_indices = []
if use_concurrent:
logger.info(
f"Using concurrent processing with {max_workers} workers for {len(texts)} texts"
)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks
future_to_idx = {
executor.submit(get_single_embedding, (text, idx)): idx
for idx, text in enumerate(texts)
}
# Add progress bar for concurrent processing
try:
if is_build or len(texts) > 10:
from tqdm import tqdm
futures_iterator = tqdm(
as_completed(future_to_idx),
total=len(texts),
desc="Computing Ollama embeddings",
)
else:
futures_iterator = as_completed(future_to_idx)
except ImportError:
futures_iterator = as_completed(future_to_idx)
# Collect results as they complete
for future in futures_iterator:
try:
idx, embedding = future.result()
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
except Exception as e:
idx = future_to_idx[future]
logger.error(f"Exception for text {idx}: {e}")
failed_indices.append(idx)
else:
# Sequential processing with progress bar
show_progress = is_build or len(texts) > 10
try:
if show_progress:
from tqdm import tqdm
iterator = tqdm(
enumerate(texts), total=len(texts), desc="Computing Ollama embeddings"
)
else:
iterator = enumerate(texts)
except ImportError:
iterator = enumerate(texts)
for idx, text in iterator:
result_idx, embedding = get_single_embedding((text, idx))
if embedding is not None:
all_embeddings[idx] = embedding
else:
failed_indices.append(idx)
# Handle failed embeddings
if failed_indices:
if len(failed_indices) == len(texts):
raise RuntimeError("Failed to compute any embeddings")
logger.warning(f"Failed to compute embeddings for {len(failed_indices)}/{len(texts)} texts")
# Use zero embeddings as fallback for failed ones
valid_embedding = next((e for e in all_embeddings if e is not None), None)
if valid_embedding:
embedding_dim = len(valid_embedding)
for idx in failed_indices:
all_embeddings[idx] = [0.0] * embedding_dim
# Remove None values and convert to numpy array
all_embeddings = [e for e in all_embeddings if e is not None]
# Convert to numpy array and normalize
embeddings = np.array(all_embeddings, dtype=np.float32)
# Normalize embeddings (L2 normalization)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
return embeddings

View File

@@ -1,14 +1,12 @@
import time
import atexit import atexit
import logging
import os
import signal
import socket import socket
import subprocess import subprocess
import sys import sys
import time import os
import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import psutil import psutil
# Set up logging based on environment variable # Set up logging based on environment variable
@@ -20,24 +18,6 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_colab_environment() -> bool:
"""Check if we're running in Google Colab environment."""
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
def _get_available_port(start_port: int = 5557) -> int:
"""Get an available port starting from start_port."""
port = start_port
while port < start_port + 100: # Try up to 100 ports
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
def _check_port(port: int) -> bool: def _check_port(port: int) -> bool:
"""Check if a port is in use""" """Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -195,69 +175,68 @@ class EmbeddingServerManager:
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start the embedding server.""" """
Starts the embedding server process.
Args:
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server.
Returns:
tuple[bool, int]: (success, actual_port_used)
"""
passages_file = kwargs.get("passages_file") passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible server already running # Check if we have a compatible running server
if self._has_compatible_running_server(model_name, passages_file): if self._has_compatible_running_server(model_name, passages_file):
logger.info("Found compatible running server!") assert self.server_port is not None, (
return True, port "a compatible running server should set server_port"
)
return True, self.server_port
# For Colab environment, use a different strategy # Find available port (compatible or free)
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Find a compatible port or next available
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
if is_compatible:
logger.info(f"Found compatible server on port {actual_port}")
return True, actual_port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try: try:
actual_port = _get_available_port(port) actual_port, is_compatible = _find_compatible_port_or_next_available(
except RuntimeError: port, model_name, passages_file
logger.error("No available ports found") )
except RuntimeError as e:
logger.error(str(e))
return False, port return False, port
logger.info(f"Starting server on port {actual_port} for Colab environment") if is_compatible:
logger.info(f"Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port
# Use a simpler startup strategy for Colab if actual_port != port:
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs) logger.info(f"Using port {actual_port} instead of {port}")
try: # Start new server
# In Colab, we'll use a more direct approach return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool: def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
"""Check if we have a compatible running server.""" """Check if we have a compatible running server."""
if not (self.server_process and self.server_process.poll() is None and self.server_port): if not (
self.server_process
and self.server_process.poll() is None
and self.server_port
):
return False return False
if _check_process_matches_config(self.server_port, model_name, passages_file): if _check_process_matches_config(self.server_port, model_name, passages_file):
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible") logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
return True return True
logger.info("Existing server process is incompatible. Should start a new server.") logger.info(
"Existing server process is incompatible. Should start a new server."
)
return False return False
def _start_new_server( def _start_new_server(
@@ -290,13 +269,9 @@ class EmbeddingServerManager:
] ]
if kwargs.get("passages_file"): if kwargs.get("passages_file"):
# Convert to absolute path to ensure subprocess can find the file command.extend(["--passages-file", str(kwargs["passages_file"])])
passages_file = Path(kwargs["passages_file"]).resolve()
command.extend(["--passages-file", str(passages_file)])
if embedding_mode != "sentence-transformers": if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode]) command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("distance_metric"):
command.extend(["--distance-metric", kwargs["distance_metric"]])
return command return command
@@ -305,24 +280,13 @@ class EmbeddingServerManager:
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}") logger.info(f"Command: {' '.join(command)}")
# In CI environment, redirect output to avoid buffer deadlock # Let server output go directly to console
# Embedding servers use many print statements that can fill buffers # The server will respect LEANN_LOG_LEVEL environment variable
is_ci = os.environ.get("CI") == "true"
if is_ci:
stdout_target = subprocess.DEVNULL
stderr_target = subprocess.DEVNULL
logger.info("CI environment detected, redirecting embedding server output to DEVNULL")
else:
stdout_target = None # Direct to console for visible logs
stderr_target = None # Direct to console for visible logs
# IMPORTANT: Use a new session so we can manage the whole process group reliably
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
cwd=project_root, cwd=project_root,
stdout=stdout_target, stdout=None, # Direct to console
stderr=stderr_target, stderr=None, # Direct to console
start_new_session=True,
) )
self.server_port = port self.server_port = port
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
@@ -364,79 +328,21 @@ class EmbeddingServerManager:
logger.info( logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..." f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
) )
# Try terminating the whole process group first (POSIX) self.server_process.terminate()
try:
pgid = os.getpgid(self.server_process.pid)
os.killpg(pgid, signal.SIGTERM)
except Exception:
# Fallback to terminating just the process
self.server_process.terminate()
try: try:
self.server_process.wait(timeout=3) self.server_process.wait(timeout=5)
logger.info(f"Server process {self.server_process.pid} terminated.") logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.warning( logger.warning(
f"Server process {self.server_process.pid} did not terminate gracefully within 3 seconds, killing it." f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
) )
try: self.server_process.kill()
pgid = os.getpgid(self.server_process.pid)
os.killpg(pgid, signal.SIGKILL) # Clean up process resources to prevent resource tracker warnings
except Exception: try:
self.server_process.kill() self.server_process.wait() # Ensure process is fully cleaned up
try: except Exception:
self.server_process.wait(timeout=2) pass
logger.info(f"Server process {self.server_process.pid} killed successfully.")
except subprocess.TimeoutExpired:
logger.error(
f"Failed to kill server process {self.server_process.pid} - it may be hung"
)
# Don't hang indefinitely
# Clean up process resources without waiting
# The process should already be terminated/killed above
# Don't wait here as it can hang CI indefinitely
self.server_process = None self.server_process = None
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, redirect to DEVNULL to avoid pipe blocking
# PIPE without reading can cause hangs
self.server_process = subprocess.Popen(
command,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
text=True,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback
if not self._atexit_registered:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error("Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port
time.sleep(wait_interval)
logger.error(f"Colab server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port

View File

@@ -1,14 +1,15 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Literal, Optional
import numpy as np import numpy as np
from typing import Dict, Any, List, Literal, Optional
class LeannBackendBuilderInterface(ABC): class LeannBackendBuilderInterface(ABC):
"""Backend interface for building indexes""" """Backend interface for building indexes"""
@abstractmethod @abstractmethod
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None: def build(
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
) -> None:
"""Build index """Build index
Args: Args:
@@ -52,7 +53,7 @@ class LeannBackendSearcherInterface(ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None, zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> Dict[str, Any]:
"""Search for nearest neighbors """Search for nearest neighbors
Args: Args:

View File

@@ -1,176 +0,0 @@
#!/usr/bin/env python3
import json
import subprocess
import sys
def handle_request(request):
if request.get("method") == "initialize":
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"capabilities": {"tools": {}},
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
},
}
elif request.get("method") == "tools/list":
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"tools": [
{
"name": "leann_search",
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
🎯 **Perfect for**:
- "How does authentication work?" → finds auth-related code
- "Error handling patterns" → locates try-catch blocks and error logic
- "Database connection setup" → finds DB initialization code
- "API endpoint definitions" → locates route handlers
- "Configuration management" → finds config files and usage
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
"inputSchema": {
"type": "object",
"properties": {
"index_name": {
"type": "string",
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
},
"query": {
"type": "string",
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
},
"top_k": {
"type": "integer",
"default": 5,
"minimum": 1,
"maximum": 20,
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
},
"complexity": {
"type": "integer",
"default": 32,
"minimum": 16,
"maximum": 128,
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
},
},
"required": ["index_name", "query"],
},
},
{
"name": "leann_status",
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
"inputSchema": {
"type": "object",
"properties": {
"index_name": {
"type": "string",
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
}
},
},
},
{
"name": "leann_list",
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
"inputSchema": {"type": "object", "properties": {}},
},
]
},
}
elif request.get("method") == "tools/call":
tool_name = request["params"]["name"]
args = request["params"].get("arguments", {})
try:
if tool_name == "leann_search":
# Validate required parameters
if not args.get("index_name") or not args.get("query"):
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": "Error: Both index_name and query are required",
}
]
},
}
# Build simplified command
cmd = [
"leann",
"search",
args["index_name"],
args["query"],
f"--top-k={args.get('top_k', 5)}",
f"--complexity={args.get('complexity', 32)}",
]
result = subprocess.run(cmd, capture_output=True, text=True)
elif tool_name == "leann_status":
if args.get("index_name"):
# Check specific index status - for now, we'll use leann list and filter
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
# We could enhance this to show more detailed status per index
else:
# Show all indexes status
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
elif tool_name == "leann_list":
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"result": {
"content": [
{
"type": "text",
"text": result.stdout
if result.returncode == 0
else f"Error: {result.stderr}",
}
]
},
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {"code": -1, "message": str(e)},
}
def main():
for line in sys.stdin:
try:
request = json.loads(line.strip())
response = handle_request(request)
if response:
print(json.dumps(response))
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"id": None,
"error": {"code": -1, "message": str(e)},
}
print(json.dumps(error_response))
sys.stdout.flush()
if __name__ == "__main__":
main()

View File

@@ -1,13 +1,13 @@
# packages/leann-core/src/leann/registry.py # packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING
import importlib import importlib
import importlib.metadata import importlib.metadata
from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface from leann.interface import LeannBackendFactoryInterface
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {} BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
def register_backend(name: str): def register_backend(name: str):
@@ -31,11 +31,13 @@ def autodiscover_backends():
backend_module_name = dist_name.replace("-", "_") backend_module_name = dist_name.replace("-", "_")
discovered_backends.append(backend_module_name) discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading for backend_module_name in sorted(
discovered_backends
): # sort for deterministic loading
try: try:
importlib.import_module(backend_module_name) importlib.import_module(backend_module_name)
# Registration message is printed by the decorator # Registration message is printed by the decorator
except ImportError: except ImportError as e:
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}") # print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
pass pass
# print("INFO: Backend auto-discovery finished.") # print("INFO: Backend auto-discovery finished.")

View File

@@ -1,7 +1,7 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Dict, Any, Literal, Optional
import numpy as np import numpy as np
@@ -38,7 +38,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
self.embedding_model = self.meta.get("embedding_model") self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model: if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail.") print(
"WARNING: embedding_model not found in meta.json. Recompute will fail."
)
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
@@ -46,40 +48,39 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
backend_module_name=backend_module_name, backend_module_name=backend_module_name,
) )
def _load_meta(self) -> dict[str, Any]: def _load_meta(self) -> Dict[str, Any]:
"""Loads the metadata file associated with the index.""" """Loads the metadata file associated with the index."""
# This is the corrected logic for finding the meta file. # This is the corrected logic for finding the meta file.
meta_path = self.index_dir / f"{self.index_path.name}.meta.json" meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
with open(meta_path, encoding="utf-8") as f: with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int: def _ensure_server_running(
self, passages_source_file: str, port: int, **kwargs
) -> int:
""" """
Ensures the embedding server is running if recompute is needed. Ensures the embedding server is running if recompute is needed.
This is a helper for subclasses. This is a helper for subclasses.
""" """
if not self.embedding_model: if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") raise ValueError(
"Cannot use recompute mode without 'embedding_model' in meta.json."
# Get distance_metric from meta if not provided in kwargs )
distance_metric = (
kwargs.get("distance_metric")
or self.meta.get("backend_kwargs", {}).get("distance_metric")
or "mips"
)
server_started, actual_port = self.embedding_server_manager.start_server( server_started, actual_port = self.embedding_server_manager.start_server(
port=port, port=port,
model_name=self.embedding_model, model_name=self.embedding_model,
embedding_mode=self.embedding_mode, embedding_mode=self.embedding_mode,
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=distance_metric, distance_metric=kwargs.get("distance_metric"),
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}") raise RuntimeError(
f"Failed to start embedding server on port {actual_port}"
)
return actual_port return actual_port
@@ -108,10 +109,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
# on that port? # on that port?
# Ensure we have a server with passages_file for compatibility # Ensure we have a server with passages_file for compatibility
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json" passages_source_file = (
# Convert to absolute path to ensure server can find it self.index_dir / f"{self.index_path.name}.meta.json"
)
zmq_port = self._ensure_server_running( zmq_port = self._ensure_server_running(
str(passages_source_file.resolve()), zmq_port str(passages_source_file), zmq_port
) )
return self._compute_embedding_via_server([query], zmq_port)[ return self._compute_embedding_via_server([query], zmq_port)[
@@ -129,18 +131,13 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server.""" """Compute embeddings using the ZMQ embedding server."""
import msgpack
import zmq import zmq
import msgpack
context = None
socket = None
try: try:
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.LINGER, 0) # Don't block on close socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout
socket.setsockopt(zmq.IMMEDIATE, 1) # Don't wait for connection
socket.connect(f"tcp://localhost:{zmq_port}") socket.connect(f"tcp://localhost:{zmq_port}")
# Send embedding request # Send embedding request
@@ -152,6 +149,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
response_bytes = socket.recv() response_bytes = socket.recv()
response = msgpack.unpackb(response_bytes) response = msgpack.unpackb(response_bytes)
socket.close()
context.term()
# Convert response to numpy array # Convert response to numpy array
if isinstance(response, list) and len(response) > 0: if isinstance(response, list) and len(response) > 0:
return np.array(response, dtype=np.float32) return np.array(response, dtype=np.float32)
@@ -160,11 +160,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to compute embeddings via server: {e}") raise RuntimeError(f"Failed to compute embeddings via server: {e}")
finally:
if socket:
socket.close(linger=0)
if context:
context.term()
@abstractmethod @abstractmethod
def search( def search(
@@ -178,7 +173,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
pruning_strategy: Literal["global", "local", "proportional"] = "global", pruning_strategy: Literal["global", "local", "proportional"] = "global",
zmq_port: Optional[int] = None, zmq_port: Optional[int] = None,
**kwargs, **kwargs,
) -> dict[str, Any]: ) -> Dict[str, Any]:
""" """
Search for the top_k nearest neighbors of the query vector. Search for the top_k nearest neighbors of the query vector.
@@ -198,27 +193,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
""" """
pass pass
def cleanup(self): def __del__(self):
"""Cleanup resources including embedding server and ZMQ connections.""" """Ensures the embedding server is stopped when the searcher is destroyed."""
# Stop embedding server
if hasattr(self, "embedding_server_manager"): if hasattr(self, "embedding_server_manager"):
self.embedding_server_manager.stop_server() self.embedding_server_manager.stop_server()
# Set ZMQ linger but don't terminate global context
try:
import zmq
# Just set linger on the global instance
ctx = zmq.Context.instance()
ctx.linger = 0
# NEVER call ctx.term() on the global instance
except Exception:
pass
def __del__(self):
"""Ensures resources are cleaned up when the searcher is destroyed."""
try:
self.cleanup()
except Exception:
# Ignore errors during destruction
pass

View File

@@ -1,91 +0,0 @@
# 🔥 LEANN Claude Code Integration
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
## Prerequisites
**Step 1:** First, complete the basic LEANN installation following the [📦 Installation guide](../../README.md#installation) in the root README:
```bash
uv venv
source .venv/bin/activate
uv pip install leann
```
**Step 2:** Install LEANN globally for MCP integration:
```bash
uv tool install leann-core
```
This makes the `leann` command available system-wide, which `leann_mcp` requires.
## 🚀 Quick Setup
Add the LEANN MCP server to Claude Code:
```bash
claude mcp add leann-server -- leann_mcp
```
## 🛠️ Available Tools
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
- **`leann_list`** - List all available indexes across your projects
- **`leann_search`** - Perform semantic searches across code and documents
- **`leann_ask`** - Ask natural language questions and get AI-powered answers from your codebase
## 🎯 Quick Start Example
```bash
# Build an index for your project (change to your actual path)
leann build my-project --docs ./
# Start Claude Code
claude
```
**Try this in Claude Code:**
```
Help me understand this codebase. List available indexes and search for authentication patterns.
```
<p align="center">
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
</p>
## 🧠 How It Works
The integration consists of three key components working seamlessly together:
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
## 📁 File Support
LEANN understands **30+ file types** including:
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
- **Data**: SQL, YAML, JSON, CSV, XML
- **Documentation**: Markdown, TXT, PDF
- **And many more!**
## 💾 Storage & Organization
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
- **Global registry**: Project tracking at `~/.leann/projects.json`
- **Multi-project support**: Switch between different codebases seamlessly
- **Portable**: Transfer indexes between machines with minimal overhead
## 🗑️ Uninstalling
To remove the LEANN MCP server from Claude Code:
```bash
claude mcp remove leann-server
```
To remove LEANN
```
uv pip uninstall leann leann-backend-hnsw leann-core
```

View File

@@ -1,36 +0,0 @@
# LEANN - The smallest vector index in the world
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
## Installation
```bash
# Default installation (includes both HNSW and DiskANN backends)
uv pip install leann
```
## Quick Start
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
from pathlib import Path
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
# Build an index (choose backend: "hnsw" or "diskann")
builder = LeannBuilder(backend_name="hnsw") # or "diskann" for large-scale deployments
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.add_text("Tung Tung Tung Sahur called—they need their bananacrocodile hybrid back")
builder.build_index(INDEX_PATH)
# Search
searcher = LeannSearcher(INDEX_PATH)
results = searcher.search("fantastical AI-generated creatures", top_k=1)
# Chat with your data
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
response = chat.ask("How much storage does LEANN save?", top_k=1)
```
## License
MIT License

View File

@@ -1,12 +0,0 @@
"""
LEANN - Low-storage Embedding Approximation for Neural Networks
A revolutionary vector database that democratizes personal AI.
"""
__version__ = "0.1.0"
# Re-export main API from leann-core
from leann_core import LeannBuilder, LeannChat, LeannSearcher
__all__ = ["LeannBuilder", "LeannChat", "LeannSearcher"]

View File

@@ -1,39 +0,0 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.2.7"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
authors = [
{ name = "LEANN Team" }
]
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
# Default installation: core + hnsw + diskann
dependencies = [
"leann-core>=0.1.0",
"leann-backend-hnsw>=0.1.0",
"leann-backend-diskann>=0.1.0",
]
[project.optional-dependencies]
# All backends now included by default
[project.urls]
Repository = "https://github.com/yichuan-w/LEANN"
Issues = "https://github.com/yichuan-w/LEANN/issues"

View File

@@ -1,23 +1,22 @@
import json import json
import sqlite3
import xml.etree.ElementTree as ElementTree
from pathlib import Path
from typing import Annotated
import requests
import typer import typer
from pathlib import Path
import requests
from tqdm import tqdm from tqdm import tqdm
import xml.etree.ElementTree as ET
from typing_extensions import Annotated
import sqlite3
app = typer.Typer() app = typer.Typer()
def get_safe_path(s: str) -> str: def get_safe_path(s: str) -> str:
""" """
Remove invalid characters to sanitize a path. Remove invalid characters to sanitize a path.
:param s: str to sanitize :param s: str to sanitize
:returns: sanitized str :returns: sanitized str
""" """
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(" ", "") ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(
' ', '')
for i in ban_chars: for i in ban_chars:
s = s.replace(i, "") s = s.replace(i, "")
return s return s
@@ -26,40 +25,36 @@ def get_safe_path(s: str) -> str:
def process_history(history: str): def process_history(history: str):
if history.startswith("<?xml") or history.startswith("<msg>"): if history.startswith("<?xml") or history.startswith("<msg>"):
try: try:
root = ElementTree.fromstring(history) root = ET.fromstring(history)
title = root.find(".//title").text if root.find(".//title") is not None else None title = root.find('.//title').text if root.find('.//title') is not None else None
quoted = ( quoted = root.find('.//refermsg/content').text if root.find('.//refermsg/content') is not None else None
root.find(".//refermsg/content").text
if root.find(".//refermsg/content") is not None
else None
)
if title and quoted: if title and quoted:
return {"title": title, "quoted": process_history(quoted)} return {
"title": title,
"quoted": process_history(quoted)
}
if title: if title:
return title return title
except Exception: except Exception:
return history return history
return history return history
def get_message(history: dict | str): def get_message(history: dict | str):
if isinstance(history, dict): if isinstance(history, dict):
if "title" in history: if 'title' in history:
return history["title"] return history['title']
else: else:
return history return history
def export_chathistory(user_id: str): def export_chathistory(user_id: str):
res = requests.get( res = requests.get("http://localhost:48065/wechat/chatlog", params={
"http://localhost:48065/wechat/chatlog", "userId": user_id,
params={"userId": user_id, "count": 100000}, "count": 100000
).json() }).json()
for i in range(len(res["chatLogs"])): for i in range(len(res['chatLogs'])):
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"]) res['chatLogs'][i]['content'] = process_history(res['chatLogs'][i]['content'])
res["chatLogs"][i]["message"] = get_message(res["chatLogs"][i]["content"]) res['chatLogs'][i]['message'] = get_message(res['chatLogs'][i]['content'])
return res["chatLogs"] return res['chatLogs']
@app.command() @app.command()
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]): def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
@@ -69,7 +64,7 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
if not dest.is_dir(): if not dest.is_dir():
if not dest.exists(): if not dest.exists():
inp = typer.prompt("Destination path does not exist, create it? (y/n)") inp = typer.prompt("Destination path does not exist, create it? (y/n)")
if inp.lower() == "y": if inp.lower() == 'y':
dest.mkdir(parents=True) dest.mkdir(parents=True)
else: else:
typer.echo("Aborted.", err=True) typer.echo("Aborted.", err=True)
@@ -82,12 +77,12 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
exported_count = 0 exported_count = 0
for user in tqdm(all_users): for user in tqdm(all_users):
try: try:
usr_chatlog = export_chathistory(user["arg"]) usr_chatlog = export_chathistory(user['arg'])
# Only write file if there are messages # Only write file if there are messages
if len(usr_chatlog) > 0: if len(usr_chatlog) > 0:
out_path = dest / get_safe_path((user["title"] or "") + "-" + user["arg"] + ".json") out_path = dest/get_safe_path((user['title'] or "")+"-"+user['arg']+'.json')
with open(out_path, "w", encoding="utf-8") as f: with open(out_path, 'w', encoding='utf-8') as f:
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2) json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
exported_count += 1 exported_count += 1
except Exception as e: except Exception as e:
@@ -96,43 +91,23 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
print(f"Exported {exported_count} users' chat history to {dest} in json.") print(f"Exported {exported_count} users' chat history to {dest} in json.")
@app.command() @app.command()
def export_sqlite( def export_sqlite(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path("chatlog.db")):
dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path(
"chatlog.db"
),
):
""" """
Export all users' chat history to a sqlite database. Export all users' chat history to a sqlite database.
""" """
connection = sqlite3.connect(dest) connection = sqlite3.connect(dest)
cursor = connection.cursor() cursor = connection.cursor()
cursor.execute( cursor.execute("CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)")
"CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)"
)
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)") cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)") cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json() all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
for user in tqdm(all_users): for user in tqdm(all_users):
cursor.execute( cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user['arg'], user['title']))
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", usr_chatlog = export_chathistory(user['arg'])
(user["arg"], user["title"]),
)
usr_chatlog = export_chathistory(user["arg"])
for msg in usr_chatlog: for msg in usr_chatlog:
cursor.execute( cursor.execute("INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)", (user['arg'], msg['fromUser'], msg['toUser'], msg['message'], msg['createTime'], str(msg['content'])))
"INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)",
(
user["arg"],
msg["fromUser"],
msg["toUser"],
msg["message"],
msg["createTime"],
str(msg["content"]),
),
)
connection.commit() connection.commit()

View File

@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-workspace" name = "leann-workspace"
version = "0.1.0" version = "0.1.0"
requires-python = ">=3.9" requires-python = ">=3.10"
dependencies = [ dependencies = [
"leann-core", "leann-core",
@@ -25,65 +25,33 @@ dependencies = [
"requests>=2.25.0", "requests>=2.25.0",
"sentence-transformers>=2.2.0", "sentence-transformers>=2.2.0",
"openai>=1.0.0", "openai>=1.0.0",
# PDF parsing dependencies - essential for document processing
"PyPDF2>=3.0.0", "PyPDF2>=3.0.0",
"pdfplumber>=0.11.0",
"pymupdf>=1.26.0",
"pypdfium2>=4.30.0",
# LlamaIndex core and readers - updated versions
"llama-index>=0.12.44", "llama-index>=0.12.44",
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing "llama-index-readers-docling",
# "llama-index-readers-docling", # Requires Python >= 3.10 "llama-index-node-parser-docling",
# "llama-index-node-parser-docling", # Requires Python >= 3.10
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
# Other dependencies
"ipykernel==6.29.5", "ipykernel==6.29.5",
"msgpack>=1.1.1", "msgpack>=1.1.1",
"mlx>=0.26.3; sys_platform == 'darwin'", "llama-index-vector-stores-faiss>=0.4.0",
"mlx-lm>=0.26.0; sys_platform == 'darwin'", "llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
"psutil>=5.8.0", "psutil>=5.8.0",
"pybind11>=3.0.0",
"pathspec>=0.12.1",
"nbconvert>=7.16.6",
"gitignore-parser>=0.1.12",
] ]
[project.optional-dependencies] [project.optional-dependencies]
dev = [ dev = [
"pytest>=8.3.0", # Minimum version for Python 3.13 support "pytest>=7.0",
"pytest-cov>=5.0", "pytest-cov>=4.0",
"pytest-xdist>=3.5", # For parallel test execution
"black>=23.0", "black>=23.0",
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments "ruff>=0.1.0",
"matplotlib", "matplotlib",
"huggingface-hub>=0.20.0", "huggingface-hub>=0.20.0",
"pre-commit>=3.5.0",
]
test = [
"pytest>=8.3.0", # Minimum version for Python 3.13 support
"pytest-timeout>=2.3",
"anyio>=4.0", # For async test support (includes pytest plugin)
"psutil>=5.9.0", # For process cleanup in tests
"llama-index-core>=0.12.0",
"llama-index-readers-file>=0.4.0",
"python-dotenv>=1.0.0",
"sentence-transformers>=2.2.0",
] ]
diskann = [ diskann = [
"leann-backend-diskann", "leann-backend-diskann",
] ]
# Add a new optional dependency group for document processing
documents = [
"beautifulsoup4>=4.13.0", # For HTML parsing
"python-docx>=0.8.11", # For Word documents
"openpyxl>=3.1.0", # For Excel files
"pandas>=2.2.0", # For data processing
]
[tool.setuptools] [tool.setuptools]
py-modules = [] py-modules = []
@@ -92,80 +60,3 @@ py-modules = []
leann-core = { path = "packages/leann-core", editable = true } leann-core = { path = "packages/leann-core", editable = true }
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true } leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true } leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
[tool.ruff]
target-version = "py39"
line-length = 100
extend-exclude = [
"third_party",
"*.egg-info",
"__pycache__",
".git",
".venv",
]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"N", # pep8-naming
"RUF", # ruff-specific rules
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults
"B904", # raise without from
"N812", # lowercase imported as non-lowercase
"N806", # variable in function should be lowercase
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
]
[tool.ruff.lint.per-file-ignores]
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[dependency-groups]
dev = [
"ruff>=0.12.4",
]
[tool.lychee]
accept = ["200", "403", "429", "503"]
timeout = 20
max_retries = 2
exclude = ["localhost", "127.0.0.1", "example.com"]
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
scheme = ["https", "http"]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"openai: marks tests that require OpenAI API key",
]
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
timeout_method = "thread" # Use thread method to avoid non-daemon thread issues
addopts = [
"-v",
"--tb=short",
"--strict-markers",
"--disable-warnings",
]
env = [
"HF_HUB_DISABLE_SYMLINKS=1",
"TOKENIZERS_PARALLELISM=false",
]

View File

@@ -0,0 +1,12 @@
import faiss
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(hnsw_index.ntotal)
# print stats of the graph
print(hnsw_index.hnsw.print_neighbor_stats(0))
# save_degree_distribution
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")

View File

@@ -0,0 +1,11 @@
import faiss
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(nsg_index.ntotal)
# print stats of the graph
print(nsg_index.nsg.print_neighbor_stats(0))
# save degree distribution
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")

63
research/micro/bnbtest.py Normal file
View File

@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import time
# import bitsandbytes as bnb
from bitsandbytes.nn import Linear8bitLt
# set default to half
import torch
torch.set_default_dtype(torch.float16)
M = 2048
N = 2048
bsz = 2048
import torch_int
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
fp16_model = nn.Sequential(
nn.Linear(M, N),
# nn.Linear(2048, 2048)
)
int8_model = nn.Sequential(
Linear8bitLt(M, N, has_fp16_weights=False),
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
# Create random input tensor
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
int8_time = speed_test(int8_model, input_tensor, "INT8")
# Calculate speedup
speedup = fp16_time / int8_time
print(f"INT8 is {speedup:.2f}x faster than FP16")

Some files were not shown because too many files have changed in this diff Show More