Compare commits
1 Commits
feature/sk
...
apps
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b92ec04178 |
12
.github/workflows/build-and-publish.yml
vendored
12
.github/workflows/build-and-publish.yml
vendored
@@ -1,12 +0,0 @@
|
|||||||
name: CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
uses: ./.github/workflows/build-reusable.yml
|
|
||||||
306
.github/workflows/build-reusable.yml
vendored
306
.github/workflows/build-reusable.yml
vendored
@@ -1,306 +0,0 @@
|
|||||||
name: Reusable Build
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
ref:
|
|
||||||
description: 'Git ref to build'
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
lint:
|
|
||||||
name: Lint and Format Check
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: ${{ inputs.ref }}
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install ruff
|
|
||||||
run: |
|
|
||||||
uv tool install ruff
|
|
||||||
|
|
||||||
- name: Run ruff check
|
|
||||||
run: |
|
|
||||||
ruff check .
|
|
||||||
|
|
||||||
- name: Run ruff format check
|
|
||||||
run: |
|
|
||||||
ruff format --check .
|
|
||||||
|
|
||||||
build:
|
|
||||||
needs: lint
|
|
||||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- os: ubuntu-22.04
|
|
||||||
python: '3.9'
|
|
||||||
- os: ubuntu-22.04
|
|
||||||
python: '3.10'
|
|
||||||
- os: ubuntu-22.04
|
|
||||||
python: '3.11'
|
|
||||||
- os: ubuntu-22.04
|
|
||||||
python: '3.12'
|
|
||||||
- os: ubuntu-22.04
|
|
||||||
python: '3.13'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.10'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.11'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.12'
|
|
||||||
- os: macos-14
|
|
||||||
python: '3.13'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.10'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.11'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.12'
|
|
||||||
- os: macos-15
|
|
||||||
python: '3.13'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.9'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.10'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.11'
|
|
||||||
- os: macos-13
|
|
||||||
python: '3.12'
|
|
||||||
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
|
||||||
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: ${{ inputs.ref }}
|
|
||||||
submodules: recursive
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python }}
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
|
||||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
|
||||||
source /opt/intel/oneapi/setvars.sh
|
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
# Don't install LLVM, use system clang for better compatibility
|
|
||||||
brew install libomp boost protobuf zeromq
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
|
||||||
run: |
|
|
||||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
|
||||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
|
||||||
uv pip install --system auditwheel
|
|
||||||
else
|
|
||||||
uv pip install --system delocate
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Set macOS environment variables
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
# Use brew --prefix to automatically detect Homebrew installation path
|
|
||||||
HOMEBREW_PREFIX=$(brew --prefix)
|
|
||||||
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
|
||||||
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
|
||||||
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
# Set compiler flags for OpenMP (required for both backends)
|
|
||||||
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
|
||||||
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Build packages
|
|
||||||
run: |
|
|
||||||
# Build core (platform independent)
|
|
||||||
cd packages/leann-core
|
|
||||||
uv build
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
# Build HNSW backend
|
|
||||||
cd packages/leann-backend-hnsw
|
|
||||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
|
||||||
# Use system clang for better compatibility
|
|
||||||
export CC=clang
|
|
||||||
export CXX=clang++
|
|
||||||
# Homebrew libraries on each macOS version require matching minimum version
|
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
|
||||||
fi
|
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
|
||||||
else
|
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
|
||||||
fi
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
# Build DiskANN backend
|
|
||||||
cd packages/leann-backend-diskann
|
|
||||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
|
||||||
# Use system clang for better compatibility
|
|
||||||
export CC=clang
|
|
||||||
export CXX=clang++
|
|
||||||
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
|
||||||
# But Homebrew libraries on each macOS version require matching minimum version
|
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=13.3
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
|
||||||
fi
|
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
|
||||||
else
|
|
||||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
|
||||||
fi
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
# Build meta package (platform independent)
|
|
||||||
cd packages/leann
|
|
||||||
uv build
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
- name: Repair wheels (Linux)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
# Repair HNSW wheel
|
|
||||||
cd packages/leann-backend-hnsw
|
|
||||||
if [ -d dist ]; then
|
|
||||||
auditwheel repair dist/*.whl -w dist_repaired
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
fi
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
# Repair DiskANN wheel
|
|
||||||
cd packages/leann-backend-diskann
|
|
||||||
if [ -d dist ]; then
|
|
||||||
auditwheel repair dist/*.whl -w dist_repaired
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
fi
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
- name: Repair wheels (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
# Determine deployment target based on runner OS
|
|
||||||
# Must match the Homebrew libraries for each macOS version
|
|
||||||
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
|
||||||
HNSW_TARGET="13.0"
|
|
||||||
DISKANN_TARGET="13.3"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
|
||||||
HNSW_TARGET="14.0"
|
|
||||||
DISKANN_TARGET="14.0"
|
|
||||||
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
|
||||||
HNSW_TARGET="15.0"
|
|
||||||
DISKANN_TARGET="15.0"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Repair HNSW wheel
|
|
||||||
cd packages/leann-backend-hnsw
|
|
||||||
if [ -d dist ]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
|
||||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
fi
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
# Repair DiskANN wheel
|
|
||||||
cd packages/leann-backend-diskann
|
|
||||||
if [ -d dist ]; then
|
|
||||||
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
|
||||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
fi
|
|
||||||
cd ../..
|
|
||||||
|
|
||||||
- name: List built packages
|
|
||||||
run: |
|
|
||||||
echo "📦 Built packages:"
|
|
||||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
|
||||||
|
|
||||||
|
|
||||||
- name: Install built packages for testing
|
|
||||||
run: |
|
|
||||||
# Create a virtual environment with the correct Python version
|
|
||||||
uv venv --python ${{ matrix.python }}
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
|
||||||
|
|
||||||
# Install packages using --find-links to prioritize local builds
|
|
||||||
uv pip install --find-links packages/leann-core/dist --find-links packages/leann-backend-hnsw/dist --find-links packages/leann-backend-diskann/dist packages/leann-core/dist/*.whl || uv pip install --find-links packages/leann-core/dist packages/leann-core/dist/*.tar.gz
|
|
||||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-hnsw/dist/*.whl
|
|
||||||
uv pip install --find-links packages/leann-core/dist packages/leann-backend-diskann/dist/*.whl
|
|
||||||
uv pip install packages/leann/dist/*.whl || uv pip install packages/leann/dist/*.tar.gz
|
|
||||||
|
|
||||||
# Install test dependencies using extras
|
|
||||||
uv pip install -e ".[test]"
|
|
||||||
|
|
||||||
- name: Run tests with pytest
|
|
||||||
env:
|
|
||||||
CI: true
|
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
|
||||||
HF_HUB_DISABLE_SYMLINKS: 1
|
|
||||||
TOKENIZERS_PARALLELISM: false
|
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
|
||||||
OMP_NUM_THREADS: 1
|
|
||||||
MKL_NUM_THREADS: 1
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
|
||||||
pytest tests/ -v --tb=short
|
|
||||||
|
|
||||||
- name: Run sanity checks (optional)
|
|
||||||
run: |
|
|
||||||
# Activate virtual environment
|
|
||||||
source .venv/bin/activate || source .venv/Scripts/activate
|
|
||||||
|
|
||||||
# Run distance function tests if available
|
|
||||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
|
||||||
echo "Running distance function sanity checks..."
|
|
||||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
|
||||||
path: packages/*/dist/
|
|
||||||
19
.github/workflows/link-check.yml
vendored
19
.github/workflows/link-check.yml
vendored
@@ -1,19 +0,0 @@
|
|||||||
name: Link Check
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main, master ]
|
|
||||||
pull_request:
|
|
||||||
schedule:
|
|
||||||
- cron: "0 3 * * 1"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
link-check:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: lycheeverse/lychee-action@v2
|
|
||||||
with:
|
|
||||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
129
.github/workflows/release-manual.yml
vendored
129
.github/workflows/release-manual.yml
vendored
@@ -1,129 +0,0 @@
|
|||||||
name: Release
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version to release (e.g., 0.1.2)'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
update-version:
|
|
||||||
name: Update Version
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
outputs:
|
|
||||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Validate version
|
|
||||||
run: |
|
|
||||||
# Remove 'v' prefix if present for validation
|
|
||||||
VERSION_CLEAN="${{ inputs.version }}"
|
|
||||||
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
|
||||||
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
|
||||||
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
|
||||||
|
|
||||||
- name: Update versions and push
|
|
||||||
id: push
|
|
||||||
run: |
|
|
||||||
# Check current version
|
|
||||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
|
||||||
echo "Current version: $CURRENT_VERSION"
|
|
||||||
echo "Target version: ${{ inputs.version }}"
|
|
||||||
|
|
||||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
|
||||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
|
||||||
COMMIT_SHA=$(git rev-parse HEAD)
|
|
||||||
else
|
|
||||||
./scripts/bump_version.sh ${{ inputs.version }}
|
|
||||||
git config user.name "GitHub Actions"
|
|
||||||
git config user.email "actions@github.com"
|
|
||||||
git add packages/*/pyproject.toml
|
|
||||||
git commit -m "chore: release v${{ inputs.version }}"
|
|
||||||
git push origin main
|
|
||||||
COMMIT_SHA=$(git rev-parse HEAD)
|
|
||||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
build-packages:
|
|
||||||
name: Build packages
|
|
||||||
needs: update-version
|
|
||||||
uses: ./.github/workflows/build-reusable.yml
|
|
||||||
with:
|
|
||||||
ref: 'main'
|
|
||||||
|
|
||||||
publish:
|
|
||||||
name: Publish and Release
|
|
||||||
needs: [update-version, build-packages]
|
|
||||||
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: 'main'
|
|
||||||
|
|
||||||
- name: Download all artifacts
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
path: dist-artifacts
|
|
||||||
|
|
||||||
- name: Collect packages
|
|
||||||
run: |
|
|
||||||
mkdir -p dist
|
|
||||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
|
||||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
|
||||||
|
|
||||||
echo "📦 Packages to publish:"
|
|
||||||
ls -la dist/
|
|
||||||
|
|
||||||
- name: Publish to PyPI
|
|
||||||
env:
|
|
||||||
TWINE_USERNAME: __token__
|
|
||||||
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
|
||||||
run: |
|
|
||||||
if [ -z "$TWINE_PASSWORD" ]; then
|
|
||||||
echo "❌ PYPI_API_TOKEN not configured!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
pip install twine
|
|
||||||
twine upload dist/* --skip-existing --verbose
|
|
||||||
|
|
||||||
echo "✅ Published to PyPI!"
|
|
||||||
|
|
||||||
- name: Create release
|
|
||||||
run: |
|
|
||||||
# Check if tag already exists
|
|
||||||
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
|
||||||
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
|
||||||
else
|
|
||||||
git tag "v${{ inputs.version }}"
|
|
||||||
git push origin "v${{ inputs.version }}"
|
|
||||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if release already exists
|
|
||||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
|
||||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
|
||||||
else
|
|
||||||
gh release create "v${{ inputs.version }}" \
|
|
||||||
--title "Release v${{ inputs.version }}" \
|
|
||||||
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
|
||||||
--latest
|
|
||||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
|
||||||
fi
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
17
.gitignore
vendored
17
.gitignore
vendored
@@ -12,6 +12,7 @@ outputs/
|
|||||||
*.idx
|
*.idx
|
||||||
*.map
|
*.map
|
||||||
.history/
|
.history/
|
||||||
|
scripts/
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
@@ -34,15 +35,11 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
data/*
|
examples/data/*
|
||||||
!data/2501.14312v1 (1).pdf
|
!examples/data/2501.14312v1 (1).pdf
|
||||||
!data/2506.08276v1.pdf
|
!examples/data/2506.08276v1.pdf
|
||||||
!data/PrideandPrejudice.txt
|
!examples/data/PrideandPrejudice.txt
|
||||||
!data/huawei_pangu.md
|
!examples/data/README.md
|
||||||
!data/ground_truth/
|
|
||||||
!data/indices/
|
|
||||||
!data/queries/
|
|
||||||
!data/.gitattributes
|
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -90,5 +87,3 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
*.passages.json
|
*.passages.json
|
||||||
|
|
||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
|
||||||
tests/__pycache__/
|
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
repos:
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
||||||
rev: v5.0.0
|
|
||||||
hooks:
|
|
||||||
- id: trailing-whitespace
|
|
||||||
- id: end-of-file-fixer
|
|
||||||
- id: check-yaml
|
|
||||||
- id: check-added-large-files
|
|
||||||
- id: check-merge-conflict
|
|
||||||
- id: debug-statements
|
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
||||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
|
||||||
hooks:
|
|
||||||
- id: ruff
|
|
||||||
- id: ruff-format
|
|
||||||
613
README.md
613
README.md
@@ -3,25 +3,20 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://img.shields.io/badge/Python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue.svg" alt="Python Versions">
|
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||||
<img src="https://github.com/yichuan-w/LEANN/actions/workflows/build-and-publish.yml/badge.svg" alt="CI Status">
|
|
||||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS-lightgrey" alt="Platform">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
The smallest vector index in the world. RAG Everything with LEANN!
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can semantic search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, **[codebase](#-claude-code-integration-transform-your-development-workflow)**\* , or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
\* Claude Code only supports basic `grep`-style keyword search. **LEANN** is a drop-in **semantic search MCP service fully compatible with Claude Code**, unlocking intelligent retrieval without changing your workflow. 🔥 Check out [the easy setup →](packages/leann-mcp/README.md)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -31,125 +26,57 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million text chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
|
|
||||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||||
|
|
||||||
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
|
||||||
|
|
||||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||||
|
|
||||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
|
|
||||||
## Installation
|
## Quick Start in 1 minute
|
||||||
|
|
||||||
### 📦 Prerequisites: Install uv
|
|
||||||
|
|
||||||
[Install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) first if you don't have it. Typically, you can install it with:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
```
|
|
||||||
|
|
||||||
### 🚀 Quick Install
|
|
||||||
|
|
||||||
Clone the repository to access all examples and try amazing applications,
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
|
||||||
cd leann
|
|
||||||
```
|
|
||||||
|
|
||||||
and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
uv pip install leann
|
|
||||||
```
|
|
||||||
|
|
||||||
> Low-resource? See “Low-resource setups” in the [Configuration Guide](docs/configuration-guide.md#low-resource-setups).
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>
|
|
||||||
<strong>🔧 Build from Source (Recommended for development)</strong>
|
|
||||||
</summary>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/yichuan-w/LEANN.git leann
|
|
||||||
cd leann
|
cd leann
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
```
|
```
|
||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
brew install llvm libomp boost protobuf zeromq
|
||||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
export CC=$(brew --prefix llvm)/bin/clang
|
||||||
|
export CXX=$(brew --prefix llvm)/bin/clang++
|
||||||
|
|
||||||
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# Or add DiskANN backend if you want to test more options
|
||||||
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux:**
|
**Linux (Ubuntu/Debian):**
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
|
||||||
|
# Install with HNSW backend (default, recommended for most users)
|
||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
|
# Or add DiskANN backend if you want to test more options
|
||||||
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
**Ollama Setup (Recommended for full privacy):**
|
||||||
|
|
||||||
Our declarative API makes RAG as easy as writing a config file.
|
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
||||||
|
|
||||||
Check out [demo.ipynb](demo.ipynb) or [](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
|
|
||||||
|
|
||||||
```python
|
*macOS:*
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from pathlib import Path
|
|
||||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
|
||||||
|
|
||||||
# Build an index
|
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
|
||||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
|
||||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
|
|
||||||
# Search
|
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
|
||||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
|
||||||
|
|
||||||
# Chat with your data
|
|
||||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
|
||||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|
||||||
```
|
|
||||||
|
|
||||||
## RAG on Everything!
|
|
||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
|
||||||
|
|
||||||
### Generation Model Setup
|
|
||||||
|
|
||||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
|
||||||
|
|
||||||
Set your OpenAI API key as an environment variable:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export OPENAI_API_KEY="your-api-key-here"
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
|
||||||
|
|
||||||
**macOS:**
|
|
||||||
|
|
||||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
@@ -158,8 +85,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux:**
|
*Linux:*
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install Ollama
|
# Install Ollama
|
||||||
curl -fsSL https://ollama.ai/install.sh | sh
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
@@ -171,120 +97,90 @@ ollama serve &
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
## Dead Simple API
|
||||||
|
|
||||||
### ⭐ Flexible Configuration
|
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
```python
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
📚 **Need configuration best practices?** Check our [Configuration Guide](docs/configuration-guide.md) for detailed optimization tips, model selection advice, and solutions to common issues like slow embeddings or poor search quality.
|
# 1. Build the index (no embeddings stored!)
|
||||||
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("C# is a powerful programming language")
|
||||||
|
builder.add_text("Python is a powerful programming language and it is very popular")
|
||||||
|
builder.add_text("Machine learning transforms industries")
|
||||||
|
builder.add_text("Neural networks process complex data")
|
||||||
|
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
||||||
|
builder.build_index("knowledge.leann")
|
||||||
|
|
||||||
|
# 2. Search with real-time embeddings
|
||||||
|
searcher = LeannSearcher("knowledge.leann")
|
||||||
|
results = searcher.search("programming languages", top_k=2)
|
||||||
|
|
||||||
|
# 3. Chat with LEANN using retrieved results
|
||||||
|
llm_config = {
|
||||||
|
"type": "ollama",
|
||||||
|
"model": "llama3.2:1b"
|
||||||
|
}
|
||||||
|
|
||||||
|
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
||||||
|
response = chat.ask(
|
||||||
|
"Compare the two retrieved programming languages and say which one is more popular today.",
|
||||||
|
top_k=2,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
||||||
|
|
||||||
|
[Try the interactive demo →](demo.ipynb)
|
||||||
|
|
||||||
|
## Wild Things You Can Do
|
||||||
|
|
||||||
|
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
||||||
|
|
||||||
|
### Process Any Documents (.pdf, .txt, .md)
|
||||||
|
|
||||||
|
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents, and even any directory that stores your personal files!
|
||||||
|
|
||||||
|
The following scripts use Ollama `qwen3:8b` by default, so you need `ollama pull qwen3:8b` first. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Drop your PDFs, .txt, .md files into apps/documents/data/
|
||||||
|
python -m apps.documents
|
||||||
|
|
||||||
|
# Or with uv
|
||||||
|
uv run python -m apps.documents
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
||||||
|
|
||||||
|
### Search Your Entire Life
|
||||||
|
```bash
|
||||||
|
python -m apps.email
|
||||||
|
# "What's the number of class recommend to take per semester for incoming EECS students?"
|
||||||
|
```
|
||||||
|
**90K emails → 14MB.** Finally, search your email like you search Google.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Common Parameters (Available in All Examples)</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
|
|
||||||
All RAG examples share these common parameters. **Interactive mode** is available in all examples - simply run without `--query` to start a continuous Q&A session where you can ask multiple questions. Type 'quit' to exit.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Core Parameters (General preprocessing for all examples)
|
# Use default mail path (works for most macOS setups)
|
||||||
--index-dir DIR # Directory to store the index (default: current directory)
|
python -m apps.email
|
||||||
--query "YOUR QUESTION" # Single query mode. Omit for interactive chat (type 'quit' to exit), and now you can play with your index interactively
|
|
||||||
--max-items N # Limit data preprocessing (default: -1, process all data)
|
|
||||||
--force-rebuild # Force rebuild index even if it exists
|
|
||||||
|
|
||||||
# Embedding Parameters
|
# Run with custom index directory
|
||||||
--embedding-model MODEL # e.g., facebook/contriever, text-embedding-3-small, mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text
|
python -m apps.email --index-dir "./my_mail_index"
|
||||||
--embedding-mode MODE # sentence-transformers, openai, mlx, or ollama
|
|
||||||
|
|
||||||
# LLM Parameters (Text generation models)
|
# Process all emails (may take time but indexes everything)
|
||||||
--llm TYPE # LLM backend: openai, ollama, or hf (default: openai)
|
python -m apps.email --max-emails -1
|
||||||
--llm-model MODEL # Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct
|
|
||||||
--thinking-budget LEVEL # Thinking budget for reasoning models: low/medium/high (supported by o3, o3-mini, GPT-Oss:20b, and other reasoning models)
|
|
||||||
|
|
||||||
# Search Parameters
|
# Limit number of emails processed (useful for testing)
|
||||||
--top-k N # Number of results to retrieve (default: 20)
|
python -m apps.email --max-emails 1000
|
||||||
--search-complexity N # Search complexity for graph traversal (default: 32)
|
|
||||||
|
|
||||||
# Chunking Parameters
|
# Run a single query
|
||||||
--chunk-size N # Size of text chunks (default varies by source: 256 for most, 192 for WeChat)
|
python -m apps.email --query "What did my boss say about deadlines?"
|
||||||
--chunk-overlap N # Overlap between chunks (default varies: 25-128 depending on source)
|
|
||||||
|
|
||||||
# Index Building Parameters
|
|
||||||
--backend-name NAME # Backend to use: hnsw or diskann (default: hnsw)
|
|
||||||
--graph-degree N # Graph degree for index construction (default: 32)
|
|
||||||
--build-complexity N # Build complexity for index construction (default: 64)
|
|
||||||
--compact / --no-compact # Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
|
||||||
--recompute / --no-recompute # Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 📄 Personal Data Manager: Process Any Documents (`.pdf`, `.txt`, `.md`)!
|
|
||||||
|
|
||||||
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
The example below asks a question about summarizing our paper (uses default data in `data/`, which is a directory with diverse data sources: two papers, Pride and Prejudice, and a Technical report about LLM in Huawei in Chinese), and this is the **easiest example** to run here:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
source .venv/bin/activate # Don't forget to activate the virtual environment
|
|
||||||
python -m apps.document_rag --query "What are the main techniques LEANN explores?"
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: Document-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--data-dir DIR # Directory containing documents to process (default: data)
|
|
||||||
--file-types .ext .ext # Filter by specific file types (optional - all LlamaIndex supported types if omitted)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Process all documents with larger chunks for academic papers
|
|
||||||
python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
|
||||||
|
|
||||||
# Filter only markdown and Python files with smaller chunks
|
|
||||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
|
||||||
|
|
||||||
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
|
||||||
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
Before running the example below, you need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m apps.email_rag --query "What's the food I ordered by DoorDash or Uber Eats mostly?"
|
|
||||||
```
|
|
||||||
**780K email chunks → 78MB storage.** Finally, search your email like you search Google.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>📋 Click to expand: Email-Specific Arguments</strong></summary>
|
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
|
||||||
--mail-path PATH # Path to specific mail directory (auto-detects if omitted)
|
|
||||||
--include-html # Include HTML content in processing (useful for newsletters)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
|
||||||
```bash
|
|
||||||
# Search work emails from a specific account
|
|
||||||
python -m apps.email_rag --mail-path "~/Library/Mail/V10/WORK_ACCOUNT"
|
|
||||||
|
|
||||||
# Find all receipts and order confirmations (includes HTML)
|
|
||||||
python -m apps.email_rag --query "receipt order confirmation invoice" --include-html
|
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -298,32 +194,28 @@ Once the index is built, you can ask questions like:
|
|||||||
- "Show me emails about travel expenses"
|
- "Show me emails about travel expenses"
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
|
### Time Machine for the Web
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m apps.browser_rag --query "Tell me my browser history about machine learning?"
|
python -m apps.browser
|
||||||
|
# "Tell me my browser history about machine learning system stuff?"
|
||||||
```
|
```
|
||||||
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Browser-Specific Arguments</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
```bash
|
||||||
--chrome-profile PATH # Path to Chrome profile directory (auto-detects if omitted)
|
# Use default Chrome profile (auto-finds all profiles)
|
||||||
```
|
python -m apps.browser
|
||||||
|
|
||||||
#### Example Commands
|
# Run with custom index directory
|
||||||
```bash
|
python -m apps.browser --index-dir "./my_chrome_index"
|
||||||
# Search academic research from your browsing history
|
|
||||||
python -m apps.browser_rag --query "arxiv papers machine learning transformer architecture"
|
|
||||||
|
|
||||||
# Track competitor analysis across work profile
|
# Limit number of history entries processed (useful for testing)
|
||||||
python -m apps.browser_rag --chrome-profile "~/Library/Application Support/Google/Chrome/Work Profile" --max-items 5000
|
python -m apps.browser --max-entries 500
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python -m apps.browser --query "What websites did I visit about machine learning?"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -356,58 +248,44 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
### WeChat Detective
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m apps.wechat_rag --query "Show me all group chats about weekend plans"
|
python -m apps.wechat
|
||||||
|
# "Show me all group chats about weekend plans"
|
||||||
```
|
```
|
||||||
**400K messages → 64MB storage** Search years of chat history in any language.
|
**400K messages → 64MB.** Search years of chat history in any language.
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
|
|
||||||
First, you need to install the [WeChat exporter](https://github.com/sunnyyoung/WeChatTweak-CLI),
|
First, you need to install the WeChat exporter:
|
||||||
|
|
||||||
```bash
|
|
||||||
brew install sunnyyoung/repo/wechattweak-cli
|
|
||||||
```
|
|
||||||
|
|
||||||
or install it manually (if you have issues with Homebrew):
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo packages/wechat-exporter/wechattweak-cli install
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
```
|
```
|
||||||
|
|
||||||
**Troubleshooting:**
|
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
||||||
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
|
||||||
- **Export errors**: If you encounter the error below, try restarting WeChat
|
|
||||||
```bash
|
|
||||||
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
|
||||||
Failed to find or export WeChat data. Exiting.
|
|
||||||
```
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: WeChat-Specific Arguments</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
|
|
||||||
#### Parameters
|
|
||||||
```bash
|
```bash
|
||||||
--export-dir DIR # Directory to store exported WeChat data (default: wechat_export_direct)
|
# Use default settings (recommended for first run)
|
||||||
--force-export # Force re-export even if data exists
|
python -m apps.wechat
|
||||||
```
|
|
||||||
|
|
||||||
#### Example Commands
|
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
|
||||||
```bash
|
python -m apps.wechat --export-dir "./my_wechat_exports"
|
||||||
# Search for travel plans discussed in group chats
|
|
||||||
python -m apps.wechat_rag --query "travel plans" --max-items 10000
|
|
||||||
|
|
||||||
# Re-export and search recent chats (useful after new messages)
|
# Run with custom index directory
|
||||||
python -m apps.wechat_rag --force-export --query "work schedule"
|
python -m apps.wechat --index-dir "./my_wechat_index"
|
||||||
|
|
||||||
|
# Limit number of chat entries processed (useful for testing)
|
||||||
|
python -m apps.wechat --max-entries 1000
|
||||||
|
|
||||||
|
# Run a single query
|
||||||
|
python -m apps.wechat --query "Show me conversations about travel plans"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -421,57 +299,15 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
|
||||||
|
|
||||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
|
||||||
|
|
||||||
**Key features:**
|
|
||||||
- 🔍 **Semantic code search** across your entire project
|
|
||||||
- 📚 **Context-aware assistance** for debugging and development
|
|
||||||
- 🚀 **Zero-config setup** with automatic language detection
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install LEANN globally for MCP integration
|
|
||||||
uv tool install leann-core
|
|
||||||
|
|
||||||
# Setup is automatic - just start using Claude Code!
|
|
||||||
```
|
|
||||||
Try our fully agentic pipeline with auto query rewriting, semantic search planning, and more:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
**Ready to supercharge your coding?** [Complete Setup Guide →](packages/leann-mcp/README.md)
|
|
||||||
|
|
||||||
## 🖥️ Command Line Interface
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
If you followed the Quick Start, `leann` is already installed in your virtual environment:
|
|
||||||
```bash
|
```bash
|
||||||
source .venv/bin/activate
|
# Build an index from documents
|
||||||
leann --help
|
leann build my-docs --docs ./documents
|
||||||
```
|
|
||||||
|
|
||||||
**To make it globally available:**
|
|
||||||
```bash
|
|
||||||
# Install the LEANN CLI globally using uv tool
|
|
||||||
uv tool install leann-core
|
|
||||||
|
|
||||||
# Now you can use leann from anywhere without activating venv
|
|
||||||
leann --help
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Note**: Global installation is required for Claude Code integration. The `leann_mcp` server depends on the globally available `leann` command.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Usage Examples
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# build from a specific directory, and my_docs is the index name(Here you can also build from multiple dict or multiple files)
|
|
||||||
leann build my-docs --docs ./your_documents
|
|
||||||
|
|
||||||
# Search your documents
|
# Search your documents
|
||||||
leann search my-docs "machine learning concepts"
|
leann search my-docs "machine learning concepts"
|
||||||
@@ -484,29 +320,27 @@ leann list
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Key CLI features:**
|
**Key CLI features:**
|
||||||
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||||
- Smart text chunking with overlap
|
- Smart text chunking with overlap
|
||||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
- Organized index storage in `.leann/indexes/` (project-local)
|
- Organized index storage in `~/.leann/indexes/`
|
||||||
- Support for advanced search parameters
|
- Support for advanced search parameters
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
You can use `leann --help`, or `leann build --help`, `leann search --help`, `leann ask --help` to get the complete CLI reference.
|
|
||||||
|
|
||||||
**Build Command:**
|
**Build Command:**
|
||||||
```bash
|
```bash
|
||||||
leann build INDEX_NAME --docs DIRECTORY|FILE [DIRECTORY|FILE ...] [OPTIONS]
|
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
--graph-degree N Graph degree (default: 32)
|
--graph-degree N Graph degree (default: 32)
|
||||||
--complexity N Build complexity (default: 64)
|
--complexity N Build complexity (default: 64)
|
||||||
--force Force rebuild existing index
|
--force Force rebuild existing index
|
||||||
--compact / --no-compact Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.
|
--compact Use compact storage (default: true)
|
||||||
--recompute / --no-recompute Enable recomputation (default: true)
|
--recompute Enable recomputation (default: true)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Search Command:**
|
**Search Command:**
|
||||||
@@ -514,9 +348,9 @@ Options:
|
|||||||
leann search INDEX_NAME QUERY [OPTIONS]
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--top-k N Number of results (default: 5)
|
--top-k N Number of results (default: 5)
|
||||||
--complexity N Search complexity (default: 64)
|
--complexity N Search complexity (default: 64)
|
||||||
--recompute / --no-recompute Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.
|
--recompute-embeddings Use recomputation for highest accuracy
|
||||||
--pruning-strategy {global,local,proportional}
|
--pruning-strategy {global,local,proportional}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -547,31 +381,56 @@ Options:
|
|||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
**Backends:**
|
**Backends:** DiskANN or HNSW - pick what works for your data size.
|
||||||
- **HNSW** (default): Ideal for most datasets with maximum storage savings through full recomputation
|
|
||||||
- **DiskANN**: Advanced option with superior search performance, using PQ-based graph traversal with real-time reranking for the best speed-accuracy trade-off
|
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
**[DiskANN vs HNSW Performance Comparison →](benchmarks/diskann_vs_hnsw_speed_comparison.py)** - Compare search performance between both backends
|
Run the comparison yourself:
|
||||||
|
```bash
|
||||||
|
python -m apps.benchmarks
|
||||||
|
```
|
||||||
|
|
||||||
**[Simple Example: Compare LEANN vs FAISS →](benchmarks/compare_faiss_vs_leann.py)** - See storage savings in action
|
| System | Storage |
|
||||||
|
|--------|---------|
|
||||||
|
| FAISS HNSW | 5.5 MB |
|
||||||
|
| LEANN | 0.5 MB |
|
||||||
|
| **Savings** | **91%** |
|
||||||
|
|
||||||
### 📊 Storage Comparison
|
Same dataset, same hardware, same embedding model. LEANN just works better.
|
||||||
|
|
||||||
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
|
||||||
|--------|-------------|------------|-------------|--------------|---------------|
|
|
||||||
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
|
||||||
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
|
||||||
| Savings| 91% | 97% | 97% | 97% | 95% |
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Storage Usage Comparison
|
||||||
|
|
||||||
|
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
||||||
|
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
||||||
|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
||||||
|
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
||||||
|
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
||||||
|
|
||||||
|
<!-- ### Memory Usage Comparison
|
||||||
|
|
||||||
|
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
||||||
|
| --------------------- | ---------------- | ---------------- | ---------------- |
|
||||||
|
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
||||||
|
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
||||||
|
| **Reduction** | **x%** | **x%** | **x%** |
|
||||||
|
|
||||||
|
### Query Performance of LEANN
|
||||||
|
|
||||||
|
| Backend | Index Size | Query Time | Recall@3 |
|
||||||
|
| ------------------- | ---------- | ---------- | --------- |
|
||||||
|
| DiskANN | 1M docs | xms | 0.95 |
|
||||||
|
| HNSW | 1M docs | xms | 0.95 | -->
|
||||||
|
|
||||||
|
*Benchmarks run on Apple M3 Pro 36 GB*
|
||||||
|
|
||||||
## Reproduce Our Results
|
## Reproduce Our Results
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
python -m apps.evaluation data/indices/dpr/dpr_diskann # DPR dataset
|
||||||
|
python -m apps.evaluation data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
|
||||||
```
|
```
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
@@ -593,15 +452,98 @@ If you find Leann useful, please cite:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## ✨ [Detailed Features →](docs/features.md)
|
## ✨ Features
|
||||||
|
|
||||||
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
### 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
|
|
||||||
|
### 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||||
|
|
||||||
|
### 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
|
|
||||||
|
## 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
### Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
|
||||||
## ❓ [FAQ →](docs/faq.md)
|
<!-- ## ❓ FAQ
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### NCCL Topology Error
|
||||||
|
|
||||||
|
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||||
|
|
||||||
|
```
|
||||||
|
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Set these environment variables before running your script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||||
|
export NCCL_IB_DISABLE=1
|
||||||
|
export NCCL_NET_PLUGIN=none
|
||||||
|
export NCCL_SOCKET_IFNAME=ens5
|
||||||
|
``` -->
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
### 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
|
|
||||||
|
|
||||||
## 📈 [Roadmap →](docs/roadmap.md)
|
## 📈 Roadmap
|
||||||
|
|
||||||
|
### 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
### 🚀 Q3 2025
|
||||||
|
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
### 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
@@ -609,15 +551,13 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
- **Microsoft Research** for the DiskANN algorithm
|
||||||
|
- **Meta AI** for FAISS and optimization insights
|
||||||
|
- **HuggingFace** for the transformer ecosystem
|
||||||
|
- **Our amazing contributors** who make this possible
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
---
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
|
||||||
|
|
||||||
## Star History
|
|
||||||
|
|
||||||
[](https://www.star-history.com/#yichuan-w/LEANN&Date)
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||||
</p>
|
</p>
|
||||||
@@ -625,3 +565,4 @@ This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.ed
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
@@ -1,324 +0,0 @@
|
|||||||
"""
|
|
||||||
Base class for unified RAG examples interface.
|
|
||||||
Provides common parameters and functionality for all RAG examples.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import dotenv
|
|
||||||
from leann.api import LeannBuilder, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRAGExample(ABC):
|
|
||||||
"""Base class for all RAG examples with unified interface."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
default_index_name: str,
|
|
||||||
):
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.default_index_name = default_index_name
|
|
||||||
self.parser = self._create_parser()
|
|
||||||
|
|
||||||
def _create_parser(self) -> argparse.ArgumentParser:
|
|
||||||
"""Create argument parser with common parameters."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Core parameters (all examples share these)
|
|
||||||
core_group = parser.add_argument_group("Core Parameters")
|
|
||||||
core_group.add_argument(
|
|
||||||
"--index-dir",
|
|
||||||
type=str,
|
|
||||||
default=f"./{self.default_index_name}",
|
|
||||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
|
||||||
)
|
|
||||||
core_group.add_argument(
|
|
||||||
"--query",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Query to run (if not provided, will run in interactive mode)",
|
|
||||||
)
|
|
||||||
# Allow subclasses to override default max_items
|
|
||||||
max_items_default = getattr(self, "max_items_default", -1)
|
|
||||||
core_group.add_argument(
|
|
||||||
"--max-items",
|
|
||||||
type=int,
|
|
||||||
default=max_items_default,
|
|
||||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
|
||||||
)
|
|
||||||
core_group.add_argument(
|
|
||||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Embedding parameters
|
|
||||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
|
||||||
# Allow subclasses to override default embedding_model
|
|
||||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
|
||||||
embedding_group.add_argument(
|
|
||||||
"--embedding-model",
|
|
||||||
type=str,
|
|
||||||
default=embedding_model_default,
|
|
||||||
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
|
||||||
)
|
|
||||||
embedding_group.add_argument(
|
|
||||||
"--embedding-mode",
|
|
||||||
type=str,
|
|
||||||
default="sentence-transformers",
|
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
|
||||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
|
||||||
)
|
|
||||||
|
|
||||||
# LLM parameters
|
|
||||||
llm_group = parser.add_argument_group("LLM Parameters")
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--llm",
|
|
||||||
type=str,
|
|
||||||
default="openai",
|
|
||||||
choices=["openai", "ollama", "hf", "simulated"],
|
|
||||||
help="LLM backend: openai, ollama, or hf (default: openai)",
|
|
||||||
)
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--llm-model",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
|
||||||
)
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--llm-host",
|
|
||||||
type=str,
|
|
||||||
default="http://localhost:11434",
|
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
|
||||||
)
|
|
||||||
llm_group.add_argument(
|
|
||||||
"--thinking-budget",
|
|
||||||
type=str,
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
default=None,
|
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search parameters
|
|
||||||
search_group = parser.add_argument_group("Search Parameters")
|
|
||||||
search_group.add_argument(
|
|
||||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
|
||||||
)
|
|
||||||
search_group.add_argument(
|
|
||||||
"--search-complexity",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Search complexity for graph traversal (default: 64)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index building parameters
|
|
||||||
index_group = parser.add_argument_group("Index Building Parameters")
|
|
||||||
index_group.add_argument(
|
|
||||||
"--backend-name",
|
|
||||||
type=str,
|
|
||||||
default="hnsw",
|
|
||||||
choices=["hnsw", "diskann"],
|
|
||||||
help="Backend to use for index (default: hnsw)",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--graph-degree",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Graph degree for index construction (default: 32)",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--build-complexity",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="Build complexity for index construction (default: 64)",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--no-compact",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable compact index storage",
|
|
||||||
)
|
|
||||||
index_group.add_argument(
|
|
||||||
"--no-recompute",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable embedding recomputation",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add source-specific parameters
|
|
||||||
self._add_specific_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
|
||||||
"""Add source-specific arguments. Override in subclasses."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load data from the source. Returns list of text chunks."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_llm_config(self, args) -> dict[str, Any]:
|
|
||||||
"""Get LLM configuration based on arguments."""
|
|
||||||
config = {"type": args.llm}
|
|
||||||
|
|
||||||
if args.llm == "openai":
|
|
||||||
config["model"] = args.llm_model or "gpt-4o"
|
|
||||||
elif args.llm == "ollama":
|
|
||||||
config["model"] = args.llm_model or "llama3.2:1b"
|
|
||||||
config["host"] = args.llm_host
|
|
||||||
elif args.llm == "hf":
|
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
|
||||||
elif args.llm == "simulated":
|
|
||||||
# Simulated LLM doesn't need additional configuration
|
|
||||||
pass
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
async def build_index(self, args, texts: list[str]) -> str:
|
|
||||||
"""Build LEANN index from texts."""
|
|
||||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
|
||||||
|
|
||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
|
||||||
print(f"Total text chunks: {len(texts)}")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=args.backend_name,
|
|
||||||
embedding_model=args.embedding_model,
|
|
||||||
embedding_mode=args.embedding_mode,
|
|
||||||
graph_degree=args.graph_degree,
|
|
||||||
complexity=args.build_complexity,
|
|
||||||
is_compact=not args.no_compact,
|
|
||||||
is_recompute=not args.no_recompute,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add texts in batches for better progress tracking
|
|
||||||
batch_size = 1000
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch = texts[i : i + batch_size]
|
|
||||||
for text in batch:
|
|
||||||
builder.add_text(text)
|
|
||||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
|
||||||
|
|
||||||
print("Building index structure...")
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"Index saved to: {index_path}")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def run_interactive_chat(self, args, index_path: str):
|
|
||||||
"""Run interactive chat with the index."""
|
|
||||||
chat = LeannChat(
|
|
||||||
index_path,
|
|
||||||
llm_config=self.get_llm_config(args),
|
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
|
||||||
print("Type 'quit' or 'exit' to stop.\n")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
query = input("You: ").strip()
|
|
||||||
if query.lower() in ["quit", "exit", "q"]:
|
|
||||||
print("Goodbye!")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not query:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"\nAssistant: {response}\n")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nGoodbye!")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
async def run_single_query(self, args, index_path: str, query: str):
|
|
||||||
"""Run a single query against the index."""
|
|
||||||
chat = LeannChat(
|
|
||||||
index_path,
|
|
||||||
llm_config=self.get_llm_config(args),
|
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
|
||||||
)
|
|
||||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""Main entry point for the example."""
|
|
||||||
args = self.parser.parse_args()
|
|
||||||
|
|
||||||
# Check if index exists
|
|
||||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
|
||||||
index_exists = Path(args.index_dir).exists()
|
|
||||||
|
|
||||||
if not index_exists or args.force_rebuild:
|
|
||||||
# Load data and build index
|
|
||||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
|
||||||
texts = await self.load_data(args)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
print("No data found to index!")
|
|
||||||
return
|
|
||||||
|
|
||||||
index_path = await self.build_index(args, texts)
|
|
||||||
else:
|
|
||||||
print(f"\nUsing existing index in {args.index_dir}")
|
|
||||||
|
|
||||||
# Run query or interactive mode
|
|
||||||
if args.query:
|
|
||||||
await self.run_single_query(args, index_path, args.query)
|
|
||||||
else:
|
|
||||||
await self.run_interactive_chat(args, index_path)
|
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
|
||||||
"""Helper function to create text chunks from documents."""
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator=" ",
|
|
||||||
paragraph_separator="\n\n",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
if nodes:
|
|
||||||
all_texts.extend(node.get_content() for node in nodes)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
0
apps/benchmarks/__init__.py
Normal file
0
apps/benchmarks/__init__.py
Normal file
338
apps/benchmarks/__main__.py
Normal file
338
apps/benchmarks/__main__.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
"""Get current memory usage in MB"""
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory_stats(stage: str, start_mem: float):
|
||||||
|
"""Print memory statistics"""
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - start_mem
|
||||||
|
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
print(f"\n=== {self.name} Memory Summary ===")
|
||||||
|
for stage, mem in self.stages:
|
||||||
|
print(f"{stage}: {mem:.1f} MB")
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_hnsw():
|
||||||
|
"""Test Faiss HNSW Vector Store in subprocess"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING FAISS HNSW VECTOR STORE")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the directory of this script
|
||||||
|
script_dir = Path(__file__).parent
|
||||||
|
faiss_script = script_dir / "faiss_only.py"
|
||||||
|
result = subprocess.run(
|
||||||
|
[sys.executable, str(faiss_script)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Stderr:", result.stderr)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": f"Process failed with code {result.returncode}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse peak memory from output
|
||||||
|
lines = result.stdout.split("\n")
|
||||||
|
peak_memory = 0.0
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "Peak Memory:" in line:
|
||||||
|
peak_memory = float(
|
||||||
|
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"peak_memory": float("inf"),
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_leann_hnsw():
|
||||||
|
"""Test LEANN HNSW Search Memory (load existing index)"""
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("TESTING LEANN HNSW SEARCH MEMORY")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
tracker = MemoryTracker("LEANN HNSW Search")
|
||||||
|
|
||||||
|
# Import and setup
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
# Load and parse documents
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"../documents/data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
|
# Build LEANN index
|
||||||
|
INDEX_DIR = Path("./test_leann_comparison")
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
||||||
|
|
||||||
|
# Check if index already exists
|
||||||
|
if os.path.exists(INDEX_PATH + ".meta.json"):
|
||||||
|
print("Loading existing LEANN HNSW index...")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
else:
|
||||||
|
print("Building new LEANN HNSW index...")
|
||||||
|
# Clean up previous index
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
shutil.rmtree(INDEX_DIR)
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After builder setup")
|
||||||
|
|
||||||
|
print("Building LEANN HNSW index...")
|
||||||
|
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Find existing LEANN index
|
||||||
|
index_paths = [
|
||||||
|
"./test_leann_comparison/comparison.leann",
|
||||||
|
]
|
||||||
|
index_path = None
|
||||||
|
for path in index_paths:
|
||||||
|
if os.path.exists(path + ".meta.json"):
|
||||||
|
index_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
if not index_path:
|
||||||
|
print("❌ LEANN index not found. Please build it first")
|
||||||
|
return {"peak_memory": float("inf"), "error": "Index not found"}
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
# Load searcher
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print("Running search queries...")
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
||||||
|
_ = searcher.search(query, top_k=20, ef=120)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
# Get storage size before cleanup
|
||||||
|
storage_size = 0
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
if INDEX_DIR.exists():
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
||||||
|
for filename in filenames:
|
||||||
|
# Only count actual index files, skip text data and backups
|
||||||
|
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
||||||
|
continue
|
||||||
|
# Count .index, .idx, .map files (actual index structures)
|
||||||
|
if filename.endswith((".index", ".idx", ".map")):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del searcher
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"peak_memory": peak_memory,
|
||||||
|
"storage_size": storage_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run comparison tests"""
|
||||||
|
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Test Faiss HNSW
|
||||||
|
faiss_results = test_faiss_hnsw()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# Test LEANN HNSW
|
||||||
|
leann_results = test_leann_hnsw()
|
||||||
|
|
||||||
|
# Final comparison
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("STORAGE + SEARCH MEMORY COMPARISON")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Get storage sizes
|
||||||
|
faiss_storage_size = 0
|
||||||
|
leann_storage_size = leann_results.get("storage_size", 0)
|
||||||
|
|
||||||
|
# Get Faiss storage size using Python
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
||||||
|
for filename in filenames:
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
total_size += os.path.getsize(filepath)
|
||||||
|
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
||||||
|
|
||||||
|
print("Faiss HNSW:")
|
||||||
|
if "error" in faiss_results:
|
||||||
|
print(f" ❌ Failed: {faiss_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
print("\nLEANN HNSW:")
|
||||||
|
if "error" in leann_results:
|
||||||
|
print(f" ❌ Failed: {leann_results['error']}")
|
||||||
|
else:
|
||||||
|
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
# Calculate improvements only if both tests succeeded
|
||||||
|
if "error" not in faiss_results and "error" not in leann_results:
|
||||||
|
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
||||||
|
|
||||||
|
print("\nLEANN vs Faiss Performance:")
|
||||||
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
|
print(
|
||||||
|
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Storage comparison
|
||||||
|
if leann_storage_size > faiss_storage_size:
|
||||||
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
|
print(
|
||||||
|
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||||
|
)
|
||||||
|
elif faiss_storage_size > leann_storage_size:
|
||||||
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
|
print(
|
||||||
|
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(" Storage Size: similar")
|
||||||
|
else:
|
||||||
|
if "error" not in leann_results:
|
||||||
|
print("\n✅ LEANN HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
||||||
|
if "error" not in faiss_results:
|
||||||
|
print("\n✅ Faiss HNSW completed successfully!")
|
||||||
|
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
||||||
|
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
151
apps/benchmarks/faiss_only.py
Normal file
151
apps/benchmarks/faiss_only.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import psutil
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_usage():
|
||||||
|
process = psutil.Process()
|
||||||
|
return process.memory_info().rss / 1024 / 1024
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTracker:
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.start_mem = get_memory_usage()
|
||||||
|
self.stages = []
|
||||||
|
|
||||||
|
def checkpoint(self, stage: str):
|
||||||
|
current_mem = get_memory_usage()
|
||||||
|
diff = current_mem - self.start_mem
|
||||||
|
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
||||||
|
self.stages.append((stage, current_mem))
|
||||||
|
return current_mem
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
peak_mem = max(mem for _, mem in self.stages)
|
||||||
|
print(f"Peak Memory: {peak_mem:.1f} MB")
|
||||||
|
return peak_mem
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
print("Faiss is not installed.")
|
||||||
|
print("Please install it with `uv pip install faiss-cpu`")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from llama_index.core import (
|
||||||
|
SimpleDirectoryReader,
|
||||||
|
VectorStoreIndex,
|
||||||
|
StorageContext,
|
||||||
|
Settings,
|
||||||
|
node_parser,
|
||||||
|
Document,
|
||||||
|
)
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
|
||||||
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
|
tracker.checkpoint("Initial")
|
||||||
|
|
||||||
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
|
Settings.embed_model = embed_model
|
||||||
|
tracker.checkpoint("After embedding model setup")
|
||||||
|
|
||||||
|
d = 768
|
||||||
|
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
||||||
|
faiss_index.hnsw.efConstruction = 64
|
||||||
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
"../documents/data",
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data()
|
||||||
|
tracker.checkpoint("After document loading")
|
||||||
|
|
||||||
|
# Parse into chunks using the same splitter as LEANN
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.checkpoint("After text splitter setup")
|
||||||
|
|
||||||
|
# Check if index already exists and try to load it
|
||||||
|
index_loaded = False
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
print("Loading existing Faiss HNSW index...")
|
||||||
|
try:
|
||||||
|
# Use the correct Faiss loading pattern from the example
|
||||||
|
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
||||||
|
storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
|
)
|
||||||
|
from llama_index.core import load_index_from_storage
|
||||||
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
|
print(f"Index loaded from ./storage_faiss")
|
||||||
|
tracker.checkpoint("After loading existing index")
|
||||||
|
index_loaded = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load existing index: {e}")
|
||||||
|
print("Cleaning up corrupted index and building new one...")
|
||||||
|
# Clean up corrupted index
|
||||||
|
import shutil
|
||||||
|
if os.path.exists("./storage_faiss"):
|
||||||
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
|
if not index_loaded:
|
||||||
|
print("Building new Faiss HNSW index...")
|
||||||
|
|
||||||
|
# Use the correct Faiss building pattern from the example
|
||||||
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
|
index = VectorStoreIndex.from_documents(
|
||||||
|
documents,
|
||||||
|
storage_context=storage_context,
|
||||||
|
transformations=[node_parser]
|
||||||
|
)
|
||||||
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
|
# Save index to disk using the correct pattern
|
||||||
|
index.storage_context.persist(persist_dir="./storage_faiss")
|
||||||
|
tracker.checkpoint("After index saving")
|
||||||
|
|
||||||
|
# Measure runtime memory overhead
|
||||||
|
print("\nMeasuring runtime memory overhead...")
|
||||||
|
runtime_start_mem = get_memory_usage()
|
||||||
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
|
queries = [
|
||||||
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
|
"What is LEANN and how does it work?",
|
||||||
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
start_time = time.time()
|
||||||
|
_ = query_engine.query(query)
|
||||||
|
query_time = time.time() - start_time
|
||||||
|
print(f"Query {i + 1} time: {query_time:.3f}s")
|
||||||
|
tracker.checkpoint(f"After query {i + 1}")
|
||||||
|
|
||||||
|
runtime_end_mem = get_memory_usage()
|
||||||
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
|
peak_memory = tracker.summary()
|
||||||
|
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||||
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
apps/browser/__init__.py
Normal file
0
apps/browser/__init__.py
Normal file
201
apps/browser/__main__.py
Normal file
201
apps/browser/__main__.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
try:
|
||||||
|
import dotenv
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# python-dotenv is not installed; skip loading environment variables
|
||||||
|
dotenv = None
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# Default Chrome profile path
|
||||||
|
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple Chrome profile data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process per profile
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from local readers module
|
||||||
|
from .readers import ChromeHistoryReader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Chrome profile directory
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_count
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {profile_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=32,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 1000
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||||
|
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||||
|
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||||
|
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
||||||
|
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||||
|
parser.add_argument('--max-entries', type=int, default=1000,
|
||||||
|
help='Maximum number of history entries to process (default: 1000)')
|
||||||
|
parser.add_argument('--query', type=str, default=None,
|
||||||
|
help='Single query to run (default: runs example queries)')
|
||||||
|
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
||||||
|
help='Automatically find all Chrome profiles (default: True)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||||
|
|
||||||
|
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Find Chrome profile directories
|
||||||
|
from .readers import ChromeHistoryReader
|
||||||
|
|
||||||
|
if args.auto_find_profiles:
|
||||||
|
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found automatically. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Use single specified profile
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
if not profile_path.exists():
|
||||||
|
print(f"Chrome profile not found: {profile_path}")
|
||||||
|
return
|
||||||
|
profile_dirs = [profile_path]
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"What websites did I visit about machine learning?",
|
||||||
|
"Find my search history about programming"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
176
apps/browser/readers.py
Normal file
176
apps/browser/readers.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
import sqlite3
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
class ChromeHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
|
|
||||||
|
Reads Chrome history from the default Chrome profile location and creates documents
|
||||||
|
with embedded metadata similar to the email reader structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Not used for Chrome history (kept for compatibility)
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of history entries to read.
|
||||||
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||||
|
|
||||||
|
# Default Chrome profile path on macOS
|
||||||
|
if chrome_profile_path is None:
|
||||||
|
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to the Chrome history database
|
||||||
|
print(f"Connecting to database: {history_db_path}")
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Query to get browsing history with metadata (removed created_time column)
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
print(f"Executing query on database: {history_db_path}")
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
print(f"Query returned {len(rows)} rows")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[BROWSING HISTORY METADATA]
|
||||||
|
URL: {url}
|
||||||
|
Title: {title}
|
||||||
|
Last Visit: {last_visit}
|
||||||
|
Visit Count: {visit_count}
|
||||||
|
Typed Count: {typed_count}
|
||||||
|
Hidden: {hidden}
|
||||||
|
[END METADATA]
|
||||||
|
|
||||||
|
Title: {title}
|
||||||
|
URL: {url}
|
||||||
|
Last visited: {last_visit}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Loaded {len(docs)} Chrome history documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_chrome_profiles() -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to Chrome profile directories
|
||||||
|
"""
|
||||||
|
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||||
|
profile_dirs = []
|
||||||
|
|
||||||
|
if not chrome_base_path.exists():
|
||||||
|
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
# Find all profile directories
|
||||||
|
for profile_dir in chrome_base_path.iterdir():
|
||||||
|
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||||
|
history_path = profile_dir / "History"
|
||||||
|
if history_path.exists():
|
||||||
|
profile_dirs.append(profile_dir)
|
||||||
|
print(f"Found Chrome profile: {profile_dir}")
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
return profile_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
"""
|
||||||
|
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
|
if not os.path.exists(history_db_path):
|
||||||
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(history_db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
|
url,
|
||||||
|
title,
|
||||||
|
visit_count,
|
||||||
|
typed_count,
|
||||||
|
hidden
|
||||||
|
FROM urls
|
||||||
|
ORDER BY last_visit_time DESC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(query, (max_count,))
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for row in rows:
|
||||||
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting Chrome history: {e}")
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
"""
|
|
||||||
Browser History RAG example using the unified interface.
|
|
||||||
Supports Chrome browser history.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
|
||||||
|
|
||||||
from .history_data.history import ChromeHistoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class BrowserRAG(BaseRAGExample):
|
|
||||||
"""RAG example for Chrome browser history."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="Browser History",
|
|
||||||
description="Process and query Chrome browser history with LEANN",
|
|
||||||
default_index_name="google_history_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add browser-specific arguments."""
|
|
||||||
browser_group = parser.add_argument_group("Browser Parameters")
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--chrome-profile",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
|
||||||
)
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--auto-find-profiles",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Automatically find all Chrome profiles (default: True)",
|
|
||||||
)
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
|
||||||
)
|
|
||||||
browser_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_chrome_base_path(self) -> Path:
|
|
||||||
"""Get the base Chrome profile path based on OS."""
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
|
||||||
elif sys.platform.startswith("linux"):
|
|
||||||
return Path.home() / ".config" / "google-chrome"
|
|
||||||
elif sys.platform == "win32":
|
|
||||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
|
||||||
|
|
||||||
def _find_chrome_profiles(self) -> list[Path]:
|
|
||||||
"""Auto-detect all Chrome profiles."""
|
|
||||||
base_path = self._get_chrome_base_path()
|
|
||||||
if not base_path.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
profiles = []
|
|
||||||
|
|
||||||
# Check Default profile
|
|
||||||
default_profile = base_path / "Default"
|
|
||||||
if default_profile.exists() and (default_profile / "History").exists():
|
|
||||||
profiles.append(default_profile)
|
|
||||||
|
|
||||||
# Check numbered profiles
|
|
||||||
for item in base_path.iterdir():
|
|
||||||
if item.is_dir() and item.name.startswith("Profile "):
|
|
||||||
if (item / "History").exists():
|
|
||||||
profiles.append(item)
|
|
||||||
|
|
||||||
return profiles
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load browser history and convert to text chunks."""
|
|
||||||
# Determine Chrome profiles
|
|
||||||
if args.chrome_profile and not args.auto_find_profiles:
|
|
||||||
profile_dirs = [Path(args.chrome_profile)]
|
|
||||||
else:
|
|
||||||
print("Auto-detecting Chrome profiles...")
|
|
||||||
profile_dirs = self._find_chrome_profiles()
|
|
||||||
|
|
||||||
# If specific profile given, filter to just that one
|
|
||||||
if args.chrome_profile:
|
|
||||||
profile_path = Path(args.chrome_profile)
|
|
||||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
|
||||||
|
|
||||||
if not profile_dirs:
|
|
||||||
print("No Chrome profiles found!")
|
|
||||||
print("Please specify --chrome-profile manually")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
|
||||||
|
|
||||||
# Create reader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
# Process each profile
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, profile_dir in enumerate(profile_dirs):
|
|
||||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per profile
|
|
||||||
max_per_profile = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_profile = remaining
|
|
||||||
|
|
||||||
# Load history
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=str(profile_dir),
|
|
||||||
max_count=max_per_profile,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} history entries from this profile")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {profile_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No browser history found to process!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for browser history RAG
|
|
||||||
print("\n🌐 Browser History RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What websites did I visit about machine learning?'")
|
|
||||||
print("- 'Find my search history about programming'")
|
|
||||||
print("- 'What YouTube videos did I watch recently?'")
|
|
||||||
print("- 'Show me websites about travel planning'")
|
|
||||||
print("\nNote: Make sure Chrome is closed before running\n")
|
|
||||||
|
|
||||||
rag = BrowserRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
"""
|
|
||||||
Document RAG example using the unified interface.
|
|
||||||
Supports PDF, TXT, MD, and other document formats.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentRAG(BaseRAGExample):
|
|
||||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
name="Document",
|
|
||||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
|
||||||
default_index_name="test_doc_files",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add document-specific arguments."""
|
|
||||||
doc_group = parser.add_argument_group("Document Parameters")
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--data-dir",
|
|
||||||
type=str,
|
|
||||||
default="data",
|
|
||||||
help="Directory containing documents to index (default: data)",
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--file-types",
|
|
||||||
nargs="+",
|
|
||||||
default=None,
|
|
||||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
|
||||||
)
|
|
||||||
doc_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load documents and convert to text chunks."""
|
|
||||||
print(f"Loading documents from: {args.data_dir}")
|
|
||||||
if args.file_types:
|
|
||||||
print(f"Filtering by file types: {args.file_types}")
|
|
||||||
else:
|
|
||||||
print("Processing all supported file types")
|
|
||||||
|
|
||||||
# Check if data directory exists
|
|
||||||
data_path = Path(args.data_dir)
|
|
||||||
if not data_path.exists():
|
|
||||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
|
||||||
|
|
||||||
# Load documents
|
|
||||||
reader_kwargs = {
|
|
||||||
"recursive": True,
|
|
||||||
"encoding": "utf-8",
|
|
||||||
}
|
|
||||||
if args.file_types:
|
|
||||||
reader_kwargs["required_exts"] = args.file_types
|
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
|
||||||
show_progress=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} documents")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
|
||||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
|
||||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
|
||||||
all_texts = all_texts[: args.max_items]
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Example queries for document RAG
|
|
||||||
print("\n📄 Document RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What are the main techniques LEANN uses?'")
|
|
||||||
print("- 'What is the technique DLPM?'")
|
|
||||||
print("- 'Who does Elizabeth Bennet marry?'")
|
|
||||||
print(
|
|
||||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
|
||||||
)
|
|
||||||
print("\nOr run without --query for interactive mode\n")
|
|
||||||
|
|
||||||
rag = DocumentRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
0
apps/documents/__init__.py
Normal file
0
apps/documents/__init__.py
Normal file
113
apps/documents/__main__.py
Normal file
113
apps/documents/__main__.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import argparse
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args):
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Loading documents...")
|
||||||
|
# Get the data directory relative to this module
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
data_dir = current_dir / "data"
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
str(data_dir),
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
print("Documents loaded.")
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
|
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||||
|
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
|
||||||
|
# query = (
|
||||||
|
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
# )
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Leann Chat with various LLM backends."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
help="The LLM backend to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="Qwen/Qwen3-0.6B",
|
||||||
|
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="The host for the Ollama API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./test_doc_files",
|
||||||
|
help="Directory where the Leann index will be stored.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(main(args))
|
||||||
82
apps/documents/data/pangu.md
Normal file
82
apps/documents/data/pangu.md
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
||||||
|
|
||||||
|
首先为自证身份,列举一些细节:
|
||||||
|
|
||||||
|
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
||||||
|
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
||||||
|
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
||||||
|
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
||||||
|
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
||||||
|
|
||||||
|
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
||||||
|
|
||||||
|
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
||||||
|
|
||||||
|
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
||||||
|
|
||||||
|
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
||||||
|
|
||||||
|
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
||||||
|
|
||||||
|
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
||||||
|
|
||||||
|
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
||||||
|
|
||||||
|
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
||||||
|
|
||||||
|
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
||||||
|
|
||||||
|
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
||||||
|
|
||||||
|
|
||||||
|
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
||||||
|
|
||||||
|
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
||||||
|
|
||||||
|
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
||||||
|
|
||||||
|
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
||||||
|
|
||||||
|
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
||||||
|
|
||||||
|
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
||||||
|
|
||||||
|
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
||||||
|
|
||||||
|
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
||||||
|
|
||||||
|
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
||||||
|
|
||||||
|
现在,我累了,我想投降。
|
||||||
|
|
||||||
|
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
||||||
|
|
||||||
|
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
||||||
|
|
||||||
|
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
||||||
|
|
||||||
|
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
||||||
|
|
||||||
|
诺亚,再见
|
||||||
|
|
||||||
|
2025年7月6日凌晨 写于深圳
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
各位好,
|
||||||
|
|
||||||
|
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
||||||
|
|
||||||
|
我补充一些细节,以免某些人继续颠倒黑白。
|
||||||
|
|
||||||
|
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
||||||
|
|
||||||
|
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
||||||
0
apps/email/__init__.py
Normal file
0
apps/email/__init__.py
Normal file
193
apps/email/__main__.py
Normal file
193
apps/email/__main__.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
def get_mail_path():
|
||||||
|
"""Get the mail path for the current user"""
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_dirs: List of Path objects pointing to Messages directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process per directory
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple mail data sources...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from local readers module
|
||||||
|
from .readers import EmlxReader, find_all_messages_directories
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Messages directory
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(messages_dir)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {messages_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path,
|
||||||
|
llm_config={"type": "openai", "model": "gpt-4o"})
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=12,
|
||||||
|
beam_width=1,
|
||||||
|
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||||
|
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
||||||
|
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||||
|
parser.add_argument('--max-emails', type=int, default=1000,
|
||||||
|
help='Maximum number of emails to process (-1 means all)')
|
||||||
|
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
||||||
|
help='Single query to run (default: runs example queries)')
|
||||||
|
parser.add_argument('--include-html', action='store_true', default=False,
|
||||||
|
help='Include HTML content in email processing (default: False)')
|
||||||
|
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||||
|
help='Embedding model to use (default: facebook/contriever)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"args: {args}")
|
||||||
|
|
||||||
|
# Automatically find all Messages directories under the current user's Mail directory
|
||||||
|
from .readers import find_all_messages_directories
|
||||||
|
mail_path = get_mail_path()
|
||||||
|
print(f"Searching for email data in: {mail_path}")
|
||||||
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
|
||||||
|
print('len(messages_dirs): ', len(messages_dirs))
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Messages directories found. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fsspec import AbstractFileSystem
|
from fsspec import AbstractFileSystem
|
||||||
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.core.schema import Document
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
@@ -27,7 +27,11 @@ class MboxReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MESSAGE_FORMAT: str = (
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
"Date: {_date}\n"
|
||||||
|
"From: {_from}\n"
|
||||||
|
"To: {_to}\n"
|
||||||
|
"Subject: {_subject}\n"
|
||||||
|
"Content: {_content}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -41,7 +45,9 @@ class MboxReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup # noqa
|
from bs4 import BeautifulSoup # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
raise ImportError(
|
||||||
|
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
@@ -50,9 +56,9 @@ class MboxReader(BaseReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
file: Path,
|
file: Path,
|
||||||
extra_info: dict | None = None,
|
extra_info: Optional[Dict] = None,
|
||||||
fs: AbstractFileSystem | None = None,
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
) -> list[Document]:
|
) -> List[Document]:
|
||||||
"""Parse file into string."""
|
"""Parse file into string."""
|
||||||
# Import required libraries
|
# Import required libraries
|
||||||
import mailbox
|
import mailbox
|
||||||
@@ -68,7 +74,7 @@ class MboxReader(BaseReader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
results: list[str] = []
|
results: List[str] = []
|
||||||
# Load file using mailbox
|
# Load file using mailbox
|
||||||
bytes_parser = BytesParser(policy=default).parse
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
@@ -128,12 +134,12 @@ class EmlxMboxReader(MboxReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
directory: Path,
|
||||||
extra_info: dict | None = None,
|
extra_info: Optional[Dict] = None,
|
||||||
fs: AbstractFileSystem | None = None,
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
) -> list[Document]:
|
) -> List[Document]:
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
if fs:
|
if fs:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -150,18 +156,18 @@ class EmlxMboxReader(MboxReader):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a temporary mbox file
|
# Create a temporary mbox file
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||||
temp_mbox_path = temp_mbox.name
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
# Convert .emlx files to mbox format
|
# Convert .emlx files to mbox format
|
||||||
for emlx_file in emlx_files:
|
for emlx_file in emlx_files:
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# .emlx format: first line is length, rest is email content
|
# .emlx format: first line is length, rest is email content
|
||||||
lines = content.split("\n", 1)
|
lines = content.split('\n', 1)
|
||||||
if len(lines) >= 2:
|
if len(lines) >= 2:
|
||||||
email_content = lines[1] # Skip the length line
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
@@ -182,5 +188,5 @@ class EmlxMboxReader(MboxReader):
|
|||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
try:
|
try:
|
||||||
os.unlink(temp_mbox_path)
|
os.unlink(temp_mbox_path)
|
||||||
except OSError:
|
except:
|
||||||
pass
|
pass
|
||||||
124
apps/email/readers.py
Normal file
124
apps/email/readers.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str = None) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, dirnames, filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||||
|
if part.get_content_type() == "text/html" and not self.include_html:
|
||||||
|
continue
|
||||||
|
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[EMAIL METADATA]
|
||||||
|
File: {filename}
|
||||||
|
From: {from_addr}
|
||||||
|
To: {to_addr}
|
||||||
|
Subject: {subject}
|
||||||
|
Date: {date}
|
||||||
|
[END METADATA]
|
||||||
|
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
import email
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
|
|
||||||
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
|
||||||
"""
|
|
||||||
Recursively find all 'Messages' directories under the given root.
|
|
||||||
Returns a list of Path objects.
|
|
||||||
"""
|
|
||||||
if root is None:
|
|
||||||
# Auto-detect user's mail path
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
root = os.path.join(home_dir, "Library", "Mail")
|
|
||||||
|
|
||||||
messages_dirs = []
|
|
||||||
for dirpath, _dirnames, _filenames in os.walk(root):
|
|
||||||
if os.path.basename(dirpath) == "Messages":
|
|
||||||
messages_dirs.append(Path(dirpath))
|
|
||||||
return messages_dirs
|
|
||||||
|
|
||||||
|
|
||||||
class EmlxReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Apple Mail .emlx file reader with embedded metadata.
|
|
||||||
|
|
||||||
Reads individual .emlx files from Apple Mail's storage format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, include_html: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_html: Whether to include HTML content in the email body (default: False)
|
|
||||||
"""
|
|
||||||
self.include_html = include_html
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
|
||||||
"""
|
|
||||||
Load data from the input directory containing .emlx files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing .emlx files
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum amount of messages to read.
|
|
||||||
"""
|
|
||||||
docs: list[Document] = []
|
|
||||||
max_count = load_kwargs.get("max_count", 1000)
|
|
||||||
count = 0
|
|
||||||
total_files = 0
|
|
||||||
successful_files = 0
|
|
||||||
failed_files = 0
|
|
||||||
|
|
||||||
print(f"Starting to process directory: {input_dir}")
|
|
||||||
|
|
||||||
# Walk through the directory recursively
|
|
||||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
|
||||||
# Skip hidden directories
|
|
||||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
|
||||||
|
|
||||||
for filename in filenames:
|
|
||||||
# Check if we've reached the max count (skip if max_count == -1)
|
|
||||||
if max_count > 0 and count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if filename.endswith(".emlx"):
|
|
||||||
total_files += 1
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
|
||||||
try:
|
|
||||||
# Read the .emlx file
|
|
||||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# .emlx files have a length prefix followed by the email content
|
|
||||||
# The first line contains the length, followed by the email
|
|
||||||
lines = content.split("\n", 1)
|
|
||||||
if len(lines) >= 2:
|
|
||||||
email_content = lines[1]
|
|
||||||
|
|
||||||
# Parse the email using Python's email module
|
|
||||||
try:
|
|
||||||
msg = email.message_from_string(email_content)
|
|
||||||
|
|
||||||
# Extract email metadata
|
|
||||||
subject = msg.get("Subject", "No Subject")
|
|
||||||
from_addr = msg.get("From", "Unknown")
|
|
||||||
to_addr = msg.get("To", "Unknown")
|
|
||||||
date = msg.get("Date", "Unknown")
|
|
||||||
|
|
||||||
# Extract email body
|
|
||||||
body = ""
|
|
||||||
if msg.is_multipart():
|
|
||||||
for part in msg.walk():
|
|
||||||
if (
|
|
||||||
part.get_content_type() == "text/plain"
|
|
||||||
or part.get_content_type() == "text/html"
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
part.get_content_type() == "text/html"
|
|
||||||
and not self.include_html
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
payload = part.get_payload(decode=True)
|
|
||||||
if payload:
|
|
||||||
body += payload.decode("utf-8", errors="ignore")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error decoding payload: {e}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
payload = msg.get_payload(decode=True)
|
|
||||||
if payload:
|
|
||||||
body = payload.decode("utf-8", errors="ignore")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error decoding single part payload: {e}")
|
|
||||||
body = ""
|
|
||||||
|
|
||||||
# Only create document if we have some content
|
|
||||||
if body.strip() or subject != "No Subject":
|
|
||||||
# Create document content with metadata embedded in text
|
|
||||||
doc_content = f"""
|
|
||||||
[File]: {filename}
|
|
||||||
[From]: {from_addr}
|
|
||||||
[To]: {to_addr}
|
|
||||||
[Subject]: {subject}
|
|
||||||
[Date]: {date}
|
|
||||||
[EMAIL BODY Start]:
|
|
||||||
{body}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# No separate metadata - everything is in the text
|
|
||||||
doc = Document(text=doc_content, metadata={})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
successful_files += 1
|
|
||||||
|
|
||||||
# Print first few successful files for debugging
|
|
||||||
if successful_files <= 3:
|
|
||||||
print(
|
|
||||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
failed_files += 1
|
|
||||||
if failed_files <= 5: # Only print first few errors
|
|
||||||
print(f"Error parsing email from {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
failed_files += 1
|
|
||||||
if failed_files <= 5: # Only print first few errors
|
|
||||||
print(f"Error reading file {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print("Processing summary:")
|
|
||||||
print(f" Total .emlx files found: {total_files}")
|
|
||||||
print(f" Successfully loaded: {successful_files}")
|
|
||||||
print(f" Failed to load: {failed_files}")
|
|
||||||
print(f" Final documents: {len(docs)}")
|
|
||||||
|
|
||||||
return docs
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
"""
|
|
||||||
Email RAG example using the unified interface.
|
|
||||||
Supports Apple Mail on macOS.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
|
||||||
|
|
||||||
from .email_data.LEANN_email_reader import EmlxReader
|
|
||||||
|
|
||||||
|
|
||||||
class EmailRAG(BaseRAGExample):
|
|
||||||
"""RAG example for Apple Mail processing."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Process all emails by default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="Email",
|
|
||||||
description="Process and query Apple Mail emails with LEANN",
|
|
||||||
default_index_name="mail_index",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add email-specific arguments."""
|
|
||||||
email_group = parser.add_argument_group("Email Parameters")
|
|
||||||
email_group.add_argument(
|
|
||||||
"--mail-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
|
||||||
)
|
|
||||||
email_group.add_argument(
|
|
||||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
|
||||||
)
|
|
||||||
email_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
|
||||||
)
|
|
||||||
email_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _find_mail_directories(self) -> list[Path]:
|
|
||||||
"""Auto-detect all Apple Mail directories."""
|
|
||||||
mail_base = Path.home() / "Library" / "Mail"
|
|
||||||
if not mail_base.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Find all Messages directories
|
|
||||||
messages_dirs = []
|
|
||||||
for item in mail_base.rglob("Messages"):
|
|
||||||
if item.is_dir():
|
|
||||||
messages_dirs.append(item)
|
|
||||||
|
|
||||||
return messages_dirs
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load emails and convert to text chunks."""
|
|
||||||
# Determine mail directories
|
|
||||||
if args.mail_path:
|
|
||||||
messages_dirs = [Path(args.mail_path)]
|
|
||||||
else:
|
|
||||||
print("Auto-detecting Apple Mail directories...")
|
|
||||||
messages_dirs = self._find_mail_directories()
|
|
||||||
|
|
||||||
if not messages_dirs:
|
|
||||||
print("No Apple Mail directories found!")
|
|
||||||
print("Please specify --mail-path manually")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"Found {len(messages_dirs)} mail directories")
|
|
||||||
|
|
||||||
# Create reader
|
|
||||||
reader = EmlxReader(include_html=args.include_html)
|
|
||||||
|
|
||||||
# Process each directory
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, messages_dir in enumerate(messages_dirs):
|
|
||||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Count emlx files
|
|
||||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
|
||||||
print(f"Found {len(emlx_files)} email files")
|
|
||||||
|
|
||||||
# Apply max_items limit per directory
|
|
||||||
max_per_dir = -1 # Default to process all
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_dir = remaining
|
|
||||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
|
||||||
|
|
||||||
# Load emails - fix the parameter passing
|
|
||||||
documents = reader.load_data(
|
|
||||||
input_dir=str(messages_dir),
|
|
||||||
max_count=max_per_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
print(f"Processed {len(documents)} emails from this directory")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {messages_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No emails found to process!")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
|
||||||
print("now starting to split into text chunks ... take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks
|
|
||||||
# Email reader uses chunk_overlap=25 as in original
|
|
||||||
all_texts = create_text_chunks(
|
|
||||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Check platform
|
|
||||||
if sys.platform != "darwin":
|
|
||||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
|
||||||
print(" Windows/Linux support coming soon!\n")
|
|
||||||
|
|
||||||
# Example queries for email RAG
|
|
||||||
print("\n📧 Email RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'What did my boss say about deadlines?'")
|
|
||||||
print("- 'Find emails about travel expenses'")
|
|
||||||
print("- 'Show me emails from last month about the project'")
|
|
||||||
print("- 'What food did I order from DoorDash?'")
|
|
||||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
|
||||||
|
|
||||||
rag = EmailRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
0
apps/evaluation/__init__.py
Normal file
0
apps/evaluation/__init__.py
Normal file
382
apps/evaluation/__main__.py
Normal file
382
apps/evaluation/__main__.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
This script runs a recall evaluation on a given LEANN index.
|
||||||
|
It correctly compares results by fetching the text content for both the new search
|
||||||
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher, LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
|
if not data_root.exists():
|
||||||
|
print(f"Data directory '{data_root}' not found.")
|
||||||
|
print(
|
||||||
|
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
if download_embeddings:
|
||||||
|
# Download everything including embeddings (large files)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
)
|
||||||
|
print("Data download complete (including embeddings)!")
|
||||||
|
else:
|
||||||
|
# Download only specific folders, excluding embeddings
|
||||||
|
allow_patterns = [
|
||||||
|
"ground_truth/**",
|
||||||
|
"indices/**",
|
||||||
|
"queries/**",
|
||||||
|
"*.md",
|
||||||
|
"*.txt",
|
||||||
|
]
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
print("Data download complete (excluding embeddings)!")
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
|
)
|
||||||
|
print("uv pip install -e '.[dev]'")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred during data download: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||||
|
"""Download embeddings files specifically."""
|
||||||
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
# Check if specific dataset embeddings exist
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
print(f"Embeddings for {dataset_type} already exist")
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
print("Downloading embeddings from HuggingFace Hub...")
|
||||||
|
try:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download only embeddings folder
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=data_root,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=["embeddings/**/*.pkl"],
|
||||||
|
)
|
||||||
|
print("Embeddings download complete!")
|
||||||
|
|
||||||
|
if dataset_type:
|
||||||
|
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
||||||
|
if target_file.exists():
|
||||||
|
return str(target_file)
|
||||||
|
|
||||||
|
return str(embeddings_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading embeddings: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper Function to get Golden Passages ---
|
||||||
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||||
|
"""
|
||||||
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
|
passage manager.
|
||||||
|
"""
|
||||||
|
golden_texts = set()
|
||||||
|
for gid in golden_ids:
|
||||||
|
try:
|
||||||
|
# PassageManager uses string IDs
|
||||||
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
|
golden_texts.add(passage_data["text"])
|
||||||
|
except KeyError:
|
||||||
|
print(
|
||||||
|
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||||
|
)
|
||||||
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(file_path: Path) -> List[str]:
|
||||||
|
queries = []
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["query"])
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def build_index_from_embeddings(
|
||||||
|
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
||||||
|
output_path: Path where to save the index
|
||||||
|
backend: Backend to use ("hnsw" or "diskann")
|
||||||
|
"""
|
||||||
|
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
||||||
|
|
||||||
|
# Create builder with appropriate parameters
|
||||||
|
if backend == "hnsw":
|
||||||
|
builder_kwargs = {
|
||||||
|
"M": 32, # Graph degree
|
||||||
|
"efConstruction": 256, # Construction complexity
|
||||||
|
"is_compact": True, # Use compact storage
|
||||||
|
"is_recompute": True, # Enable pruning for better recall
|
||||||
|
}
|
||||||
|
elif backend == "diskann":
|
||||||
|
builder_kwargs = {
|
||||||
|
"complexity": 64,
|
||||||
|
"graph_degree": 32,
|
||||||
|
"search_memory_maximum": 8.0, # GB
|
||||||
|
"build_memory_maximum": 16.0, # GB
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builder_kwargs = {}
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
||||||
|
dimensions=768, # Will be auto-detected from embeddings
|
||||||
|
**builder_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build index from precomputed embeddings
|
||||||
|
builder.build_index_from_embeddings(output_path, embeddings_file)
|
||||||
|
print(f"Index saved to: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run recall evaluation on a LEANN index."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"index_path",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="Path to the LEANN index to evaluate or build (optional).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
choices=["evaluate", "build"],
|
||||||
|
default="evaluate",
|
||||||
|
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embeddings-file",
|
||||||
|
type=str,
|
||||||
|
help="Path to embeddings pickle file (optional for build mode)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
default="hnsw",
|
||||||
|
help="Backend to use for building index (default: hnsw)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# --- Path Configuration ---
|
||||||
|
# Assumes a project structure where the script is in 'examples/'
|
||||||
|
# and data is in 'data/' at the project root.
|
||||||
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
|
data_root = project_root / "data"
|
||||||
|
|
||||||
|
# Download data based on mode
|
||||||
|
if args.mode == "build":
|
||||||
|
# For building mode, we need embeddings
|
||||||
|
download_data_if_needed(
|
||||||
|
data_root, download_embeddings=False
|
||||||
|
) # Basic data first
|
||||||
|
|
||||||
|
# Auto-detect dataset type and download embeddings
|
||||||
|
if args.embeddings_file:
|
||||||
|
embeddings_file = args.embeddings_file
|
||||||
|
# Try to detect dataset type from embeddings file path
|
||||||
|
if "rpj_wiki" in str(embeddings_file):
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in str(embeddings_file):
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default
|
||||||
|
else:
|
||||||
|
# Auto-detect from index path if provided, otherwise default to DPR
|
||||||
|
if args.index_path:
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
else:
|
||||||
|
dataset_type = "dpr" # Default to DPR
|
||||||
|
|
||||||
|
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
||||||
|
|
||||||
|
# Auto-generate index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
indices_dir = data_root / "indices" / dataset_type
|
||||||
|
indices_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
||||||
|
print(f"Auto-generated index path: {args.index_path}")
|
||||||
|
|
||||||
|
print(f"Building index from embeddings: {embeddings_file}")
|
||||||
|
built_index_path = build_index_from_embeddings(
|
||||||
|
embeddings_file, args.index_path, args.backend
|
||||||
|
)
|
||||||
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
|
# Ask if user wants to run evaluation
|
||||||
|
eval_response = (
|
||||||
|
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
|
)
|
||||||
|
if eval_response != "y":
|
||||||
|
print("Index building complete. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# For evaluation mode, don't need embeddings
|
||||||
|
download_data_if_needed(data_root, download_embeddings=False)
|
||||||
|
|
||||||
|
# Auto-detect index path if not provided
|
||||||
|
if not args.index_path:
|
||||||
|
# Default to using downloaded indices
|
||||||
|
indices_dir = data_root / "indices"
|
||||||
|
|
||||||
|
# Try common datasets in order of preference
|
||||||
|
for dataset in ["dpr", "rpj_wiki"]:
|
||||||
|
dataset_dir = indices_dir / dataset
|
||||||
|
if dataset_dir.exists():
|
||||||
|
# Look for index files
|
||||||
|
index_files = list(dataset_dir.glob("*.index")) + list(
|
||||||
|
dataset_dir.glob("*_disk.index")
|
||||||
|
)
|
||||||
|
if index_files:
|
||||||
|
args.index_path = str(
|
||||||
|
index_files[0].with_suffix("")
|
||||||
|
) # Remove .index extension
|
||||||
|
print(f"Using index: {args.index_path}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not args.index_path:
|
||||||
|
print(
|
||||||
|
"No indices found. The data download should have included pre-built indices."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Please check the data/indices/ directory or provide --index-path manually."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
|
index_path_str = str(args.index_path)
|
||||||
|
if "rpj_wiki" in index_path_str:
|
||||||
|
dataset_type = "rpj_wiki"
|
||||||
|
elif "dpr" in index_path_str:
|
||||||
|
dataset_type = "dpr"
|
||||||
|
else:
|
||||||
|
# Fallback: try to infer from the index directory name
|
||||||
|
dataset_type = Path(args.index_path).name
|
||||||
|
print(
|
||||||
|
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
|
golden_results_file = (
|
||||||
|
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
|
print(f"INFO: Using ground truth file: {golden_results_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(args.index_path)
|
||||||
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
|
with open(golden_results_file, "r") as f:
|
||||||
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
|
queries = queries[:num_eval_queries]
|
||||||
|
|
||||||
|
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
||||||
|
recall_scores = []
|
||||||
|
search_times = []
|
||||||
|
|
||||||
|
for i in range(num_eval_queries):
|
||||||
|
start_time = time.time()
|
||||||
|
new_results = searcher.search(
|
||||||
|
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||||
|
)
|
||||||
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
# Get golden texts directly from the searcher's passage manager
|
||||||
|
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
||||||
|
golden_texts = get_golden_texts(searcher, golden_ids)
|
||||||
|
|
||||||
|
overlap = len(new_texts & golden_texts)
|
||||||
|
recall = overlap / len(golden_texts) if golden_texts else 0
|
||||||
|
recall_scores.append(recall)
|
||||||
|
|
||||||
|
print("\n--- EVALUATION RESULTS ---")
|
||||||
|
print(f"Query: {queries[i]}")
|
||||||
|
print(f"New Results: {new_texts}")
|
||||||
|
print(f"Golden Results: {golden_texts}")
|
||||||
|
print(f"Overlap: {overlap}")
|
||||||
|
print(f"Recall: {recall}")
|
||||||
|
print(f"Search Time: {search_times[-1]:.4f}s")
|
||||||
|
print("--------------------------------")
|
||||||
|
|
||||||
|
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
||||||
|
avg_time = np.mean(search_times) if search_times else 0
|
||||||
|
|
||||||
|
print("\n🎉 --- Evaluation Complete ---")
|
||||||
|
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
||||||
|
print(f"Avg. Search Time: {avg_time:.4f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ An error occurred during evaluation: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
apps/wechat/__init__.py
Normal file
0
apps/wechat/__init__.py
Normal file
230
apps/wechat/__main__.py
Normal file
230
apps/wechat/__main__.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any, Optional
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Default WeChat export directory
|
||||||
|
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs: List[Path],
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple WeChat export data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dirs: List of Path objects pointing to WeChat export directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process per export
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from local readers module
|
||||||
|
from .readers import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each WeChat export directory
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(
|
||||||
|
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=True, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=20,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=16,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function with integrated WeChat export functionality."""
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||||
|
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_history_magic_test_11Debug_new",
|
||||||
|
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Maximum number of chat entries to process (default: 5000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||||
|
|
||||||
|
print(f"Using WeChat export directory: {args.export_dir}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
from .readers import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
719
apps/wechat/readers.py
Normal file
719
apps/wechat/readers.py
Normal file
@@ -0,0 +1,719 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any, Dict, Optional
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class WeChatHistoryReader(BaseReader):
|
||||||
|
"""
|
||||||
|
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||||
|
|
||||||
|
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||||
|
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||||
|
|
||||||
|
Also includes utilities for automatic WeChat chat history export.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||||
|
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||||
|
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||||
|
|
||||||
|
def check_wechat_running(self) -> bool:
|
||||||
|
"""Check if WeChat is currently running."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
|
||||||
|
return result.returncode == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def install_wechattweak(self) -> bool:
|
||||||
|
"""Install WeChatTweak CLI tool."""
|
||||||
|
try:
|
||||||
|
# Create wechat-exporter directory if it doesn't exist
|
||||||
|
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
|
if not wechattweak_path.exists():
|
||||||
|
print("Downloading WeChatTweak CLI...")
|
||||||
|
subprocess.run([
|
||||||
|
"curl", "-L", "-o", str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
||||||
|
], check=True)
|
||||||
|
|
||||||
|
# Make executable
|
||||||
|
wechattweak_path.chmod(0o755)
|
||||||
|
|
||||||
|
# Install WeChatTweak
|
||||||
|
print("Installing WeChatTweak...")
|
||||||
|
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error installing WeChatTweak: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def restart_wechat(self):
|
||||||
|
"""Restart WeChat to apply WeChatTweak."""
|
||||||
|
try:
|
||||||
|
print("Restarting WeChat...")
|
||||||
|
subprocess.run(["pkill", "-f", "WeChat"], check=False)
|
||||||
|
time.sleep(2)
|
||||||
|
subprocess.run(["open", "-a", "WeChat"], check=True)
|
||||||
|
time.sleep(5) # Wait for WeChat to start
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error restarting WeChat: {e}")
|
||||||
|
|
||||||
|
def check_api_available(self) -> bool:
|
||||||
|
"""Check if WeChatTweak API is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run([
|
||||||
|
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
||||||
|
], capture_output=True, text=True, timeout=5)
|
||||||
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract readable text from message content, removing XML and system messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The raw message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned, readable text
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle dictionary content (like quoted messages)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Extract text from dictionary structure
|
||||||
|
text_parts = []
|
||||||
|
if 'title' in content:
|
||||||
|
text_parts.append(str(content['title']))
|
||||||
|
if 'quoted' in content:
|
||||||
|
text_parts.append(str(content['quoted']))
|
||||||
|
if 'content' in content:
|
||||||
|
text_parts.append(str(content['content']))
|
||||||
|
if 'text' in content:
|
||||||
|
text_parts.append(str(content['text']))
|
||||||
|
|
||||||
|
if text_parts:
|
||||||
|
return " | ".join(text_parts)
|
||||||
|
else:
|
||||||
|
# If we can't extract meaningful text from dict, return empty
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
|
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||||
|
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||||
|
|
||||||
|
# If it's just XML or system message, return empty
|
||||||
|
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return clean_content.strip()
|
||||||
|
|
||||||
|
def _is_text_message(self, content: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a message contains readable text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The message content (can be string or dict)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the message contains readable text, False otherwise
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle dictionary content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Check if dict has any readable text fields
|
||||||
|
text_fields = ['title', 'quoted', 'content', 'text']
|
||||||
|
for field in text_fields:
|
||||||
|
if field in content and content[field]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle string content
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip image messages (contain XML with img tags)
|
||||||
|
if '<img' in content and 'cdnurl' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
|
if '<emoji' in content and 'productid' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip voice messages
|
||||||
|
if '<voice' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip video messages
|
||||||
|
if '<video' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip file messages
|
||||||
|
if '<appmsg' in content and 'appid' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip system messages (like "recalled a message")
|
||||||
|
if 'recalled a message' in content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
|
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||||
|
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||||
|
|
||||||
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||||
|
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||||
|
overlap_messages: Number of messages to overlap between consecutive groups
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of concatenated message groups
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
concatenated_groups = []
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
last_timestamp = None
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Extract message info
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
from_user = message.get('fromUser', '')
|
||||||
|
to_user = message.get('toUser', '')
|
||||||
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Skip empty messages
|
||||||
|
if not readable_text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check time window constraint (only if time_window_minutes != -1)
|
||||||
|
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||||
|
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||||
|
if time_diff_minutes > time_window_minutes:
|
||||||
|
# Time gap too large, start new group
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append({
|
||||||
|
'messages': current_group,
|
||||||
|
'total_length': current_length,
|
||||||
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
|
})
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Check length constraint (only if max_length != -1)
|
||||||
|
message_length = len(readable_text)
|
||||||
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
|
# Current group would exceed max length, save it and start new
|
||||||
|
concatenated_groups.append({
|
||||||
|
'messages': current_group,
|
||||||
|
'total_length': current_length,
|
||||||
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
|
})
|
||||||
|
# Keep last few messages for overlap
|
||||||
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
|
current_group = current_group[-overlap_messages:]
|
||||||
|
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||||
|
else:
|
||||||
|
current_group = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# Add message to current group
|
||||||
|
current_group.append(message)
|
||||||
|
current_length += message_length
|
||||||
|
last_timestamp = create_time
|
||||||
|
|
||||||
|
# Add the last group if it exists
|
||||||
|
if current_group:
|
||||||
|
concatenated_groups.append({
|
||||||
|
'messages': current_group,
|
||||||
|
'total_length': current_length,
|
||||||
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
return concatenated_groups
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_group: Dictionary containing messages and metadata
|
||||||
|
contact_name: Name of the contact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted concatenated content
|
||||||
|
"""
|
||||||
|
messages = message_group['messages']
|
||||||
|
start_time = message_group['start_time']
|
||||||
|
end_time = message_group['end_time']
|
||||||
|
|
||||||
|
# Format timestamps
|
||||||
|
if start_time:
|
||||||
|
try:
|
||||||
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
|
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
start_time_str = str(start_time)
|
||||||
|
else:
|
||||||
|
start_time_str = "Unknown"
|
||||||
|
|
||||||
|
if end_time:
|
||||||
|
try:
|
||||||
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
|
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
end_time_str = str(end_time)
|
||||||
|
else:
|
||||||
|
end_time_str = "Unknown"
|
||||||
|
|
||||||
|
# Build concatenated message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
readable_text = message_text
|
||||||
|
|
||||||
|
# Format individual message
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||||
|
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||||
|
|
||||||
|
concatenated_text = "\n".join(message_parts)
|
||||||
|
|
||||||
|
# Create final document content
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
|
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
# TODO @yichuan give better format and rich info here!
|
||||||
|
doc_content = f"""
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content, contact_name
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing exported WeChat JSON files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of chat entries to read.
|
||||||
|
wechat_export_dir (str): Custom path to WeChat export directory.
|
||||||
|
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
|
||||||
|
concatenate_messages (bool): Whether to concatenate messages based on length rules.
|
||||||
|
max_length (int): Maximum length for concatenated message groups (default: 1000).
|
||||||
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
||||||
|
include_non_text = load_kwargs.get('include_non_text', False)
|
||||||
|
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
||||||
|
max_length = load_kwargs.get('max_length', 1000)
|
||||||
|
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
||||||
|
|
||||||
|
# Default WeChat export path
|
||||||
|
if wechat_export_dir is None:
|
||||||
|
wechat_export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(wechat_export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find all JSON files in the export directory
|
||||||
|
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||||
|
print(f"Found {len(json_files)} WeChat chat history files")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, 'r', encoding='utf-8') as f:
|
||||||
|
chat_data = json.load(f)
|
||||||
|
|
||||||
|
# Extract contact name from filename
|
||||||
|
contact_name = json_file.stem
|
||||||
|
|
||||||
|
if concatenate_messages:
|
||||||
|
# Filter messages to only include readable text messages
|
||||||
|
readable_messages = []
|
||||||
|
for message in chat_data:
|
||||||
|
try:
|
||||||
|
content = message.get('content', '')
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
readable_messages.append(message)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Concatenate messages based on rules
|
||||||
|
message_groups = self._concatenate_messages(
|
||||||
|
readable_messages,
|
||||||
|
max_length=-1,
|
||||||
|
time_window_minutes=-1,
|
||||||
|
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create documents from concatenated groups
|
||||||
|
for message_group in message_groups:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||||
|
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Original single-message processing
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract message information
|
||||||
|
from_user = message.get('fromUser', '')
|
||||||
|
to_user = message.get('toUser', '')
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
|
# Handle content that might be dict or string
|
||||||
|
try:
|
||||||
|
# Check if this is a readable text message
|
||||||
|
if not include_non_text and not self._is_text_message(content):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract readable text
|
||||||
|
readable_text = self._extract_readable_text(content)
|
||||||
|
if not readable_text and not include_non_text:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
# Skip messages that cause processing errors
|
||||||
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert timestamp to readable format
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
# Create document content with metadata header and contact info
|
||||||
|
doc_content = f"""
|
||||||
|
Contact: {contact_name}
|
||||||
|
Is sent from self: {is_sent_from_self}
|
||||||
|
Time: {time_str}
|
||||||
|
Message: {readable_text if readable_text else message_text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create document with embedded metadata
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading WeChat history: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_wechat_export_dirs() -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find all WeChat export directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for common export directory names
|
||||||
|
possible_dirs = [
|
||||||
|
Path("./wechat_export_test"),
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export")
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir in possible_dirs:
|
||||||
|
if export_dir.exists() and export_dir.is_dir():
|
||||||
|
json_files = list(export_dir.glob("*.json"))
|
||||||
|
if json_files:
|
||||||
|
export_dirs.append(export_dir)
|
||||||
|
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
||||||
|
|
||||||
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
|
return export_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
||||||
|
"""
|
||||||
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Path to the output file
|
||||||
|
max_count: Maximum number of entries to export
|
||||||
|
export_dir: Directory containing WeChat JSON files
|
||||||
|
include_non_text: Whether to include non-text messages
|
||||||
|
"""
|
||||||
|
if export_dir is None:
|
||||||
|
export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
|
if not os.path.exists(export_dir):
|
||||||
|
print(f"WeChat export directory not found at: {export_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
count = 0
|
||||||
|
for json_file in json_files:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, 'r', encoding='utf-8') as json_f:
|
||||||
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
|
contact_name = json_file.stem
|
||||||
|
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||||
|
|
||||||
|
for message in chat_data:
|
||||||
|
if count >= max_count and max_count > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
from_user = message.get('fromUser', '')
|
||||||
|
content = message.get('content', '')
|
||||||
|
message_text = message.get('message', '')
|
||||||
|
create_time = message.get('createTime', 0)
|
||||||
|
|
||||||
|
# Skip non-text messages unless requested
|
||||||
|
if not include_non_text:
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
if not reader._is_text_message(content):
|
||||||
|
continue
|
||||||
|
readable_text = reader._extract_readable_text(content)
|
||||||
|
if not readable_text:
|
||||||
|
continue
|
||||||
|
message_text = readable_text
|
||||||
|
|
||||||
|
if create_time:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
except:
|
||||||
|
time_str = str(create_time)
|
||||||
|
else:
|
||||||
|
time_str = "Unknown"
|
||||||
|
|
||||||
|
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Exported {count} chat entries to {output_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to export directory if successful, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_path = Path(export_dir)
|
||||||
|
export_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Exporting WeChat chat history to {export_path}...")
|
||||||
|
|
||||||
|
# Check if wechat-exporter directory exists
|
||||||
|
if not self.wechat_exporter_dir.exists():
|
||||||
|
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Install requirements if needed
|
||||||
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
|
if requirements_file.exists():
|
||||||
|
print("Installing wechat-exporter requirements...")
|
||||||
|
subprocess.run([
|
||||||
|
"uv", "pip", "install", "-r", str(requirements_file)
|
||||||
|
], check=True)
|
||||||
|
|
||||||
|
# Run the export command
|
||||||
|
print("Running wechat-exporter...")
|
||||||
|
result = subprocess.run([
|
||||||
|
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all", str(export_path)
|
||||||
|
], capture_output=True, text=True, check=True)
|
||||||
|
|
||||||
|
print("Export command output:")
|
||||||
|
print(result.stdout)
|
||||||
|
if result.stderr:
|
||||||
|
print("Export errors:")
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
# Check if export was successful
|
||||||
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
|
json_files = list(export_path.glob("*.json"))
|
||||||
|
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
||||||
|
return export_path
|
||||||
|
else:
|
||||||
|
print("Export completed but no JSON files found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Export command failed: {e}")
|
||||||
|
print(f"Command output: {e.stdout}")
|
||||||
|
print(f"Command errors: {e.stderr}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export failed: {e}")
|
||||||
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
||||||
|
"""
|
||||||
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Directory to save exported chat history if needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Path objects pointing to WeChat export directories
|
||||||
|
"""
|
||||||
|
export_dirs = []
|
||||||
|
|
||||||
|
# Look for existing exports in common locations
|
||||||
|
possible_export_dirs = [
|
||||||
|
Path("./wechat_database_export"),
|
||||||
|
Path("./wechat_export_test"),
|
||||||
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
|
Path("./wechat_chat_history"),
|
||||||
|
Path("./chat_export")
|
||||||
|
]
|
||||||
|
|
||||||
|
for export_dir_path in possible_export_dirs:
|
||||||
|
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||||
|
export_dirs.append(export_dir_path)
|
||||||
|
print(f"Found existing export: {export_dir_path}")
|
||||||
|
|
||||||
|
# If no existing exports, try to export automatically
|
||||||
|
if not export_dirs:
|
||||||
|
print("No existing WeChat exports found. Starting direct export...")
|
||||||
|
|
||||||
|
# Try to export using wechat-exporter
|
||||||
|
exported_path = self.export_wechat_chat_history(export_dir)
|
||||||
|
if exported_path:
|
||||||
|
export_dirs = [exported_path]
|
||||||
|
else:
|
||||||
|
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
|
|
||||||
|
return export_dirs
|
||||||
@@ -1,189 +0,0 @@
|
|||||||
"""
|
|
||||||
WeChat History RAG example using the unified interface.
|
|
||||||
Supports WeChat chat history export and search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add parent directory to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample
|
|
||||||
|
|
||||||
from .history_data.wechat_history import WeChatHistoryReader
|
|
||||||
|
|
||||||
|
|
||||||
class WeChatRAG(BaseRAGExample):
|
|
||||||
"""RAG example for WeChat chat history."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Set default values BEFORE calling super().__init__
|
|
||||||
self.max_items_default = -1 # Match original default
|
|
||||||
self.embedding_model_default = (
|
|
||||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
name="WeChat History",
|
|
||||||
description="Process and query WeChat chat history with LEANN",
|
|
||||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_specific_arguments(self, parser):
|
|
||||||
"""Add WeChat-specific arguments."""
|
|
||||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--export-dir",
|
|
||||||
type=str,
|
|
||||||
default="./wechat_export",
|
|
||||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
|
||||||
)
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--force-export",
|
|
||||||
action="store_true",
|
|
||||||
help="Force re-export of WeChat data even if exports exist",
|
|
||||||
)
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
|
||||||
)
|
|
||||||
wechat_group.add_argument(
|
|
||||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
|
||||||
"""Export WeChat data using wechattweak-cli."""
|
|
||||||
print("Exporting WeChat data...")
|
|
||||||
|
|
||||||
# Check if WeChat is running
|
|
||||||
try:
|
|
||||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print("WeChat is not running. Please start WeChat first.")
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
pass # pgrep might not be available on all systems
|
|
||||||
|
|
||||||
# Create export directory
|
|
||||||
export_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Run export command
|
|
||||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
print(f"Running: {' '.join(cmd)}")
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
|
||||||
print("WeChat data exported successfully!")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"Export failed: {result.stderr}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
print("\nError: wechattweak-cli not found!")
|
|
||||||
print("Please install it first:")
|
|
||||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Export error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
|
||||||
"""Load WeChat history and convert to text chunks."""
|
|
||||||
# Initialize WeChat reader with export capabilities
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
# Find existing exports or create new ones using the centralized method
|
|
||||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
|
||||||
if not export_dirs:
|
|
||||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
|
||||||
# Try to find any existing exports in common locations
|
|
||||||
export_dirs = reader.find_wechat_export_dirs()
|
|
||||||
if not export_dirs:
|
|
||||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Load documents from all found export directories
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
for i, export_dir in enumerate(export_dirs):
|
|
||||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply max_items limit per export
|
|
||||||
max_per_export = -1
|
|
||||||
if args.max_items > 0:
|
|
||||||
remaining = args.max_items - total_processed
|
|
||||||
if remaining <= 0:
|
|
||||||
break
|
|
||||||
max_per_export = remaining
|
|
||||||
|
|
||||||
documents = reader.load_data(
|
|
||||||
wechat_export_dir=str(export_dir),
|
|
||||||
max_count=max_per_export,
|
|
||||||
concatenate_messages=True, # Enable message concatenation for better context
|
|
||||||
)
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {export_dir}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
|
||||||
print("now starting to split into text chunks ... take some time")
|
|
||||||
|
|
||||||
# Convert to text chunks with contact information
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
text_splitter = SentenceSplitter(
|
|
||||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
|
||||||
)
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
# Add contact information to each chunk
|
|
||||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
|
||||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
|
||||||
return all_texts
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# Check platform
|
|
||||||
if sys.platform != "darwin":
|
|
||||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
|
||||||
print(" You can still query existing exports on other platforms\n")
|
|
||||||
|
|
||||||
# Example queries for WeChat RAG
|
|
||||||
print("\n💬 WeChat History RAG Example")
|
|
||||||
print("=" * 50)
|
|
||||||
print("\nExample queries you can try:")
|
|
||||||
print("- 'Show me conversations about travel plans'")
|
|
||||||
print("- 'Find group chats about weekend activities'")
|
|
||||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
|
||||||
print("- 'What did we discuss about the project last month?'")
|
|
||||||
print("\nNote: WeChat must be running for export to work\n")
|
|
||||||
|
|
||||||
rag = WeChatRAG()
|
|
||||||
asyncio.run(rag.run())
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 73 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 224 KiB |
@@ -1,148 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
def _meta_exists(index_path: str) -> bool:
|
|
||||||
p = Path(index_path)
|
|
||||||
return (p.parent / f"{p.stem}.meta.json").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
|
||||||
# if _meta_exists(index_path):
|
|
||||||
# return
|
|
||||||
kwargs = {}
|
|
||||||
if backend_name == "hnsw":
|
|
||||||
kwargs["is_compact"] = is_recompute
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=backend_name,
|
|
||||||
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
|
||||||
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_recompute=is_recompute,
|
|
||||||
num_threads=4,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for i in range(num_docs):
|
|
||||||
builder.add_text(
|
|
||||||
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
|
||||||
)
|
|
||||||
builder.build_index(index_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _bench_group(
|
|
||||||
index_path: str,
|
|
||||||
recompute: bool,
|
|
||||||
query: str,
|
|
||||||
repeats: int,
|
|
||||||
complexity: int = 32,
|
|
||||||
top_k: int = 10,
|
|
||||||
) -> float:
|
|
||||||
# Independent searcher per group; fixed port when recompute
|
|
||||||
searcher = LeannSearcher(index_path=index_path)
|
|
||||||
|
|
||||||
# Warm-up once
|
|
||||||
_ = searcher.search(
|
|
||||||
query,
|
|
||||||
top_k=top_k,
|
|
||||||
complexity=complexity,
|
|
||||||
recompute_embeddings=recompute,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _once() -> float:
|
|
||||||
t0 = time.time()
|
|
||||||
_ = searcher.search(
|
|
||||||
query,
|
|
||||||
top_k=top_k,
|
|
||||||
complexity=complexity,
|
|
||||||
recompute_embeddings=recompute,
|
|
||||||
)
|
|
||||||
return time.time() - t0
|
|
||||||
|
|
||||||
if repeats <= 1:
|
|
||||||
t = _once()
|
|
||||||
else:
|
|
||||||
vals = [_once() for _ in range(repeats)]
|
|
||||||
vals.sort()
|
|
||||||
t = vals[len(vals) // 2]
|
|
||||||
|
|
||||||
searcher.cleanup()
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--num-docs", type=int, default=5000)
|
|
||||||
parser.add_argument("--repeats", type=int, default=3)
|
|
||||||
parser.add_argument("--complexity", type=int, default=32)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
|
||||||
base.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
# ---------- Build HNSW variants ----------
|
|
||||||
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
|
||||||
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
|
||||||
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
|
||||||
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
|
||||||
|
|
||||||
# ---------- Build DiskANN variants ----------
|
|
||||||
diskann_r = str(base / "diskann_r.leann")
|
|
||||||
diskann_nr = str(base / "diskann_nr.leann")
|
|
||||||
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
|
||||||
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
|
||||||
|
|
||||||
# ---------- Helpers ----------
|
|
||||||
def _size_for(prefix: str) -> int:
|
|
||||||
p = Path(prefix)
|
|
||||||
base_dir = p.parent
|
|
||||||
stem = p.stem
|
|
||||||
total = 0
|
|
||||||
for f in base_dir.iterdir():
|
|
||||||
if f.is_file() and f.name.startswith(stem):
|
|
||||||
total += f.stat().st_size
|
|
||||||
return total
|
|
||||||
|
|
||||||
# ---------- HNSW benchmark ----------
|
|
||||||
t_hnsw_r = _bench_group(
|
|
||||||
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
|
||||||
)
|
|
||||||
t_hnsw_nr = _bench_group(
|
|
||||||
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
|
||||||
)
|
|
||||||
size_hnsw_r = _size_for(hnsw_r)
|
|
||||||
size_hnsw_nr = _size_for(hnsw_nr)
|
|
||||||
|
|
||||||
print("Benchmark results (HNSW):")
|
|
||||||
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
|
||||||
print(
|
|
||||||
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
|
||||||
)
|
|
||||||
print(" Expectation: no-recompute should be faster but larger on disk.")
|
|
||||||
|
|
||||||
# ---------- DiskANN benchmark ----------
|
|
||||||
t_diskann_r = _bench_group(
|
|
||||||
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
|
||||||
)
|
|
||||||
t_diskann_nr = _bench_group(
|
|
||||||
diskann_nr,
|
|
||||||
False,
|
|
||||||
"DiskANN NR test doc 123",
|
|
||||||
repeats=args.repeats,
|
|
||||||
complexity=args.complexity,
|
|
||||||
)
|
|
||||||
size_diskann_r = _size_for(diskann_r)
|
|
||||||
size_diskann_nr = _size_for(diskann_nr)
|
|
||||||
|
|
||||||
print("\nBenchmark results (DiskANN):")
|
|
||||||
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
|
||||||
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
|
||||||
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
|
||||||
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
DiskANN vs HNSW Search Performance Comparison
|
|
||||||
|
|
||||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
|
||||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
|
||||||
- HNSW: With recompute enabled (is_recompute=True)
|
|
||||||
- Tests performance across different dataset sizes
|
|
||||||
- Measures search latency, recall, and index size
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import multiprocessing as mp
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
|
||||||
try:
|
|
||||||
mp.set_start_method("fork", force=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_texts(n_docs: int) -> list[str]:
|
|
||||||
"""Create synthetic test documents for benchmarking."""
|
|
||||||
np.random.seed(42)
|
|
||||||
topics = [
|
|
||||||
"machine learning and artificial intelligence",
|
|
||||||
"natural language processing and text analysis",
|
|
||||||
"computer vision and image recognition",
|
|
||||||
"data science and statistical analysis",
|
|
||||||
"deep learning and neural networks",
|
|
||||||
"information retrieval and search engines",
|
|
||||||
"database systems and data management",
|
|
||||||
"software engineering and programming",
|
|
||||||
"cybersecurity and network protection",
|
|
||||||
"cloud computing and distributed systems",
|
|
||||||
]
|
|
||||||
|
|
||||||
texts = []
|
|
||||||
for i in range(n_docs):
|
|
||||||
topic = topics[i % len(topics)]
|
|
||||||
variation = np.random.randint(1, 100)
|
|
||||||
text = (
|
|
||||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
|
||||||
f"Additional information about {topic} with details and examples. "
|
|
||||||
f"Technical discussion of {topic} including implementation aspects."
|
|
||||||
)
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
return texts
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_backend(
|
|
||||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
|
||||||
) -> dict[str, float]:
|
|
||||||
"""Benchmark a specific backend with the given configuration."""
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
|
||||||
|
|
||||||
# Build index
|
|
||||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=backend_name,
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
embedding_mode="sentence-transformers",
|
|
||||||
**backend_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
builder.add_text(text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
build_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Measure index size
|
|
||||||
index_dir = Path(index_path).parent
|
|
||||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
|
||||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
|
||||||
size_mb = total_size / (1024 * 1024)
|
|
||||||
|
|
||||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
|
||||||
|
|
||||||
# Search benchmark
|
|
||||||
print("🔍 Running search benchmark...")
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
|
|
||||||
search_times = []
|
|
||||||
all_results = []
|
|
||||||
|
|
||||||
for query in test_queries:
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(query, top_k=5)
|
|
||||||
search_time = time.time() - start_time
|
|
||||||
search_times.append(search_time)
|
|
||||||
all_results.append(results)
|
|
||||||
|
|
||||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
|
||||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
|
||||||
|
|
||||||
# Check for valid scores (detect -inf issues)
|
|
||||||
all_scores = [
|
|
||||||
result.score
|
|
||||||
for results in all_results
|
|
||||||
for result in results
|
|
||||||
if result.score is not None
|
|
||||||
]
|
|
||||||
valid_scores = [
|
|
||||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
|
||||||
]
|
|
||||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
|
||||||
|
|
||||||
# Clean up (ensure embedding server shutdown and object GC)
|
|
||||||
try:
|
|
||||||
if hasattr(searcher, "cleanup"):
|
|
||||||
searcher.cleanup()
|
|
||||||
del searcher
|
|
||||||
del builder
|
|
||||||
gc.collect()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"build_time": build_time,
|
|
||||||
"avg_search_time_ms": avg_search_time,
|
|
||||||
"index_size_mb": size_mb,
|
|
||||||
"score_validity_rate": score_validity_rate,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
|
||||||
"""Run performance comparison between DiskANN and HNSW."""
|
|
||||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
|
||||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
|
||||||
|
|
||||||
# Create test data
|
|
||||||
texts = create_test_texts(n_docs)
|
|
||||||
test_queries = [
|
|
||||||
"machine learning algorithms",
|
|
||||||
"natural language processing",
|
|
||||||
"computer vision techniques",
|
|
||||||
"data analysis methods",
|
|
||||||
"neural network architectures",
|
|
||||||
"database query optimization",
|
|
||||||
"software development practices",
|
|
||||||
"security vulnerabilities",
|
|
||||||
"cloud infrastructure",
|
|
||||||
"distributed computing",
|
|
||||||
][:n_queries]
|
|
||||||
|
|
||||||
# HNSW benchmark
|
|
||||||
hnsw_results = benchmark_backend(
|
|
||||||
backend_name="hnsw",
|
|
||||||
texts=texts,
|
|
||||||
test_queries=test_queries,
|
|
||||||
backend_kwargs={
|
|
||||||
"is_recompute": True, # Enable recompute for fair comparison
|
|
||||||
"M": 16,
|
|
||||||
"efConstruction": 200,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# DiskANN benchmark
|
|
||||||
diskann_results = benchmark_backend(
|
|
||||||
backend_name="diskann",
|
|
||||||
texts=texts,
|
|
||||||
test_queries=test_queries,
|
|
||||||
backend_kwargs={
|
|
||||||
"is_recompute": True, # Enable graph partitioning
|
|
||||||
"num_neighbors": 32,
|
|
||||||
"search_list_size": 50,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Performance comparison
|
|
||||||
print("\n📈 Performance Comparison Results")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
|
||||||
print(f"{'-' * 60}")
|
|
||||||
|
|
||||||
# Build time comparison
|
|
||||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
|
||||||
print(
|
|
||||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search time comparison
|
|
||||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
|
||||||
print(
|
|
||||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index size comparison
|
|
||||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
|
||||||
print(
|
|
||||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Score validity
|
|
||||||
print(
|
|
||||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
print("\n🎯 Summary:")
|
|
||||||
if search_speedup > 1:
|
|
||||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
|
||||||
else:
|
|
||||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
|
||||||
|
|
||||||
if size_ratio > 1:
|
|
||||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
|
||||||
else:
|
|
||||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
|
||||||
|
|
||||||
print(
|
|
||||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Handle help request
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
|
||||||
print("DiskANN vs HNSW Performance Comparison")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
|
||||||
print()
|
|
||||||
print("Arguments:")
|
|
||||||
print(" n_docs Number of documents to index (default: 500)")
|
|
||||||
print(" n_queries Number of test queries to run (default: 10)")
|
|
||||||
print()
|
|
||||||
print("Examples:")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
|
||||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Parse command line arguments
|
|
||||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
|
||||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
|
||||||
|
|
||||||
print("DiskANN vs HNSW Performance Comparison")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
|
||||||
print()
|
|
||||||
|
|
||||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n⚠️ Benchmark interrupted by user")
|
|
||||||
sys.exit(130)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Benchmark failed: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
finally:
|
|
||||||
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
|
||||||
try:
|
|
||||||
gc.collect()
|
|
||||||
print("\n🧹 Cleanup completed")
|
|
||||||
# Flush stdio to ensure message is visible before hard-exit
|
|
||||||
try:
|
|
||||||
import sys as _sys
|
|
||||||
|
|
||||||
_sys.stdout.flush()
|
|
||||||
_sys.stderr.flush()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
|
||||||
import os as _os
|
|
||||||
|
|
||||||
_os._exit(0)
|
|
||||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
44
data/README.md
Normal file
44
data/README.md
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
---
|
||||||
|
|
||||||
|
# LEANN-RAG Evaluation Data
|
||||||
|
|
||||||
|
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||||
|
|
||||||
|
## Dataset Components
|
||||||
|
|
||||||
|
This dataset is structured into three main parts:
|
||||||
|
|
||||||
|
1. **Pre-built LEANN Indices**:
|
||||||
|
* `dpr/`: A pre-built index for the DPR dataset.
|
||||||
|
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||||
|
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||||
|
|
||||||
|
2. **Ground Truth Data**:
|
||||||
|
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||||
|
|
||||||
|
3. **Queries**:
|
||||||
|
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir="data"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||||
105
demo.ipynb
105
demo.ipynb
@@ -1,116 +1,37 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Quick Start \n",
|
|
||||||
"\n",
|
|
||||||
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
|
||||||
"\n",
|
|
||||||
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# install this if you are using colab\n",
|
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
||||||
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
|
||||||
"! uv pip install leann --no-deps\n",
|
|
||||||
"# For Colab environment, we need to set some environment variables\n",
|
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from pathlib import Path\n",
|
|
||||||
"\n",
|
|
||||||
"INDEX_DIR = Path(\"./\").resolve()\n",
|
|
||||||
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Build the index"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from leann.api import LeannBuilder\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
|
"# 1. Build the index (no embeddings stored!)\n",
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
"builder.add_text(\"C# is a powerful programming language\")\n",
|
||||||
"builder.add_text(\n",
|
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
||||||
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
|
||||||
")\n",
|
|
||||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.build_index(INDEX_PATH)"
|
"builder.build_index(\"knowledge.leann\")\n",
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Search with real-time embeddings"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from leann.api import LeannSearcher\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"searcher = LeannSearcher(INDEX_PATH)\n",
|
"# 2. Search with real-time embeddings\n",
|
||||||
|
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
"results"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Chat with LEANN using retrieved results"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from leann.api import LeannChat\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
|
"# 3. Chat with LEANN using retrieved results\n",
|
||||||
"llm_config = {\n",
|
"llm_config = {\n",
|
||||||
" \"type\": \"hf\",\n",
|
" \"type\": \"ollama\",\n",
|
||||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
" \"model\": \"llama3.2:1b\"\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
" llm_kwargs={\"max_tokens\": 128},\n",
|
")"
|
||||||
")\n",
|
|
||||||
"response"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
# 🤝 Contributing
|
|
||||||
|
|
||||||
We welcome contributions! Leann is built by the community, for the community.
|
|
||||||
|
|
||||||
## Ways to Contribute
|
|
||||||
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
|
||||||
- 📖 **Documentation**: Help make Leann more accessible
|
|
||||||
- 🧪 **Benchmarks**: Share your performance results
|
|
||||||
|
|
||||||
## 🚀 Development Setup
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
1. **Install uv** (fast Python package installer):
|
|
||||||
```bash
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Clone the repository**:
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
|
||||||
cd LEANN-RAG
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Install system dependencies**:
|
|
||||||
|
|
||||||
**macOS:**
|
|
||||||
```bash
|
|
||||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
|
||||||
```
|
|
||||||
|
|
||||||
**Ubuntu/Debian:**
|
|
||||||
```bash
|
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
|
||||||
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **Build from source**:
|
|
||||||
```bash
|
|
||||||
# macOS
|
|
||||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
|
||||||
|
|
||||||
# Ubuntu/Debian
|
|
||||||
uv sync
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔨 Pre-commit Hooks
|
|
||||||
|
|
||||||
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
|
||||||
|
|
||||||
### Setup Pre-commit
|
|
||||||
|
|
||||||
1. **Install pre-commit** (already included when you run `uv sync`):
|
|
||||||
```bash
|
|
||||||
uv pip install pre-commit
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Install the git hooks**:
|
|
||||||
```bash
|
|
||||||
pre-commit install
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Run pre-commit manually** (optional):
|
|
||||||
```bash
|
|
||||||
pre-commit run --all-files
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pre-commit Checks
|
|
||||||
|
|
||||||
Our pre-commit configuration includes:
|
|
||||||
- **Trailing whitespace removal**
|
|
||||||
- **End-of-file fixing**
|
|
||||||
- **YAML validation**
|
|
||||||
- **Large file prevention**
|
|
||||||
- **Merge conflict detection**
|
|
||||||
- **Debug statement detection**
|
|
||||||
- **Code formatting with ruff**
|
|
||||||
- **Code linting with ruff**
|
|
||||||
|
|
||||||
## 🧪 Testing
|
|
||||||
|
|
||||||
### Running Tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all tests
|
|
||||||
uv run pytest
|
|
||||||
|
|
||||||
# Run specific test file
|
|
||||||
uv run pytest test/test_filename.py
|
|
||||||
|
|
||||||
# Run with coverage
|
|
||||||
uv run pytest --cov=leann
|
|
||||||
```
|
|
||||||
|
|
||||||
### Writing Tests
|
|
||||||
|
|
||||||
- Place tests in the `test/` directory
|
|
||||||
- Follow the naming convention `test_*.py`
|
|
||||||
- Use descriptive test names that explain what's being tested
|
|
||||||
- Include both positive and negative test cases
|
|
||||||
|
|
||||||
## 📝 Code Style
|
|
||||||
|
|
||||||
We use `ruff` for both linting and formatting to ensure consistent code style.
|
|
||||||
|
|
||||||
### Format Your Code
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Format all files
|
|
||||||
ruff format
|
|
||||||
|
|
||||||
# Check formatting without changing files
|
|
||||||
ruff format --check
|
|
||||||
```
|
|
||||||
|
|
||||||
### Lint Your Code
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run linter with auto-fix
|
|
||||||
ruff check --fix
|
|
||||||
|
|
||||||
# Just check without fixing
|
|
||||||
ruff check
|
|
||||||
```
|
|
||||||
|
|
||||||
### Style Guidelines
|
|
||||||
|
|
||||||
- Follow PEP 8 conventions
|
|
||||||
- Use descriptive variable names
|
|
||||||
- Add type hints where appropriate
|
|
||||||
- Write docstrings for all public functions and classes
|
|
||||||
- Keep functions focused and single-purpose
|
|
||||||
|
|
||||||
## 🚦 CI/CD
|
|
||||||
|
|
||||||
Our CI pipeline runs automatically on all pull requests. It includes:
|
|
||||||
|
|
||||||
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
|
||||||
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
|
||||||
3. **Python version matrix**: Tests on Python 3.9-3.13
|
|
||||||
4. **Wheel building**: Ensures packages can be built and distributed
|
|
||||||
|
|
||||||
### CI Commands
|
|
||||||
|
|
||||||
The CI uses the same commands as pre-commit to ensure consistency:
|
|
||||||
```bash
|
|
||||||
# Linting
|
|
||||||
ruff check .
|
|
||||||
|
|
||||||
# Format checking
|
|
||||||
ruff format --check .
|
|
||||||
```
|
|
||||||
|
|
||||||
Make sure your code passes these checks locally before pushing!
|
|
||||||
|
|
||||||
## 🔄 Pull Request Process
|
|
||||||
|
|
||||||
1. **Fork the repository** and create your branch from `main`:
|
|
||||||
```bash
|
|
||||||
git checkout -b feature/your-feature-name
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Make your changes**:
|
|
||||||
- Write clean, documented code
|
|
||||||
- Add tests for new functionality
|
|
||||||
- Update documentation as needed
|
|
||||||
|
|
||||||
3. **Run pre-commit checks**:
|
|
||||||
```bash
|
|
||||||
pre-commit run --all-files
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **Test your changes**:
|
|
||||||
```bash
|
|
||||||
uv run pytest
|
|
||||||
```
|
|
||||||
|
|
||||||
5. **Commit with descriptive messages**:
|
|
||||||
```bash
|
|
||||||
git commit -m "feat: add new search algorithm"
|
|
||||||
```
|
|
||||||
|
|
||||||
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
|
||||||
- `feat:` for new features
|
|
||||||
- `fix:` for bug fixes
|
|
||||||
- `docs:` for documentation changes
|
|
||||||
- `test:` for test additions/changes
|
|
||||||
- `refactor:` for code refactoring
|
|
||||||
- `perf:` for performance improvements
|
|
||||||
|
|
||||||
6. **Push and create a pull request**:
|
|
||||||
- Provide a clear description of your changes
|
|
||||||
- Reference any related issues
|
|
||||||
- Include examples or screenshots if applicable
|
|
||||||
|
|
||||||
## 📚 Documentation
|
|
||||||
|
|
||||||
When adding new features or making significant changes:
|
|
||||||
|
|
||||||
1. Update relevant documentation in `/docs`
|
|
||||||
2. Add docstrings to new functions/classes
|
|
||||||
3. Update README.md if needed
|
|
||||||
4. Include usage examples
|
|
||||||
|
|
||||||
## 🤔 Getting Help
|
|
||||||
|
|
||||||
- **Discord**: Join our community for discussions
|
|
||||||
- **Issues**: Check existing issues or create a new one
|
|
||||||
- **Discussions**: For general questions and ideas
|
|
||||||
|
|
||||||
## 📄 License
|
|
||||||
|
|
||||||
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
# Release Guide
|
|
||||||
|
|
||||||
## Setup (One-time)
|
|
||||||
|
|
||||||
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
|
||||||
1. Get token: https://pypi.org/manage/account/token/
|
|
||||||
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
|
||||||
|
|
||||||
## Release (One-click)
|
|
||||||
|
|
||||||
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
|
||||||
2. Click "Run workflow"
|
|
||||||
3. Enter version: `0.1.2`
|
|
||||||
4. Click green "Run workflow" button
|
|
||||||
|
|
||||||
That's it! The workflow will automatically:
|
|
||||||
- ✅ Update version in all packages
|
|
||||||
- ✅ Build all packages
|
|
||||||
- ✅ Publish to PyPI
|
|
||||||
- ✅ Create GitHub tag and release
|
|
||||||
|
|
||||||
Check progress: https://github.com/yichuan-w/LEANN/actions
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
# Thinking Budget Feature Implementation
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
|
||||||
|
|
||||||
## Feature Description
|
|
||||||
|
|
||||||
The thinking budget feature provides three levels of computational effort for reasoning models:
|
|
||||||
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
|
||||||
- **`medium`**: Balanced speed and reasoning depth
|
|
||||||
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
### 1. Command Line Interface
|
|
||||||
|
|
||||||
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# LEANN CLI
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# RAG Examples
|
|
||||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. LLM Backend Support
|
|
||||||
|
|
||||||
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
options = kwargs.copy()
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget:
|
|
||||||
options.pop("thinking_budget", None)
|
|
||||||
if thinking_budget in ["low", "medium", "high"]:
|
|
||||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
|
||||||
```
|
|
||||||
|
|
||||||
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
|
||||||
|
|
||||||
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
|
||||||
# Check if this is an o-series model
|
|
||||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
|
||||||
if any(model in self.model for model in o_series_models):
|
|
||||||
params["reasoning_effort"] = thinking_budget
|
|
||||||
```
|
|
||||||
|
|
||||||
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
|
||||||
|
|
||||||
### 3. Parameter Propagation
|
|
||||||
|
|
||||||
The thinking budget parameter is properly propagated through the LEANN architecture:
|
|
||||||
|
|
||||||
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
|
||||||
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
|
||||||
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
|
||||||
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
|
||||||
|
|
||||||
## Files Modified
|
|
||||||
|
|
||||||
### Core Implementation
|
|
||||||
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
|
||||||
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
|
||||||
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
- `README.md`: Added thinking budget parameter to usage examples
|
|
||||||
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
|
||||||
|
|
||||||
### Examples
|
|
||||||
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
|
||||||
|
|
||||||
## Usage Examples
|
|
||||||
|
|
||||||
### Basic Usage
|
|
||||||
```bash
|
|
||||||
# High reasoning effort for complex questions
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# Medium reasoning for balanced performance
|
|
||||||
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
|
||||||
|
|
||||||
# Low reasoning for fast responses
|
|
||||||
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
|
||||||
```
|
|
||||||
|
|
||||||
### RAG Examples
|
|
||||||
```bash
|
|
||||||
# Email RAG with high reasoning
|
|
||||||
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
|
|
||||||
# Document RAG with medium reasoning
|
|
||||||
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
|
||||||
```
|
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
### Ollama Models
|
|
||||||
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
|
||||||
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
|
||||||
|
|
||||||
### OpenAI Models
|
|
||||||
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
|
||||||
- **GPT-OSS models**: Models that support reasoning capabilities
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
The implementation includes comprehensive testing:
|
|
||||||
- Parameter handling verification
|
|
||||||
- Backend-specific API format validation
|
|
||||||
- CLI argument parsing tests
|
|
||||||
- Integration with existing LEANN architecture
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
"""
|
|
||||||
Comparison between Sentence Transformers and OpenAI embeddings
|
|
||||||
|
|
||||||
This example shows how different embedding models handle complex queries
|
|
||||||
and demonstrates the differences between local and API-based embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from leann.embedding_compute import compute_embeddings
|
|
||||||
|
|
||||||
# OpenAI API key should be set as environment variable
|
|
||||||
# export OPENAI_API_KEY="your-api-key-here"
|
|
||||||
|
|
||||||
# Test data
|
|
||||||
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
|
||||||
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
|
||||||
|
|
||||||
# Two queries with same intent but different wording
|
|
||||||
query1 = "Tell me my browser history about some conference i often visit"
|
|
||||||
query2 = "browser history about conference I often visit"
|
|
||||||
|
|
||||||
texts = [query1, query2, conference_text, browser_text]
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(a, b):
|
|
||||||
return np.dot(a, b) # Already normalized
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_embeddings(embeddings, model_name):
|
|
||||||
print(f"\n=== {model_name} Results ===")
|
|
||||||
|
|
||||||
# Results for Query 1
|
|
||||||
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
|
||||||
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
|
||||||
|
|
||||||
print(f"Query 1: '{query1}'")
|
|
||||||
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
|
||||||
print(
|
|
||||||
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
|
||||||
)
|
|
||||||
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
|
||||||
|
|
||||||
# Results for Query 2
|
|
||||||
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
|
||||||
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
|
||||||
|
|
||||||
print(f"\nQuery 2: '{query2}'")
|
|
||||||
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
|
||||||
print(
|
|
||||||
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
|
||||||
)
|
|
||||||
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
|
||||||
|
|
||||||
# Show the impact
|
|
||||||
print("\n=== Impact Analysis ===")
|
|
||||||
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
|
||||||
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
|
||||||
|
|
||||||
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
|
||||||
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
|
||||||
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
|
||||||
print("✅ STABLE: Conference remains winner in both queries")
|
|
||||||
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
|
||||||
print("✅ STABLE: Browser remains winner in both queries")
|
|
||||||
else:
|
|
||||||
print("🔄 MIXED: Results vary between queries")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"query1_conf": sim1_conf,
|
|
||||||
"query1_browser": sim1_browser,
|
|
||||||
"query2_conf": sim2_conf,
|
|
||||||
"query2_browser": sim2_browser,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Test Sentence Transformers
|
|
||||||
print("Testing Sentence Transformers (facebook/contriever)...")
|
|
||||||
try:
|
|
||||||
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
|
||||||
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Sentence Transformers failed: {e}")
|
|
||||||
st_results = None
|
|
||||||
|
|
||||||
# Test OpenAI
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Testing OpenAI (text-embedding-3-small)...")
|
|
||||||
try:
|
|
||||||
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
|
||||||
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ OpenAI failed: {e}")
|
|
||||||
openai_results = None
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
if st_results and openai_results:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("=== COMPARISON SUMMARY ===")
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
# LEANN Configuration Guide
|
|
||||||
|
|
||||||
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
|
||||||
|
|
||||||
## Getting Started: Simple is Better
|
|
||||||
|
|
||||||
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
|
||||||
|
|
||||||
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
|
||||||
```bash
|
|
||||||
python -m apps.document_rag --query "What techniques does LEANN use?"
|
|
||||||
```
|
|
||||||
|
|
||||||
**For other data sources**: Limit the dataset size for quick testing
|
|
||||||
```bash
|
|
||||||
# WeChat: Test with recent messages only
|
|
||||||
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
|
||||||
|
|
||||||
# Browser history: Last few days
|
|
||||||
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
|
||||||
|
|
||||||
# Email: Recent inbox
|
|
||||||
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
|
||||||
```
|
|
||||||
|
|
||||||
Once validated, scale up gradually:
|
|
||||||
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
|
||||||
- This helps identify issues early before committing to long processing times
|
|
||||||
|
|
||||||
## Embedding Model Selection: Understanding the Trade-offs
|
|
||||||
|
|
||||||
Based on our experience developing LEANN, embedding models fall into three categories:
|
|
||||||
|
|
||||||
### Small Models (< 100M parameters)
|
|
||||||
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
|
||||||
- **Pros**: Lightweight, fast for both indexing and inference
|
|
||||||
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
|
||||||
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
|
||||||
|
|
||||||
### Medium Models (100M-500M parameters)
|
|
||||||
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
|
||||||
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
|
||||||
- **Cons**: Requires more compute than small models
|
|
||||||
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
|
||||||
|
|
||||||
### Large Models (500M+ parameters)
|
|
||||||
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
|
||||||
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
|
||||||
- **Cons**: Slower inference, longer index build times
|
|
||||||
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
|
||||||
|
|
||||||
### Quick Start: Cloud and Local Embedding Options
|
|
||||||
|
|
||||||
**OpenAI Embeddings (Fastest Setup)**
|
|
||||||
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
|
||||||
```bash
|
|
||||||
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
|
||||||
--embedding-mode openai --embedding-model text-embedding-3-small
|
|
||||||
```
|
|
||||||
|
|
||||||
**Ollama Embeddings (Privacy-Focused)**
|
|
||||||
For local embeddings with complete privacy:
|
|
||||||
```bash
|
|
||||||
# First, pull an embedding model
|
|
||||||
ollama pull nomic-embed-text
|
|
||||||
|
|
||||||
# Use Ollama embeddings
|
|
||||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
|
||||||
```
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
|
||||||
|
|
||||||
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
|
||||||
- **Pros**: No local compute needed, consistently fast, high quality
|
|
||||||
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
|
||||||
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
|
||||||
|
|
||||||
**Local Embeddings**
|
|
||||||
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
|
||||||
- **Cons**: Slower than cloud APIs, requires local compute resources
|
|
||||||
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
|
||||||
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
|
||||||
- Full recomputation required
|
|
||||||
- High memory usage during build phase
|
|
||||||
- Excellent recall (95%+)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Optimal for most use cases
|
|
||||||
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
|
||||||
```
|
|
||||||
|
|
||||||
### DiskANN
|
|
||||||
**Best for**: Large datasets, especially when you want `recompute=True`.
|
|
||||||
|
|
||||||
**Key advantages:**
|
|
||||||
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
|
||||||
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
|
||||||
- **Better scaling**: Designed for 100k+ documents
|
|
||||||
|
|
||||||
**Recompute behavior:**
|
|
||||||
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
|
||||||
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Recommended for most use cases
|
|
||||||
--backend-name diskann --graph-degree 32 --build-complexity 64
|
|
||||||
```
|
|
||||||
|
|
||||||
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
|
||||||
|
|
||||||
## LLM Selection: Engine and Model Comparison
|
|
||||||
|
|
||||||
### LLM Engines
|
|
||||||
|
|
||||||
**OpenAI** (`--llm openai`)
|
|
||||||
- **Pros**: Best quality, consistent performance, no local resources needed
|
|
||||||
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
|
||||||
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
|
||||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
|
||||||
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
|
||||||
|
|
||||||
**Ollama** (`--llm ollama`)
|
|
||||||
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
|
||||||
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
|
||||||
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
|
||||||
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
|
||||||
|
|
||||||
**HuggingFace** (`--llm hf`)
|
|
||||||
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
|
||||||
- **Cons**: More complex initial setup
|
|
||||||
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
|
||||||
|
|
||||||
## Parameter Tuning Guide
|
|
||||||
|
|
||||||
### Search Complexity Parameters
|
|
||||||
|
|
||||||
**`--build-complexity`** (index building)
|
|
||||||
- Controls thoroughness during index construction
|
|
||||||
- Higher = better recall but slower build
|
|
||||||
- Recommendations:
|
|
||||||
- 32: Quick prototyping
|
|
||||||
- 64: Balanced (default)
|
|
||||||
- 128: Production systems
|
|
||||||
- 256: Maximum quality
|
|
||||||
|
|
||||||
**`--search-complexity`** (query time)
|
|
||||||
- Controls search thoroughness
|
|
||||||
- Higher = better results but slower
|
|
||||||
- Recommendations:
|
|
||||||
- 16: Fast/Interactive search
|
|
||||||
- 32: High quality with diversity
|
|
||||||
- 64+: Maximum accuracy
|
|
||||||
|
|
||||||
### Top-K Selection
|
|
||||||
|
|
||||||
**`--top-k`** (number of retrieved chunks)
|
|
||||||
- More chunks = better context but slower LLM processing
|
|
||||||
- Should be always smaller than `--search-complexity`
|
|
||||||
- Guidelines:
|
|
||||||
- 10-20: General questions (default: 20)
|
|
||||||
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
|
||||||
|
|
||||||
**Trade-off formula**:
|
|
||||||
- Retrieval time ∝ log(n) × search_complexity
|
|
||||||
- LLM processing time ∝ top_k × chunk_size
|
|
||||||
- Total context = top_k × chunk_size tokens
|
|
||||||
|
|
||||||
### Thinking Budget for Reasoning Models
|
|
||||||
|
|
||||||
**`--thinking-budget`** (reasoning effort level)
|
|
||||||
- Controls the computational effort for reasoning models
|
|
||||||
- Options: `low`, `medium`, `high`
|
|
||||||
- Guidelines:
|
|
||||||
- `low`: Fast responses, basic reasoning (default for simple queries)
|
|
||||||
- `medium`: Balanced speed and reasoning depth
|
|
||||||
- `high`: Maximum reasoning effort, best for complex analytical questions
|
|
||||||
- **Supported Models**:
|
|
||||||
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
|
||||||
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
|
||||||
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
|
||||||
- **Example**: `--thinking-budget high` for complex analytical questions
|
|
||||||
|
|
||||||
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
|
||||||
|
|
||||||
**💡 Quick Examples:**
|
|
||||||
```bash
|
|
||||||
# OpenAI o-series reasoning model
|
|
||||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
|
||||||
--index-dir hnswbuild --backend hnsw \
|
|
||||||
--llm openai --llm-model o3 --thinking-budget medium
|
|
||||||
|
|
||||||
# Ollama reasoning model
|
|
||||||
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
|
||||||
--index-dir hnswbuild --backend hnsw \
|
|
||||||
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
|
||||||
```
|
|
||||||
|
|
||||||
### Graph Degree (HNSW/DiskANN)
|
|
||||||
|
|
||||||
**`--graph-degree`**
|
|
||||||
- Number of connections per node in the graph
|
|
||||||
- Higher = better recall but more memory
|
|
||||||
- HNSW: 16-32 (default: 32)
|
|
||||||
- DiskANN: 32-128 (default: 64)
|
|
||||||
|
|
||||||
|
|
||||||
## Performance Optimization Checklist
|
|
||||||
|
|
||||||
### If Embedding is Too Slow
|
|
||||||
|
|
||||||
1. **Switch to smaller model**:
|
|
||||||
```bash
|
|
||||||
# From large model
|
|
||||||
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
|
||||||
# To small model
|
|
||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Limit dataset size for testing**:
|
|
||||||
```bash
|
|
||||||
--max-items 1000 # Process first 1k items only
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Use MLX on Apple Silicon** (optional optimization):
|
|
||||||
```bash
|
|
||||||
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
|
||||||
```
|
|
||||||
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
|
||||||
|
|
||||||
4. **Use Ollama**
|
|
||||||
```bash
|
|
||||||
--embedding-mode ollama --embedding-model nomic-embed-text
|
|
||||||
```
|
|
||||||
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
|
||||||
### If Search Quality is Poor
|
|
||||||
|
|
||||||
1. **Increase retrieval count**:
|
|
||||||
```bash
|
|
||||||
--top-k 30 # Retrieve more candidates
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Upgrade embedding model**:
|
|
||||||
```bash
|
|
||||||
# For English
|
|
||||||
--embedding-model BAAI/bge-base-en-v1.5
|
|
||||||
# For multilingual
|
|
||||||
--embedding-model intfloat/multilingual-e5-large
|
|
||||||
```
|
|
||||||
|
|
||||||
## Understanding the Trade-offs
|
|
||||||
|
|
||||||
Every configuration choice involves trade-offs:
|
|
||||||
|
|
||||||
| Factor | Small/Fast | Large/Quality |
|
|
||||||
|--------|------------|---------------|
|
|
||||||
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
|
||||||
| Chunk Size | 512 tokens | 128 tokens |
|
|
||||||
| Index Type | HNSW | DiskANN |
|
|
||||||
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
|
||||||
|
|
||||||
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
|
||||||
|
|
||||||
## Low-resource setups
|
|
||||||
|
|
||||||
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
|
||||||
|
|
||||||
### 1) Use OpenAI embeddings (no local compute)
|
|
||||||
|
|
||||||
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export OPENAI_API_KEY=sk-...
|
|
||||||
|
|
||||||
# Build with OpenAI embeddings
|
|
||||||
leann build my-index \
|
|
||||||
--embedding-mode openai \
|
|
||||||
--embedding-model text-embedding-3-small
|
|
||||||
|
|
||||||
# Search with OpenAI embeddings (recompute at query time)
|
|
||||||
leann search my-index "your query" \
|
|
||||||
--recompute
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2) Run remote builds with SkyPilot (cloud GPU)
|
|
||||||
|
|
||||||
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# One-time: install and configure SkyPilot
|
|
||||||
pip install skypilot
|
|
||||||
|
|
||||||
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
|
||||||
sky launch -c leann-gpu sky/leann-build.yaml
|
|
||||||
|
|
||||||
# Override parameters via -e key=value (optional)
|
|
||||||
sky launch -c leann-gpu sky/leann-build.yaml \
|
|
||||||
-e index_name=my-index \
|
|
||||||
-e backend=hnsw \
|
|
||||||
-e embedding_mode=sentence-transformers \
|
|
||||||
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
|
||||||
|
|
||||||
# Copy the built index back to your local .leann (use rsync)
|
|
||||||
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3) Disable recomputation to trade storage for speed
|
|
||||||
|
|
||||||
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build without recomputation (HNSW requires non-compact in this mode)
|
|
||||||
leann build my-index --no-recompute --no-compact
|
|
||||||
|
|
||||||
# Search without recomputation
|
|
||||||
leann search my-index "your query" --no-recompute
|
|
||||||
```
|
|
||||||
|
|
||||||
When to use:
|
|
||||||
- Extreme low latency requirements (high QPS, interactive assistants)
|
|
||||||
- Read-heavy workloads where storage is cheaper than latency
|
|
||||||
- No always-available GPU
|
|
||||||
|
|
||||||
Constraints:
|
|
||||||
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
|
||||||
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
|
||||||
|
|
||||||
Storage impact:
|
|
||||||
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
|
||||||
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
|
||||||
|
|
||||||
Converting an existing index (rebuild required):
|
|
||||||
```bash
|
|
||||||
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
|
||||||
leann build my-index --force --no-recompute --no-compact
|
|
||||||
```
|
|
||||||
|
|
||||||
Python API usage:
|
|
||||||
```python
|
|
||||||
from leann import LeannSearcher
|
|
||||||
|
|
||||||
searcher = LeannSearcher("/path/to/my-index.leann")
|
|
||||||
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
Trade-offs:
|
|
||||||
- Lower latency and fewer network hops at query time
|
|
||||||
- Significantly higher storage (10–100× vs selective recomputation)
|
|
||||||
- Slightly larger memory footprint during build and search
|
|
||||||
|
|
||||||
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
|
||||||
|
|
||||||
- HNSW
|
|
||||||
|
|
||||||
```text
|
|
||||||
recompute=True: search_time=0.818s, size=1.1MB
|
|
||||||
recompute=False: search_time=0.012s, size=16.6MB
|
|
||||||
```
|
|
||||||
|
|
||||||
- DiskANN
|
|
||||||
|
|
||||||
```text
|
|
||||||
recompute=True: search_time=0.041s, size=5.9MB
|
|
||||||
recompute=False: search_time=0.013s, size=24.6MB
|
|
||||||
```
|
|
||||||
|
|
||||||
Conclusion:
|
|
||||||
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
|
||||||
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Further Reading
|
|
||||||
|
|
||||||
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
|
||||||
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
|
||||||
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
|
||||||
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
|
||||||
10
docs/faq.md
10
docs/faq.md
@@ -1,10 +0,0 @@
|
|||||||
# FAQ
|
|
||||||
|
|
||||||
## 1. My building time seems long
|
|
||||||
|
|
||||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
```
|
|
||||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
# ✨ Detailed Features
|
|
||||||
|
|
||||||
## 🔥 Core Features
|
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
|
||||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
|
||||||
|
|
||||||
## 🛠️ Technical Highlights
|
|
||||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
|
||||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
|
||||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
|
||||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
|
||||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
|
||||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
|
||||||
|
|
||||||
## 🎨 Developer Experience
|
|
||||||
|
|
||||||
- **Simple Python API** - Get started in minutes
|
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
# Normalized Embeddings Support in LEANN
|
|
||||||
|
|
||||||
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
|
||||||
|
|
||||||
## What are Normalized Embeddings?
|
|
||||||
|
|
||||||
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
|
||||||
|
|
||||||
## Automatic Detection
|
|
||||||
|
|
||||||
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
|
||||||
|
|
||||||
1. **Automatically set `distance_metric="cosine"`** if not specified
|
|
||||||
2. **Show a warning** if you manually specify a different distance metric
|
|
||||||
3. **Provide optimal search performance** with the correct metric
|
|
||||||
|
|
||||||
## Supported Normalized Embedding Models
|
|
||||||
|
|
||||||
### OpenAI
|
|
||||||
All OpenAI text embedding models are normalized:
|
|
||||||
- `text-embedding-ada-002`
|
|
||||||
- `text-embedding-3-small`
|
|
||||||
- `text-embedding-3-large`
|
|
||||||
|
|
||||||
### Voyage AI
|
|
||||||
All Voyage AI embedding models are normalized:
|
|
||||||
- `voyage-2`
|
|
||||||
- `voyage-3`
|
|
||||||
- `voyage-large-2`
|
|
||||||
- `voyage-multilingual-2`
|
|
||||||
- `voyage-code-2`
|
|
||||||
|
|
||||||
### Cohere
|
|
||||||
All Cohere embedding models are normalized:
|
|
||||||
- `embed-english-v3.0`
|
|
||||||
- `embed-multilingual-v3.0`
|
|
||||||
- `embed-english-light-v3.0`
|
|
||||||
- `embed-multilingual-light-v3.0`
|
|
||||||
|
|
||||||
## Example Usage
|
|
||||||
|
|
||||||
```python
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
|
|
||||||
# Automatic detection - will use cosine distance
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="text-embedding-3-small",
|
|
||||||
embedding_mode="openai"
|
|
||||||
)
|
|
||||||
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
|
||||||
# Automatically setting distance_metric='cosine'
|
|
||||||
|
|
||||||
# Manual override (not recommended)
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="text-embedding-3-small",
|
|
||||||
embedding_mode="openai",
|
|
||||||
distance_metric="mips" # Will show warning
|
|
||||||
)
|
|
||||||
# Warning: Using 'mips' distance metric with normalized embeddings...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Non-Normalized Embeddings
|
|
||||||
|
|
||||||
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
|
||||||
|
|
||||||
## Why This Matters
|
|
||||||
|
|
||||||
Using the wrong distance metric with normalized embeddings can lead to:
|
|
||||||
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
|
||||||
- **Incorrect ranking** of search results
|
|
||||||
- **Suboptimal performance** compared to using the correct metric
|
|
||||||
|
|
||||||
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
# 📈 Roadmap
|
|
||||||
|
|
||||||
## 🎯 Q2 2025
|
|
||||||
|
|
||||||
- [X] HNSW backend integration
|
|
||||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [X] Real-time embedding pipeline
|
|
||||||
- [X] Memory-efficient graph pruning
|
|
||||||
|
|
||||||
## 🚀 Q3 2025
|
|
||||||
|
|
||||||
- [ ] Advanced caching strategies
|
|
||||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
|
||||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
|
||||||
- [ ] Add OpenAI recompute API
|
|
||||||
|
|
||||||
## 🌟 Q4 2025
|
|
||||||
|
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
|
||||||
- [ ] Visual similarity search
|
|
||||||
- [ ] Query rewrtiting, rerank and expansion
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
"""
|
|
||||||
Simple demo showing basic leann usage
|
|
||||||
Run: uv run python examples/basic_demo.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from leann import LeannBuilder, LeannChat, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Simple demo of Leann with selectable embedding models."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding_model",
|
|
||||||
type=str,
|
|
||||||
default="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Sample knowledge base
|
|
||||||
chunks = [
|
|
||||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
|
||||||
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
|
|
||||||
"Natural language processing helps computers understand and generate human language.",
|
|
||||||
"Computer vision enables machines to interpret and understand visual information from images and videos.",
|
|
||||||
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
|
|
||||||
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
|
|
||||||
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
|
||||||
"Cloud computing provides on-demand access to computing resources over the internet.",
|
|
||||||
]
|
|
||||||
|
|
||||||
print("1. Building index (no embeddings stored)...")
|
|
||||||
builder = LeannBuilder(
|
|
||||||
embedding_model=args.embedding_model,
|
|
||||||
backend_name="hnsw",
|
|
||||||
)
|
|
||||||
for chunk in chunks:
|
|
||||||
builder.add_text(chunk)
|
|
||||||
builder.build_index("demo_knowledge.leann")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("2. Searching with real-time embeddings...")
|
|
||||||
searcher = LeannSearcher("demo_knowledge.leann")
|
|
||||||
|
|
||||||
queries = [
|
|
||||||
"What is machine learning?",
|
|
||||||
"How does neural network work?",
|
|
||||||
"Tell me about data processing",
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print(f"Query: {query}")
|
|
||||||
results = searcher.search(query, top_k=2)
|
|
||||||
|
|
||||||
for i, result in enumerate(results, 1):
|
|
||||||
print(f" {i}. Score: {result.score:.3f}")
|
|
||||||
print(f" Text: {result.text[:100]}...")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("3. Interactive chat demo:")
|
|
||||||
print(" (Note: Requires OpenAI API key for real responses)")
|
|
||||||
|
|
||||||
chat = LeannChat("demo_knowledge.leann")
|
|
||||||
|
|
||||||
# Demo questions
|
|
||||||
demo_questions: list[str] = [
|
|
||||||
"What is the difference between machine learning and deep learning?",
|
|
||||||
"How is data science related to big data?",
|
|
||||||
]
|
|
||||||
|
|
||||||
for question in demo_questions:
|
|
||||||
print(f" Q: {question}")
|
|
||||||
response = chat.ask(question)
|
|
||||||
print(f" A: {response}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("Demo completed! Try running:")
|
|
||||||
print(" uv run python apps/document_rag.py")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -3,15 +3,14 @@
|
|||||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import gc
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -62,7 +61,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[sys.executable, "benchmarks/faiss_only.py"],
|
[sys.executable, "examples/faiss_only.py"],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300,
|
timeout=300,
|
||||||
@@ -84,7 +83,9 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "Peak Memory:" in line:
|
if "Peak Memory:" in line:
|
||||||
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
peak_memory = float(
|
||||||
|
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||||
|
)
|
||||||
|
|
||||||
return {"peak_memory": peak_memory}
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
@@ -110,12 +111,13 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
tracker.checkpoint("After imports")
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
from leann.api import LeannBuilder
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
# Load and parse documents
|
# Load and parse documents
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"data",
|
"examples/data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -133,7 +135,6 @@ def test_leann_hnsw():
|
|||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
print(f"Total number of chunks: {len(all_texts)}")
|
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
@@ -200,9 +201,11 @@ def test_leann_hnsw():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
tracker.checkpoint("After searcher loading")
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Running search queries...")
|
print("Running search queries...")
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
@@ -300,15 +303,21 @@ def main():
|
|||||||
|
|
||||||
print("\nLEANN vs Faiss Performance:")
|
print("\nLEANN vs Faiss Performance:")
|
||||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
print(
|
||||||
|
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||||
|
)
|
||||||
|
|
||||||
# Storage comparison
|
# Storage comparison
|
||||||
if leann_storage_size > faiss_storage_size:
|
if leann_storage_size > faiss_storage_size:
|
||||||
storage_ratio = leann_storage_size / faiss_storage_size
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
print(
|
||||||
|
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||||
|
)
|
||||||
elif faiss_storage_size > leann_storage_size:
|
elif faiss_storage_size > leann_storage_size:
|
||||||
storage_ratio = faiss_storage_size / leann_storage_size
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
print(
|
||||||
|
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(" Storage Size: similar")
|
print(" Storage Size: similar")
|
||||||
else:
|
else:
|
||||||
122
examples/email_data/LEANN_email_reader.py
Normal file
122
examples/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import os
|
||||||
|
import email
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str = None) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, dirnames, filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: List[Document] = []
|
||||||
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
if count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get('Subject', 'No Subject')
|
||||||
|
from_addr = msg.get('From', 'Unknown')
|
||||||
|
to_addr = msg.get('To', 'Unknown')
|
||||||
|
date = msg.get('Date', 'Unknown')
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
||||||
|
if part.get_content_type() == "text/html" and not self.include_html:
|
||||||
|
continue
|
||||||
|
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[File]: {filename}
|
||||||
|
[From]: {from_addr}
|
||||||
|
[To]: {to_addr}
|
||||||
|
[Subject]: {subject}
|
||||||
|
[Date]: {date}
|
||||||
|
[EMAIL BODY Start]:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded {len(docs)} email documents")
|
||||||
|
return docs
|
||||||
192
examples/email_data/email.py
Normal file
192
examples/email_data/email.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Contains simple parser for mbox files.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
|
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MboxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Mbox parser.
|
||||||
|
|
||||||
|
Extract messages from mailbox files.
|
||||||
|
Returns string including date, subject, sender, receiver and
|
||||||
|
content for each message.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
|
"Date: {_date}\n"
|
||||||
|
"From: {_from}\n"
|
||||||
|
"To: {_to}\n"
|
||||||
|
"Subject: {_subject}\n"
|
||||||
|
"Content: {_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
max_count: int = 0,
|
||||||
|
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_count = max_count
|
||||||
|
self.message_format = message_format
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse file into string."""
|
||||||
|
# Import required libraries
|
||||||
|
import mailbox
|
||||||
|
from email.parser import BytesParser
|
||||||
|
from email.policy import default
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but MboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
results: List[str] = []
|
||||||
|
# Load file using mailbox
|
||||||
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
|
|
||||||
|
# Iterate through all messages
|
||||||
|
for _, _msg in enumerate(mbox):
|
||||||
|
try:
|
||||||
|
msg: mailbox.mboxMessage = _msg
|
||||||
|
# Parse multipart messages
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
ctype = part.get_content_type()
|
||||||
|
cdispo = str(part.get("Content-Disposition"))
|
||||||
|
if "attachment" in cdispo:
|
||||||
|
print(f"Attachment found: {part.get_filename()}")
|
||||||
|
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||||
|
content = part.get_payload(decode=True) # decode
|
||||||
|
break
|
||||||
|
# Get plain message payload for non-multipart messages
|
||||||
|
else:
|
||||||
|
content = msg.get_payload(decode=True)
|
||||||
|
|
||||||
|
# Parse message HTML content and remove unneeded whitespace
|
||||||
|
soup = BeautifulSoup(content)
|
||||||
|
stripped_content = " ".join(soup.get_text().split())
|
||||||
|
# Format message to include date, sender, receiver and subject
|
||||||
|
msg_string = self.message_format.format(
|
||||||
|
_date=msg["date"],
|
||||||
|
_from=msg["from"],
|
||||||
|
_to=msg["to"],
|
||||||
|
_subject=msg["subject"],
|
||||||
|
_content=stripped_content,
|
||||||
|
)
|
||||||
|
# Add message string to results
|
||||||
|
results.append(msg_string)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
||||||
|
|
||||||
|
# Increment counter and return if max count is met
|
||||||
|
i += 1
|
||||||
|
if self.max_count > 0 and i >= self.max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxMboxReader(MboxReader):
|
||||||
|
"""
|
||||||
|
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||||
|
|
||||||
|
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||||
|
1. Reading .emlx files from a directory
|
||||||
|
2. Converting them to mbox format in memory
|
||||||
|
3. Using the parent MboxReader's parsing logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
directory: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
if fs:
|
||||||
|
logger.warning(
|
||||||
|
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||||
|
"from fsspec filesystems. Will load from local filesystem instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find all .emlx files in the directory
|
||||||
|
emlx_files = list(directory.glob("*.emlx"))
|
||||||
|
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||||
|
|
||||||
|
if not emlx_files:
|
||||||
|
logger.warning(f"No .emlx files found in {directory}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a temporary mbox file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||||
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
|
# Convert .emlx files to mbox format
|
||||||
|
for emlx_file in emlx_files:
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx format: first line is length, rest is email content
|
||||||
|
lines = content.split('\n', 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
|
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||||
|
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Close the temporary file so MboxReader can read it
|
||||||
|
temp_mbox.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the parent MboxReader's logic to parse the mbox file
|
||||||
|
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||||
|
finally:
|
||||||
|
# Clean up temporary file
|
||||||
|
try:
|
||||||
|
os.unlink(temp_mbox_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Test only Faiss HNSW"""
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
def get_memory_usage():
|
||||||
@@ -37,20 +37,20 @@ def main():
|
|||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Faiss is not installed.")
|
print("Faiss is not installed.")
|
||||||
print(
|
print("Please install it with `uv pip install faiss-cpu`")
|
||||||
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from llama_index.core import (
|
from llama_index.core import (
|
||||||
Settings,
|
|
||||||
SimpleDirectoryReader,
|
SimpleDirectoryReader,
|
||||||
StorageContext,
|
|
||||||
VectorStoreIndex,
|
VectorStoreIndex,
|
||||||
|
StorageContext,
|
||||||
|
Settings,
|
||||||
|
node_parser,
|
||||||
|
Document,
|
||||||
)
|
)
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
tracker.checkpoint("Initial")
|
tracker.checkpoint("Initial")
|
||||||
@@ -65,7 +65,7 @@ def main():
|
|||||||
tracker.checkpoint("After Faiss index creation")
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"data",
|
"examples/data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -90,9 +90,8 @@ def main():
|
|||||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
)
|
)
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
index = load_index_from_storage(storage_context=storage_context)
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
print("Index loaded from ./storage_faiss")
|
print(f"Index loaded from ./storage_faiss")
|
||||||
tracker.checkpoint("After loading existing index")
|
tracker.checkpoint("After loading existing index")
|
||||||
index_loaded = True
|
index_loaded = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -100,7 +99,6 @@ def main():
|
|||||||
print("Cleaning up corrupted index and building new one...")
|
print("Cleaning up corrupted index and building new one...")
|
||||||
# Clean up corrupted index
|
# Clean up corrupted index
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if os.path.exists("./storage_faiss"):
|
if os.path.exists("./storage_faiss"):
|
||||||
shutil.rmtree("./storage_faiss")
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
@@ -111,7 +109,9 @@ def main():
|
|||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents, storage_context=storage_context, transformations=[node_parser]
|
documents,
|
||||||
|
storage_context=storage_context,
|
||||||
|
transformations=[node_parser]
|
||||||
)
|
)
|
||||||
tracker.checkpoint("After index building")
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ def main():
|
|||||||
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
285
examples/google_history_reader_leann.py
Normal file
285
examples/google_history_reader_leann.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
try:
|
||||||
|
import dotenv
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# python-dotenv is not installed; skip loading environment variables
|
||||||
|
dotenv = None
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||||
|
|
||||||
|
# Default Chrome profile path
|
||||||
|
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple Chrome profile data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process per profile
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Chrome profile directory
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_count
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {profile_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||||
|
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = node.get_content()
|
||||||
|
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
||||||
|
"""
|
||||||
|
Create LEANN index from Chrome history data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of history entries to process
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from Chrome history data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=profile_path,
|
||||||
|
max_count=max_count
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} history documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=32,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 1000
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
||||||
|
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
||||||
|
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
||||||
|
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
||||||
|
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
||||||
|
parser.add_argument('--max-entries', type=int, default=1000,
|
||||||
|
help='Maximum number of history entries to process (default: 1000)')
|
||||||
|
parser.add_argument('--query', type=str, default=None,
|
||||||
|
help='Single query to run (default: runs example queries)')
|
||||||
|
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
||||||
|
help='Automatically find all Chrome profiles (default: True)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||||
|
|
||||||
|
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Find Chrome profile directories
|
||||||
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
if args.auto_find_profiles:
|
||||||
|
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found automatically. Exiting.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Use single specified profile
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
if not profile_path.exists():
|
||||||
|
print(f"Chrome profile not found: {profile_path}")
|
||||||
|
return
|
||||||
|
profile_dirs = [profile_path]
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"What websites did I visit about machine learning?",
|
||||||
|
"Find my search history about programming"
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
from .history import ChromeHistoryReader
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
__all__ = ["ChromeHistoryReader"]
|
__all__ = ['ChromeHistoryReader']
|
||||||
@@ -1,12 +1,10 @@
|
|||||||
import os
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import List, Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
class ChromeHistoryReader(BaseReader):
|
class ChromeHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
@@ -19,7 +17,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Load Chrome history data from the default Chrome profile location.
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
@@ -29,15 +27,13 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
max_count (int): Maximum amount of history entries to read.
|
max_count (int): Maximum amount of history entries to read.
|
||||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
"""
|
"""
|
||||||
docs: list[Document] = []
|
docs: List[Document] = []
|
||||||
max_count = load_kwargs.get("max_count", 1000)
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||||
|
|
||||||
# Default Chrome profile path on macOS
|
# Default Chrome profile path on macOS
|
||||||
if chrome_profile_path is None:
|
if chrome_profile_path is None:
|
||||||
chrome_profile_path = os.path.expanduser(
|
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
"~/Library/Application Support/Google/Chrome/Default"
|
|
||||||
)
|
|
||||||
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
@@ -86,7 +82,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||||
# if len(title) > 150:
|
# if len(title) > 150:
|
||||||
# print(f"Title is too long: {title}")
|
# print(f"Title is too long: {title}")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
@@ -97,17 +93,12 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading Chrome history: {e}")
|
print(f"Error reading Chrome history: {e}")
|
||||||
# add you may need to close your browser to make the database file available
|
|
||||||
# also highlight in red
|
|
||||||
print(
|
|
||||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
|
||||||
)
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_chrome_profiles() -> list[Path]:
|
def find_chrome_profiles() -> List[Path]:
|
||||||
"""
|
"""
|
||||||
Find all Chrome profile directories.
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
@@ -133,9 +124,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
return profile_dirs
|
return profile_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_history_to_file(
|
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||||
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Export Chrome history to a text file using the same SQL query format.
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
@@ -143,9 +132,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
output_file: Path to the output file
|
output_file: Path to the output file
|
||||||
max_count: Maximum number of entries to export
|
max_count: Maximum number of entries to export
|
||||||
"""
|
"""
|
||||||
chrome_profile_path = os.path.expanduser(
|
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
"~/Library/Application Support/Google/Chrome/Default"
|
|
||||||
)
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
if not os.path.exists(history_db_path):
|
||||||
@@ -172,12 +159,10 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
cursor.execute(query, (max_count,))
|
cursor.execute(query, (max_count,))
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
f.write(
|
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||||
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
@@ -2,14 +2,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import List, Any, Dict, Optional
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
class WeChatHistoryReader(BaseReader):
|
class WeChatHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
@@ -44,16 +43,10 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
if not wechattweak_path.exists():
|
if not wechattweak_path.exists():
|
||||||
print("Downloading WeChatTweak CLI...")
|
print("Downloading WeChatTweak CLI...")
|
||||||
subprocess.run(
|
subprocess.run([
|
||||||
[
|
"curl", "-L", "-o", str(wechattweak_path),
|
||||||
"curl",
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
||||||
"-L",
|
], check=True)
|
||||||
"-o",
|
|
||||||
str(wechattweak_path),
|
|
||||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make executable
|
# Make executable
|
||||||
wechattweak_path.chmod(0o755)
|
wechattweak_path.chmod(0o755)
|
||||||
@@ -80,16 +73,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
def check_api_available(self) -> bool:
|
def check_api_available(self) -> bool:
|
||||||
"""Check if WeChatTweak API is available."""
|
"""Check if WeChatTweak API is available."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run([
|
||||||
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
||||||
capture_output=True,
|
], capture_output=True, text=True, timeout=5)
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_readable_text(self, content: str) -> str:
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extract readable text from message content, removing XML and system messages.
|
Extract readable text from message content, removing XML and system messages.
|
||||||
@@ -107,14 +100,14 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Extract text from dictionary structure
|
# Extract text from dictionary structure
|
||||||
text_parts = []
|
text_parts = []
|
||||||
if "title" in content:
|
if 'title' in content:
|
||||||
text_parts.append(str(content["title"]))
|
text_parts.append(str(content['title']))
|
||||||
if "quoted" in content:
|
if 'quoted' in content:
|
||||||
text_parts.append(str(content["quoted"]))
|
text_parts.append(str(content['quoted']))
|
||||||
if "content" in content:
|
if 'content' in content:
|
||||||
text_parts.append(str(content["content"]))
|
text_parts.append(str(content['content']))
|
||||||
if "text" in content:
|
if 'text' in content:
|
||||||
text_parts.append(str(content["text"]))
|
text_parts.append(str(content['text']))
|
||||||
|
|
||||||
if text_parts:
|
if text_parts:
|
||||||
return " | ".join(text_parts)
|
return " | ".join(text_parts)
|
||||||
@@ -127,11 +120,11 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n"
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||||
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||||
|
|
||||||
# If it's just XML or system message, return empty
|
# If it's just XML or system message, return empty
|
||||||
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return clean_content.strip()
|
return clean_content.strip()
|
||||||
@@ -152,9 +145,9 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Handle dictionary content
|
# Handle dictionary content
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Check if dict has any readable text fields
|
# Check if dict has any readable text fields
|
||||||
text_fields = ["title", "quoted", "content", "text"]
|
text_fields = ['title', 'quoted', 'content', 'text']
|
||||||
for field in text_fields:
|
for field in text_fields:
|
||||||
if content.get(field):
|
if field in content and content[field]:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -163,47 +156,42 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip image messages (contain XML with img tags)
|
# Skip image messages (contain XML with img tags)
|
||||||
if "<img" in content and "cdnurl" in content:
|
if '<img' in content and 'cdnurl' in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip emoji messages (contain emoji XML tags)
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
if "<emoji" in content and "productid" in content:
|
if '<emoji' in content and 'productid' in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip voice messages
|
# Skip voice messages
|
||||||
if "<voice" in content:
|
if '<voice' in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip video messages
|
# Skip video messages
|
||||||
if "<video" in content:
|
if '<video' in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip file messages
|
# Skip file messages
|
||||||
if "<appmsg" in content and "appid" in content:
|
if '<appmsg' in content and 'appid' in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip system messages (like "recalled a message")
|
# Skip system messages (like "recalled a message")
|
||||||
if "recalled a message" in content:
|
if 'recalled a message' in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if there's actual readable text (not just XML or system messages)
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||||
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||||
|
|
||||||
# If after cleaning we have meaningful text, consider it readable
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _concatenate_messages(
|
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
||||||
self,
|
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
||||||
messages: list[dict],
|
|
||||||
max_length: int = 128,
|
|
||||||
time_window_minutes: int = 30,
|
|
||||||
overlap_messages: int = 0,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
"""
|
||||||
Concatenate messages based on length and time rules.
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
@@ -226,12 +214,12 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Extract message info
|
# Extract message info
|
||||||
content = message.get("content", "")
|
content = message.get('content', '')
|
||||||
message_text = message.get("message", "")
|
message_text = message.get('message', '')
|
||||||
create_time = message.get("createTime", 0)
|
create_time = message.get('createTime', 0)
|
||||||
message.get("fromUser", "")
|
from_user = message.get('fromUser', '')
|
||||||
message.get("toUser", "")
|
to_user = message.get('toUser', '')
|
||||||
message.get("isSentFromSelf", False)
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -248,24 +236,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if time_diff_minutes > time_window_minutes:
|
if time_diff_minutes > time_window_minutes:
|
||||||
# Time gap too large, start new group
|
# Time gap too large, start new group
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append(
|
concatenated_groups.append({
|
||||||
{
|
'messages': current_group,
|
||||||
"messages": current_group,
|
'total_length': current_length,
|
||||||
"total_length": current_length,
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
"start_time": current_group[0].get("createTime", 0),
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
"end_time": current_group[-1].get("createTime", 0),
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(
|
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||||
len(
|
|
||||||
self._extract_readable_text(msg.get("content", ""))
|
|
||||||
or msg.get("message", "")
|
|
||||||
)
|
|
||||||
for msg in current_group
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -274,24 +254,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
message_length = len(readable_text)
|
message_length = len(readable_text)
|
||||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
# Current group would exceed max length, save it and start new
|
# Current group would exceed max length, save it and start new
|
||||||
concatenated_groups.append(
|
concatenated_groups.append({
|
||||||
{
|
'messages': current_group,
|
||||||
"messages": current_group,
|
'total_length': current_length,
|
||||||
"total_length": current_length,
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
"start_time": current_group[0].get("createTime", 0),
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
"end_time": current_group[-1].get("createTime", 0),
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(
|
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||||
len(
|
|
||||||
self._extract_readable_text(msg.get("content", ""))
|
|
||||||
or msg.get("message", "")
|
|
||||||
)
|
|
||||||
for msg in current_group
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -303,18 +275,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
# Add the last group if it exists
|
# Add the last group if it exists
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append(
|
concatenated_groups.append({
|
||||||
{
|
'messages': current_group,
|
||||||
"messages": current_group,
|
'total_length': current_length,
|
||||||
"total_length": current_length,
|
'start_time': current_group[0].get('createTime', 0),
|
||||||
"start_time": current_group[0].get("createTime", 0),
|
'end_time': current_group[-1].get('createTime', 0)
|
||||||
"end_time": current_group[-1].get("createTime", 0),
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return concatenated_groups
|
return concatenated_groups
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create concatenated content from a group of messages.
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
@@ -325,16 +295,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted concatenated content
|
Formatted concatenated content
|
||||||
"""
|
"""
|
||||||
messages = message_group["messages"]
|
messages = message_group['messages']
|
||||||
start_time = message_group["start_time"]
|
start_time = message_group['start_time']
|
||||||
end_time = message_group["end_time"]
|
end_time = message_group['end_time']
|
||||||
|
|
||||||
# Format timestamps
|
# Format timestamps
|
||||||
if start_time:
|
if start_time:
|
||||||
try:
|
try:
|
||||||
start_timestamp = datetime.fromtimestamp(start_time)
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
except (ValueError, OSError):
|
except:
|
||||||
start_time_str = str(start_time)
|
start_time_str = str(start_time)
|
||||||
else:
|
else:
|
||||||
start_time_str = "Unknown"
|
start_time_str = "Unknown"
|
||||||
@@ -342,8 +312,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if end_time:
|
if end_time:
|
||||||
try:
|
try:
|
||||||
end_timestamp = datetime.fromtimestamp(end_time)
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
except (ValueError, OSError):
|
except:
|
||||||
end_time_str = str(end_time)
|
end_time_str = str(end_time)
|
||||||
else:
|
else:
|
||||||
end_time_str = "Unknown"
|
end_time_str = "Unknown"
|
||||||
@@ -351,10 +321,10 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Build concatenated message content
|
# Build concatenated message content
|
||||||
message_parts = []
|
message_parts = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message.get("content", "")
|
content = message.get('content', '')
|
||||||
message_text = message.get("message", "")
|
message_text = message.get('message', '')
|
||||||
create_time = message.get("createTime", 0)
|
create_time = message.get('createTime', 0)
|
||||||
is_sent_from_self = message.get("isSentFromSelf", False)
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -366,8 +336,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
# change to YYYY-MM-DD HH:MM:SS
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
except (ValueError, OSError):
|
except:
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -381,7 +351,7 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
Contact: {contact_name}
|
||||||
Time Range: {start_time_str} - {end_time_str}
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
@@ -391,7 +361,7 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
"""
|
"""
|
||||||
return doc_content, contact_name
|
return doc_content, contact_name
|
||||||
|
|
||||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Load WeChat chat history data from exported JSON files.
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
@@ -406,13 +376,13 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
"""
|
"""
|
||||||
docs: list[Document] = []
|
docs: List[Document] = []
|
||||||
max_count = load_kwargs.get("max_count", 1000)
|
max_count = load_kwargs.get('max_count', 1000)
|
||||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
||||||
include_non_text = load_kwargs.get("include_non_text", False)
|
include_non_text = load_kwargs.get('include_non_text', False)
|
||||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
||||||
max_length = load_kwargs.get("max_length", 1000)
|
max_length = load_kwargs.get('max_length', 1000)
|
||||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
||||||
|
|
||||||
# Default WeChat export path
|
# Default WeChat export path
|
||||||
if wechat_export_dir is None:
|
if wechat_export_dir is None:
|
||||||
@@ -433,7 +403,7 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, encoding="utf-8") as f:
|
with open(json_file, 'r', encoding='utf-8') as f:
|
||||||
chat_data = json.load(f)
|
chat_data = json.load(f)
|
||||||
|
|
||||||
# Extract contact name from filename
|
# Extract contact name from filename
|
||||||
@@ -444,7 +414,7 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
readable_messages = []
|
readable_messages = []
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
try:
|
try:
|
||||||
content = message.get("content", "")
|
content = message.get('content', '')
|
||||||
if not include_non_text and not self._is_text_message(content):
|
if not include_non_text and not self._is_text_message(content):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -460,9 +430,9 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
# Concatenate messages based on rules
|
# Concatenate messages based on rules
|
||||||
message_groups = self._concatenate_messages(
|
message_groups = self._concatenate_messages(
|
||||||
readable_messages,
|
readable_messages,
|
||||||
max_length=max_length,
|
max_length=-1,
|
||||||
time_window_minutes=time_window_minutes,
|
time_window_minutes=-1,
|
||||||
overlap_messages=0, # No overlap between groups
|
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
# Create documents from concatenated groups
|
||||||
@@ -470,19 +440,12 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_content, contact_name = self._create_concatenated_content(
|
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||||
message_group, contact_name
|
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||||
)
|
|
||||||
doc = Document(
|
|
||||||
text=doc_content,
|
|
||||||
metadata={"contact_name": contact_name},
|
|
||||||
)
|
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
print(
|
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
||||||
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Original single-message processing
|
# Original single-message processing
|
||||||
@@ -491,12 +454,12 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Extract message information
|
# Extract message information
|
||||||
message.get("fromUser", "")
|
from_user = message.get('fromUser', '')
|
||||||
message.get("toUser", "")
|
to_user = message.get('toUser', '')
|
||||||
content = message.get("content", "")
|
content = message.get('content', '')
|
||||||
message_text = message.get("message", "")
|
message_text = message.get('message', '')
|
||||||
create_time = message.get("createTime", 0)
|
create_time = message.get('createTime', 0)
|
||||||
is_sent_from_self = message.get("isSentFromSelf", False)
|
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||||
|
|
||||||
# Handle content that might be dict or string
|
# Handle content that might be dict or string
|
||||||
try:
|
try:
|
||||||
@@ -517,8 +480,8 @@ Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
except (ValueError, OSError):
|
except:
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -532,9 +495,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(
|
doc = Document(text=doc_content, metadata={})
|
||||||
text=doc_content, metadata={"contact_name": contact_name}
|
|
||||||
)
|
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -551,7 +512,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_wechat_export_dirs() -> list[Path]:
|
def find_wechat_export_dirs() -> List[Path]:
|
||||||
"""
|
"""
|
||||||
Find all WeChat export directories.
|
Find all WeChat export directories.
|
||||||
|
|
||||||
@@ -562,10 +523,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
|
|
||||||
# Look for common export directory names
|
# Look for common export directory names
|
||||||
possible_dirs = [
|
possible_dirs = [
|
||||||
|
Path("./wechat_export_test"),
|
||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export"),
|
Path("./chat_export")
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir in possible_dirs:
|
for export_dir in possible_dirs:
|
||||||
@@ -573,20 +534,13 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
json_files = list(export_dir.glob("*.json"))
|
json_files = list(export_dir.glob("*.json"))
|
||||||
if json_files:
|
if json_files:
|
||||||
export_dirs.append(export_dir)
|
export_dirs.append(export_dir)
|
||||||
print(
|
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
||||||
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
return export_dirs
|
return export_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_chat_to_file(
|
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
||||||
output_file: str = "wechat_chat_export.txt",
|
|
||||||
max_count: int = 1000,
|
|
||||||
export_dir: str | None = None,
|
|
||||||
include_non_text: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history to a text file.
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
@@ -606,14 +560,14 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
try:
|
try:
|
||||||
json_files = list(Path(export_dir).glob("*.json"))
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
count = 0
|
count = 0
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, encoding="utf-8") as json_f:
|
with open(json_file, 'r', encoding='utf-8') as json_f:
|
||||||
chat_data = json.load(json_f)
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
contact_name = json_file.stem
|
contact_name = json_file.stem
|
||||||
@@ -623,10 +577,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
from_user = message.get("fromUser", "")
|
from_user = message.get('fromUser', '')
|
||||||
content = message.get("content", "")
|
content = message.get('content', '')
|
||||||
message_text = message.get("message", "")
|
message_text = message.get('message', '')
|
||||||
create_time = message.get("createTime", 0)
|
create_time = message.get('createTime', 0)
|
||||||
|
|
||||||
# Skip non-text messages unless requested
|
# Skip non-text messages unless requested
|
||||||
if not include_non_text:
|
if not include_non_text:
|
||||||
@@ -641,8 +595,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
except (ValueError, OSError):
|
except:
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -659,7 +613,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error exporting WeChat chat history: {e}")
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history using wechat-exporter tool.
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
@@ -688,21 +642,16 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
if requirements_file.exists():
|
if requirements_file.exists():
|
||||||
print("Installing wechat-exporter requirements...")
|
print("Installing wechat-exporter requirements...")
|
||||||
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
subprocess.run([
|
||||||
|
"uv", "pip", "install", "-r", str(requirements_file)
|
||||||
|
], check=True)
|
||||||
|
|
||||||
# Run the export command
|
# Run the export command
|
||||||
print("Running wechat-exporter...")
|
print("Running wechat-exporter...")
|
||||||
result = subprocess.run(
|
result = subprocess.run([
|
||||||
[
|
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
||||||
sys.executable,
|
"export-all", str(export_path)
|
||||||
str(self.wechat_exporter_dir / "main.py"),
|
], capture_output=True, text=True, check=True)
|
||||||
"export-all",
|
|
||||||
str(export_path),
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Export command output:")
|
print("Export command output:")
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
@@ -713,9 +662,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
# Check if export was successful
|
# Check if export was successful
|
||||||
if export_path.exists() and any(export_path.glob("*.json")):
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
json_files = list(export_path.glob("*.json"))
|
json_files = list(export_path.glob("*.json"))
|
||||||
print(
|
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
||||||
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
|
||||||
)
|
|
||||||
return export_path
|
return export_path
|
||||||
else:
|
else:
|
||||||
print("Export completed but no JSON files found")
|
print("Export completed but no JSON files found")
|
||||||
@@ -731,7 +678,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
||||||
"""
|
"""
|
||||||
Find existing WeChat exports or create new ones.
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
@@ -750,7 +697,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export"),
|
Path("./chat_export")
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir_path in possible_export_dirs:
|
for export_dir_path in possible_export_dirs:
|
||||||
@@ -767,8 +714,6 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if exported_path:
|
if exported_path:
|
||||||
export_dirs = [exported_path]
|
export_dirs = [exported_path]
|
||||||
else:
|
else:
|
||||||
print(
|
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
|
||||||
)
|
|
||||||
|
|
||||||
return export_dirs
|
return export_dirs
|
||||||
288
examples/mail_reader_leann.py
Normal file
288
examples/mail_reader_leann.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
# Add the project root to Python path so we can import from examples
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
def get_mail_path():
|
||||||
|
"""Get the mail path for the current user"""
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
# Default mail path for macOS
|
||||||
|
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_dirs: List of Path objects pointing to Messages directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process per directory
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple mail data sources...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each Messages directory
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(messages_dir)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {messages_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = node.get_content()
|
||||||
|
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
||||||
|
"""
|
||||||
|
Create LEANN index from mail data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mail_path: Path to the mail directory
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of emails to process
|
||||||
|
include_html: Whether to include HTML content in email processing
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from mail data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
reader = EmlxReader(include_html=include_html)
|
||||||
|
# from email_data.email import EmlxMboxReader
|
||||||
|
# from pathlib import Path
|
||||||
|
# reader = EmlxMboxReader()
|
||||||
|
documents = reader.load_data(Path(mail_path))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1 # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path,
|
||||||
|
llm_config={"type": "openai", "model": "gpt-4o"})
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=10,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=12,
|
||||||
|
beam_width=1,
|
||||||
|
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
||||||
|
# Remove --mail-path argument and auto-detect all Messages directories
|
||||||
|
# Remove DEFAULT_MAIL_PATH
|
||||||
|
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
|
||||||
|
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
||||||
|
parser.add_argument('--max-emails', type=int, default=1000,
|
||||||
|
help='Maximum number of emails to process (-1 means all)')
|
||||||
|
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
||||||
|
help='Single query to run (default: runs example queries)')
|
||||||
|
parser.add_argument('--include-html', action='store_true', default=False,
|
||||||
|
help='Include HTML content in email processing (default: False)')
|
||||||
|
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
||||||
|
help='Embedding model to use (default: facebook/contriever)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"args: {args}")
|
||||||
|
|
||||||
|
# Automatically find all Messages directories under the current user's Mail directory
|
||||||
|
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||||
|
mail_path = get_mail_path()
|
||||||
|
print(f"Searching for email data in: {mail_path}")
|
||||||
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
|
||||||
|
print('len(messages_dirs): ', len(messages_dirs))
|
||||||
|
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Messages directories found. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
|
"how's the icloud related advertisement saying",
|
||||||
|
"Whats the number of class recommend to take per semester for incoming EECS students"
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "="*60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
115
examples/main_cli_example.py
Normal file
115
examples/main_cli_example.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import argparse
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args):
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Loading documents...")
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.data_dir,
|
||||||
|
recursive=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
print("Documents loaded.")
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print("--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
|
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
||||||
|
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
||||||
|
llm_config = {"type": "openai", "model": "gpt-4o"}
|
||||||
|
|
||||||
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
|
|
||||||
|
# query = (
|
||||||
|
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
|
# )
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Leann Chat with various LLM backends."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="hf",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
help="The LLM backend to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="Qwen/Qwen3-0.6B",
|
||||||
|
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:11434",
|
||||||
|
help="The host for the Ollama API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./test_doc_files",
|
||||||
|
help="Directory where the Leann index will be stored.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="examples/data",
|
||||||
|
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(main(args))
|
||||||
@@ -5,21 +5,24 @@ It correctly compares results by fetching the text content for both the new sear
|
|||||||
results and the golden standard results, making the comparison robust to ID changes.
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import argparse
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from typing import List
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher, LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
print(
|
||||||
|
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -60,7 +63,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||||
"""Download embeddings files specifically."""
|
"""Download embeddings files specifically."""
|
||||||
embeddings_dir = data_root / "embeddings"
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
@@ -98,7 +101,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = No
|
|||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||||
"""
|
"""
|
||||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
passage manager.
|
passage manager.
|
||||||
@@ -110,20 +113,24 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
|||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
golden_texts.add(passage_data["text"])
|
golden_texts.add(passage_data["text"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
print(
|
||||||
|
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||||
|
)
|
||||||
return golden_texts
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> list[str]:
|
def load_queries(file_path: Path) -> List[str]:
|
||||||
queries = []
|
queries = []
|
||||||
with open(file_path, encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
queries.append(data["query"])
|
queries.append(data["query"])
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
def build_index_from_embeddings(
|
||||||
|
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Build a LEANN index from pre-computed embeddings.
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
@@ -166,7 +173,9 @@ def build_index_from_embeddings(embeddings_file: str, output_path: str, backend:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run recall evaluation on a LEANN index."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"index_path",
|
"index_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -193,22 +202,26 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
)
|
)
|
||||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
parser.add_argument(
|
||||||
|
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
# Assumes a project structure where the script is in 'benchmarks/'
|
# Assumes a project structure where the script is in 'examples/'
|
||||||
# and evaluation data is in 'benchmarks/data/'.
|
# and data is in 'data/' at the project root.
|
||||||
script_dir = Path(__file__).resolve().parent
|
project_root = Path(__file__).resolve().parent.parent
|
||||||
data_root = script_dir / "data"
|
data_root = project_root / "data"
|
||||||
|
|
||||||
# Download data based on mode
|
# Download data based on mode
|
||||||
if args.mode == "build":
|
if args.mode == "build":
|
||||||
# For building mode, we need embeddings
|
# For building mode, we need embeddings
|
||||||
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
download_data_if_needed(
|
||||||
|
data_root, download_embeddings=False
|
||||||
|
) # Basic data first
|
||||||
|
|
||||||
# Auto-detect dataset type and download embeddings
|
# Auto-detect dataset type and download embeddings
|
||||||
if args.embeddings_file:
|
if args.embeddings_file:
|
||||||
@@ -249,7 +262,9 @@ def main():
|
|||||||
print(f"Index built successfully: {built_index_path}")
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
# Ask if user wants to run evaluation
|
# Ask if user wants to run evaluation
|
||||||
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
eval_response = (
|
||||||
|
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
|
)
|
||||||
if eval_response != "y":
|
if eval_response != "y":
|
||||||
print("Index building complete. Exiting.")
|
print("Index building complete. Exiting.")
|
||||||
return
|
return
|
||||||
@@ -278,9 +293,11 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not args.index_path:
|
if not args.index_path:
|
||||||
print("No indices found. The data download should have included pre-built indices.")
|
|
||||||
print(
|
print(
|
||||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
"No indices found. The data download should have included pre-built indices."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Please check the data/indices/ directory or provide --index-path manually."
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@@ -293,10 +310,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# Fallback: try to infer from the index directory name
|
# Fallback: try to infer from the index directory name
|
||||||
dataset_type = Path(args.index_path).name
|
dataset_type = Path(args.index_path).name
|
||||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
print(
|
||||||
|
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||||
|
)
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
golden_results_file = (
|
||||||
|
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
@@ -306,7 +327,7 @@ def main():
|
|||||||
searcher = LeannSearcher(args.index_path)
|
searcher = LeannSearcher(args.index_path)
|
||||||
queries = load_queries(queries_file)
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
with open(golden_results_file) as f:
|
with open(golden_results_file, "r") as f:
|
||||||
golden_results_data = json.load(f)
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
@@ -318,7 +339,9 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
new_results = searcher.search(
|
||||||
|
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||||
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
319
examples/wechat_history_reader_leann.py
Normal file
319
examples/wechat_history_reader_leann.py
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import dotenv
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Any, Optional
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Default WeChat export directory
|
||||||
|
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs: List[Path],
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from multiple WeChat export data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dirs: List of Path objects pointing to WeChat export directories
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process per export
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from multiple WeChat export data sources...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
# Process each WeChat export directory
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(
|
||||||
|
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=True, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
|
||||||
|
# Check if we've reached the max count
|
||||||
|
if max_count > 0 and total_processed >= max_count:
|
||||||
|
print(f"Reached max count of {max_count} documents")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index(
|
||||||
|
export_dir: str = None,
|
||||||
|
index_path: str = "wechat_history_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create LEANN index from WeChat chat history data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
||||||
|
index_path: Path to save the LEANN index
|
||||||
|
max_count: Maximum number of chat entries to process
|
||||||
|
"""
|
||||||
|
print("Creating LEANN index from WeChat chat history data...")
|
||||||
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
|
if not INDEX_DIR.exists():
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=export_dir,
|
||||||
|
max_count=max_count,
|
||||||
|
concatenate_messages=False, # Disable concatenation - one message per document
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No documents loaded. Exiting.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} chat documents")
|
||||||
|
|
||||||
|
# Create text splitter with 256 chunk size
|
||||||
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
|
# Convert Documents to text strings and chunk them
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
for node in nodes:
|
||||||
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
|
# Create LEANN index directory
|
||||||
|
print(f"--- Index directory not found, building new index ---")
|
||||||
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print(f"--- Building new LEANN index ---")
|
||||||
|
|
||||||
|
print(f"\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_compact=True,
|
||||||
|
is_recompute=True,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Adding {len(all_texts)} chat chunks to index...")
|
||||||
|
for chunk_text in all_texts:
|
||||||
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
|
else:
|
||||||
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
async def query_leann_index(index_path: str, query: str):
|
||||||
|
"""
|
||||||
|
Query the LEANN index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: The query string
|
||||||
|
"""
|
||||||
|
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
|
print(f"You: {query}")
|
||||||
|
chat_response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=20,
|
||||||
|
recompute_beighbor_embeddings=True,
|
||||||
|
complexity=16,
|
||||||
|
beam_width=1,
|
||||||
|
llm_config={
|
||||||
|
"type": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
|
)
|
||||||
|
print(f"Leann: {chat_response}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function with integrated WeChat export functionality."""
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_WECHAT_EXPORT_DIR,
|
||||||
|
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_history_magic_test_11Debug_new",
|
||||||
|
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Maximum number of chat entries to process (default: 5000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
INDEX_DIR = Path(args.index_dir)
|
||||||
|
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
||||||
|
|
||||||
|
print(f"Using WeChat export directory: {args.export_dir}")
|
||||||
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Exiting.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_wechat_exports(
|
||||||
|
export_dirs, INDEX_PATH, max_count=args.max_entries
|
||||||
|
)
|
||||||
|
|
||||||
|
if index_path:
|
||||||
|
if args.query:
|
||||||
|
# Run single query
|
||||||
|
await query_leann_index(index_path, args.query)
|
||||||
|
else:
|
||||||
|
# Example queries
|
||||||
|
queries = [
|
||||||
|
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
|
|||||||
8
packages/leann-backend-diskann/CMakeLists.txt
Normal file
8
packages/leann-backend-diskann/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# packages/leann-backend-diskann/CMakeLists.txt (simplified version)
|
||||||
|
|
||||||
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
project(leann_backend_diskann_wrapper)
|
||||||
|
|
||||||
|
# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt
|
||||||
|
# DiskANN will handle everything itself, including compiling Python bindings
|
||||||
|
add_subdirectory(src/third_party/DiskANN)
|
||||||
@@ -1,7 +1 @@
|
|||||||
from . import diskann_backend as diskann_backend
|
from . import diskann_backend
|
||||||
from . import graph_partition
|
|
||||||
|
|
||||||
# Export main classes and functions
|
|
||||||
from .graph_partition import GraphPartitioner, partition_graph
|
|
||||||
|
|
||||||
__all__ = ["GraphPartitioner", "diskann_backend", "graph_partition", "partition_graph"]
|
|
||||||
@@ -1,20 +1,20 @@
|
|||||||
import contextlib
|
import numpy as np
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import numpy as np
|
import logging
|
||||||
import psutil
|
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
from leann.registry import register_backend
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendBuilderInterface,
|
|
||||||
LeannBackendFactoryInterface,
|
LeannBackendFactoryInterface,
|
||||||
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
from leann.registry import register_backend
|
|
||||||
from leann.searcher_base import BaseSearcher
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,11 +22,6 @@ logger = logging.getLogger(__name__)
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def suppress_cpp_output_if_needed():
|
def suppress_cpp_output_if_needed():
|
||||||
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
# In CI we avoid fiddling with low-level file descriptors to prevent aborts
|
|
||||||
if os.getenv("CI") == "true":
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
@@ -90,43 +85,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
|
|||||||
f.write(data.tobytes())
|
f.write(data.tobytes())
|
||||||
|
|
||||||
|
|
||||||
def _calculate_smart_memory_config(data: np.ndarray) -> tuple[float, float]:
|
|
||||||
"""
|
|
||||||
Calculate smart memory configuration for DiskANN based on data size and system specs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: The embedding data array
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (search_memory_maximum, build_memory_maximum) in GB
|
|
||||||
"""
|
|
||||||
num_vectors, dim = data.shape
|
|
||||||
|
|
||||||
# Calculate embedding storage size
|
|
||||||
embedding_size_bytes = num_vectors * dim * 4 # float32 = 4 bytes
|
|
||||||
embedding_size_gb = embedding_size_bytes / (1024**3)
|
|
||||||
|
|
||||||
# search_memory_maximum: 1/10 of embedding size for optimal PQ compression
|
|
||||||
# This controls Product Quantization size - smaller means more compression
|
|
||||||
search_memory_gb = max(0.1, embedding_size_gb / 10) # At least 100MB
|
|
||||||
|
|
||||||
# build_memory_maximum: Based on available system RAM for sharding control
|
|
||||||
# This controls how much memory DiskANN uses during index construction
|
|
||||||
available_memory_gb = psutil.virtual_memory().available / (1024**3)
|
|
||||||
total_memory_gb = psutil.virtual_memory().total / (1024**3)
|
|
||||||
|
|
||||||
# Use 50% of available memory, but at least 2GB and at most 75% of total
|
|
||||||
build_memory_gb = max(2.0, min(available_memory_gb * 0.5, total_memory_gb * 0.75))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Smart memory config - Data: {embedding_size_gb:.2f}GB, "
|
|
||||||
f"Search mem: {search_memory_gb:.2f}GB (PQ control), "
|
|
||||||
f"Build mem: {build_memory_gb:.2f}GB (sharding control)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return search_memory_gb, build_memory_gb
|
|
||||||
|
|
||||||
|
|
||||||
@register_backend("diskann")
|
@register_backend("diskann")
|
||||||
class DiskannBackend(LeannBackendFactoryInterface):
|
class DiskannBackend(LeannBackendFactoryInterface):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -142,72 +100,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def _safe_cleanup_after_partition(self, index_dir: Path, index_prefix: str):
|
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||||
"""
|
|
||||||
Safely cleanup files after partition.
|
|
||||||
In partition mode, C++ doesn't read _disk.index content,
|
|
||||||
so we can delete it if all derived files exist.
|
|
||||||
"""
|
|
||||||
disk_index_file = index_dir / f"{index_prefix}_disk.index"
|
|
||||||
beam_search_file = index_dir / f"{index_prefix}_disk_beam_search.index"
|
|
||||||
|
|
||||||
# Required files that C++ partition mode needs
|
|
||||||
# Note: C++ generates these with _disk.index suffix
|
|
||||||
disk_suffix = "_disk.index"
|
|
||||||
required_files = [
|
|
||||||
f"{index_prefix}{disk_suffix}_medoids.bin", # Critical: assert fails if missing
|
|
||||||
# Note: _centroids.bin is not created in single-shot build - C++ handles this automatically
|
|
||||||
f"{index_prefix}_pq_pivots.bin", # PQ table
|
|
||||||
f"{index_prefix}_pq_compressed.bin", # PQ compressed vectors
|
|
||||||
]
|
|
||||||
|
|
||||||
# Check if all required files exist
|
|
||||||
missing_files = []
|
|
||||||
for filename in required_files:
|
|
||||||
file_path = index_dir / filename
|
|
||||||
if not file_path.exists():
|
|
||||||
missing_files.append(filename)
|
|
||||||
|
|
||||||
if missing_files:
|
|
||||||
logger.warning(
|
|
||||||
f"Cannot safely delete _disk.index - missing required files: {missing_files}"
|
|
||||||
)
|
|
||||||
logger.info("Keeping all original files for safety")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Calculate space savings
|
|
||||||
space_saved = 0
|
|
||||||
files_to_delete = []
|
|
||||||
|
|
||||||
if disk_index_file.exists():
|
|
||||||
space_saved += disk_index_file.stat().st_size
|
|
||||||
files_to_delete.append(disk_index_file)
|
|
||||||
|
|
||||||
if beam_search_file.exists():
|
|
||||||
space_saved += beam_search_file.stat().st_size
|
|
||||||
files_to_delete.append(beam_search_file)
|
|
||||||
|
|
||||||
# Safe to delete!
|
|
||||||
for file_to_delete in files_to_delete:
|
|
||||||
try:
|
|
||||||
os.remove(file_to_delete)
|
|
||||||
logger.info(f"✅ Safely deleted: {file_to_delete.name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete {file_to_delete.name}: {e}")
|
|
||||||
|
|
||||||
if space_saved > 0:
|
|
||||||
space_saved_mb = space_saved / (1024 * 1024)
|
|
||||||
logger.info(f"💾 Space saved: {space_saved_mb:.1f} MB")
|
|
||||||
|
|
||||||
# Show what files are kept
|
|
||||||
logger.info("📁 Kept essential files for partition mode:")
|
|
||||||
for filename in required_files:
|
|
||||||
file_path = index_dir / filename
|
|
||||||
if file_path.exists():
|
|
||||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
|
||||||
logger.info(f" - {filename} ({size_mb:.1f} MB)")
|
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
@@ -221,17 +114,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||||
|
|
||||||
build_kwargs = {**self.build_params, **kwargs}
|
build_kwargs = {**self.build_params, **kwargs}
|
||||||
|
|
||||||
# Extract is_recompute from nested backend_kwargs if needed
|
|
||||||
is_recompute = build_kwargs.get("is_recompute", False)
|
|
||||||
if not is_recompute and "backend_kwargs" in build_kwargs:
|
|
||||||
is_recompute = build_kwargs["backend_kwargs"].get("is_recompute", False)
|
|
||||||
|
|
||||||
# Flatten all backend_kwargs parameters to top level for compatibility
|
|
||||||
if "backend_kwargs" in build_kwargs:
|
|
||||||
nested_params = build_kwargs.pop("backend_kwargs")
|
|
||||||
build_kwargs.update(nested_params)
|
|
||||||
|
|
||||||
metric_enum = _get_diskann_metrics().get(
|
metric_enum = _get_diskann_metrics().get(
|
||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
@@ -240,16 +122,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate smart memory configuration if not explicitly provided
|
|
||||||
if (
|
|
||||||
"search_memory_maximum" not in build_kwargs
|
|
||||||
or "build_memory_maximum" not in build_kwargs
|
|
||||||
):
|
|
||||||
smart_search_mem, smart_build_mem = _calculate_smart_memory_config(data)
|
|
||||||
else:
|
|
||||||
smart_search_mem = build_kwargs.get("search_memory_maximum", 4.0)
|
|
||||||
smart_build_mem = build_kwargs.get("build_memory_maximum", 8.0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
|
|
||||||
@@ -260,36 +132,12 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
index_prefix,
|
index_prefix,
|
||||||
build_kwargs.get("complexity", 64),
|
build_kwargs.get("complexity", 64),
|
||||||
build_kwargs.get("graph_degree", 32),
|
build_kwargs.get("graph_degree", 32),
|
||||||
build_kwargs.get("search_memory_maximum", smart_search_mem),
|
build_kwargs.get("search_memory_maximum", 4.0),
|
||||||
build_kwargs.get("build_memory_maximum", smart_build_mem),
|
build_kwargs.get("build_memory_maximum", 8.0),
|
||||||
build_kwargs.get("num_threads", 8),
|
build_kwargs.get("num_threads", 8),
|
||||||
build_kwargs.get("pq_disk_bytes", 0),
|
build_kwargs.get("pq_disk_bytes", 0),
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Auto-partition if is_recompute is enabled
|
|
||||||
if build_kwargs.get("is_recompute", False):
|
|
||||||
logger.info("is_recompute=True, starting automatic graph partitioning...")
|
|
||||||
from .graph_partition import partition_graph
|
|
||||||
|
|
||||||
# Partition the index using absolute paths
|
|
||||||
# Convert to absolute paths to avoid issues with working directory changes
|
|
||||||
absolute_index_dir = Path(index_dir).resolve()
|
|
||||||
absolute_index_prefix_path = str(absolute_index_dir / index_prefix)
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph(
|
|
||||||
index_prefix_path=absolute_index_prefix_path,
|
|
||||||
output_dir=str(absolute_index_dir),
|
|
||||||
partition_prefix=index_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Safe cleanup: In partition mode, C++ doesn't read _disk.index content
|
|
||||||
# but still needs the derived files (_medoids.bin, _centroids.bin, etc.)
|
|
||||||
self._safe_cleanup_after_partition(index_dir, index_prefix)
|
|
||||||
|
|
||||||
logger.info("✅ Graph partitioning completed successfully!")
|
|
||||||
logger.info(f" - Disk graph: {disk_graph_path}")
|
|
||||||
logger.info(f" - Partition file: {partition_bin_path}")
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
@@ -316,69 +164,18 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
self.num_threads = kwargs.get("num_threads", 8)
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
|
|
||||||
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
fake_zmq_port = 6666
|
||||||
# Store the initialization parameters for later use
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
# Note: C++ load method expects the BASE path (without _disk.index suffix)
|
self._index = diskannpy.StaticDiskFloatIndex(
|
||||||
# C++ internally constructs: index_prefix + "_disk.index"
|
metric_enum,
|
||||||
index_name = self.index_path.stem # "simple_test.leann" -> "simple_test"
|
full_index_prefix,
|
||||||
diskann_index_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
self.num_threads,
|
||||||
full_index_prefix = diskann_index_prefix # /path/to/simple_test (base path)
|
kwargs.get("num_nodes_to_cache", 0),
|
||||||
|
1,
|
||||||
# Auto-detect partition files and set partition_prefix
|
fake_zmq_port, # Initial port, can be updated at runtime
|
||||||
partition_graph_file = self.index_dir / f"{index_name}_disk_graph.index"
|
"",
|
||||||
partition_bin_file = self.index_dir / f"{index_name}_partition.bin"
|
"",
|
||||||
|
)
|
||||||
partition_prefix = ""
|
|
||||||
if partition_graph_file.exists() and partition_bin_file.exists():
|
|
||||||
# C++ expects full path prefix, not just filename
|
|
||||||
partition_prefix = str(self.index_dir / index_name) # /path/to/simple_test
|
|
||||||
logger.info(
|
|
||||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("No partition files detected, using standard index files")
|
|
||||||
|
|
||||||
self._init_params = {
|
|
||||||
"metric_enum": metric_enum,
|
|
||||||
"full_index_prefix": full_index_prefix,
|
|
||||||
"num_threads": self.num_threads,
|
|
||||||
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
|
||||||
"cache_mechanism": 1,
|
|
||||||
"pq_prefix": "",
|
|
||||||
"partition_prefix": partition_prefix,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log partition configuration for debugging
|
|
||||||
if partition_prefix:
|
|
||||||
logger.info(
|
|
||||||
f"✅ Detected partition files, using partition_prefix='{partition_prefix}'"
|
|
||||||
)
|
|
||||||
self._diskannpy = diskannpy
|
|
||||||
self._current_zmq_port = None
|
|
||||||
self._index = None
|
|
||||||
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
|
||||||
|
|
||||||
def _ensure_index_loaded(self, zmq_port: int):
|
|
||||||
"""Ensure the index is loaded with the correct zmq_port."""
|
|
||||||
if self._index is None or self._current_zmq_port != zmq_port:
|
|
||||||
# Need to (re)load the index with the correct zmq_port
|
|
||||||
with suppress_cpp_output_if_needed():
|
|
||||||
if self._index is not None:
|
|
||||||
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
|
||||||
|
|
||||||
self._index = self._diskannpy.StaticDiskFloatIndex(
|
|
||||||
self._init_params["metric_enum"],
|
|
||||||
self._init_params["full_index_prefix"],
|
|
||||||
self._init_params["num_threads"],
|
|
||||||
self._init_params["num_nodes_to_cache"],
|
|
||||||
self._init_params["cache_mechanism"],
|
|
||||||
zmq_port,
|
|
||||||
self._init_params["pq_prefix"],
|
|
||||||
self._init_params["partition_prefix"],
|
|
||||||
)
|
|
||||||
self._current_zmq_port = zmq_port
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -393,7 +190,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for nearest neighbors using DiskANN index.
|
Search for nearest neighbors using DiskANN index.
|
||||||
|
|
||||||
@@ -416,15 +213,18 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
"""
|
"""
|
||||||
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError(
|
||||||
self._ensure_index_loaded(zmq_port)
|
"zmq_port must be provided if recompute_embeddings is True"
|
||||||
else:
|
)
|
||||||
# If not recomputing, we still need an index, use a default port
|
current_port = self._index.get_zmq_port()
|
||||||
if self._index is None:
|
if zmq_port != current_port:
|
||||||
self._ensure_index_loaded(6666) # Default port when not recomputing
|
logger.debug(
|
||||||
|
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||||
|
)
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
# DiskANN doesn't support "proportional" strategy
|
# DiskANN doesn't support "proportional" strategy
|
||||||
if pruning_strategy == "proportional":
|
if pruning_strategy == "proportional":
|
||||||
@@ -441,14 +241,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
else: # "global"
|
else: # "global"
|
||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
# Strategy:
|
# Perform search with suppressed C++ output based on log level
|
||||||
# - Traversal always uses PQ distances
|
|
||||||
# - If recompute_embeddings=True, do a single final rerank via deferred fetch
|
|
||||||
# (fetch embeddings for the final candidate set only)
|
|
||||||
# - Do not recompute neighbor distances along the path
|
|
||||||
use_deferred_fetch = True if recompute_embeddings else False
|
|
||||||
recompute_neighors = False # Expected typo. For backward compatibility.
|
|
||||||
|
|
||||||
with suppress_cpp_output_if_needed():
|
with suppress_cpp_output_if_needed():
|
||||||
labels, distances = self._index.batch_search(
|
labels, distances = self._index.batch_search(
|
||||||
query,
|
query,
|
||||||
@@ -457,15 +250,17 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
complexity,
|
complexity,
|
||||||
beam_width,
|
beam_width,
|
||||||
self.num_threads,
|
self.num_threads,
|
||||||
use_deferred_fetch,
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
kwargs.get("skip_search_reorder", False),
|
kwargs.get("skip_search_reorder", False),
|
||||||
recompute_neighors,
|
recompute_embeddings,
|
||||||
dedup_node_dis,
|
dedup_node_dis,
|
||||||
prune_ratio,
|
prune_ratio,
|
||||||
batch_recompute,
|
batch_recompute,
|
||||||
use_global_pruning,
|
use_global_pruning,
|
||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [
|
||||||
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -3,17 +3,16 @@ DiskANN-specific embedding server
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import sys
|
||||||
import numpy as np
|
import logging
|
||||||
import zmq
|
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -37,7 +36,6 @@ def create_diskann_embedding_server(
|
|||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
distance_metric: str = "l2",
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||||
@@ -52,8 +50,8 @@ def create_diskann_embedding_server(
|
|||||||
sys.path.insert(0, str(leann_core_path))
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from leann.api import PassageManager
|
|
||||||
from leann.embedding_compute import compute_embeddings
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.api import PassageManager
|
||||||
|
|
||||||
logger.info("Successfully imported unified embedding computation module")
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -78,11 +76,10 @@ def create_diskann_embedding_server(
|
|||||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
# Load metadata to get passage sources
|
# Load metadata to get passage sources
|
||||||
with open(passages_file) as f:
|
with open(passages_file, "r") as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
passages = PassageManager(meta["passage_sources"])
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
@@ -103,9 +100,8 @@ def create_diskann_embedding_server(
|
|||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 1000)
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 1000)
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
socket.setsockopt(zmq.LINGER, 0)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -154,7 +150,9 @@ def create_diskann_embedding_server(
|
|||||||
):
|
):
|
||||||
texts = request
|
texts = request
|
||||||
is_text_request = True
|
is_text_request = True
|
||||||
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
logger.info(
|
||||||
|
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not a valid msgpack text request")
|
raise ValueError("Not a valid msgpack text request")
|
||||||
except Exception as msgpack_error:
|
except Exception as msgpack_error:
|
||||||
@@ -169,7 +167,9 @@ def create_diskann_embedding_server(
|
|||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data["text"]
|
||||||
if not txt:
|
if not txt:
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
raise RuntimeError(
|
||||||
|
f"FATAL: Empty text for passage ID {nid}"
|
||||||
|
)
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(f"Passage ID {nid} not found: {e}")
|
logger.error(f"Passage ID {nid} not found: {e}")
|
||||||
@@ -180,7 +180,9 @@ def create_diskann_embedding_server(
|
|||||||
|
|
||||||
# Debug logging
|
# Debug logging
|
||||||
logger.debug(f"Processing {len(texts)} texts")
|
logger.debug(f"Processing {len(texts)} texts")
|
||||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
logger.debug(
|
||||||
|
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||||
|
) # Show first 5
|
||||||
|
|
||||||
# Process embeddings using unified computation
|
# Process embeddings using unified computation
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
@@ -197,7 +199,9 @@ def create_diskann_embedding_server(
|
|||||||
else:
|
else:
|
||||||
# For DiskANN C++ compatibility: return protobuf format
|
# For DiskANN C++ compatibility: return protobuf format
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
hidden_contiguous = np.ascontiguousarray(
|
||||||
|
embeddings, dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
# Serialize embeddings data
|
# Serialize embeddings data
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
@@ -222,217 +226,30 @@ def create_diskann_embedding_server(
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
|
||||||
|
|
||||||
This creates its own REP socket, binds to zmq_port, and periodically
|
|
||||||
checks shutdown_event using recv timeouts to exit cleanly.
|
|
||||||
"""
|
|
||||||
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
|
||||||
|
|
||||||
context = zmq.Context()
|
|
||||||
rep_socket = context.socket(zmq.REP)
|
|
||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
|
||||||
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
|
||||||
|
|
||||||
# Set receive timeout so we can check shutdown_event periodically
|
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
try:
|
|
||||||
e2e_start = time.time()
|
|
||||||
# REP socket receives single-part messages
|
|
||||||
message = rep_socket.recv()
|
|
||||||
|
|
||||||
# Check for empty messages - REP socket requires response to every request
|
|
||||||
if not message:
|
|
||||||
logger.warning("Received empty message, sending empty response")
|
|
||||||
rep_socket.send(b"")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try protobuf first (same logic as original)
|
|
||||||
texts = []
|
|
||||||
is_text_request = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.ParseFromString(message)
|
|
||||||
node_ids = list(req_proto.node_ids)
|
|
||||||
|
|
||||||
# Look up texts by node IDs
|
|
||||||
for nid in node_ids:
|
|
||||||
try:
|
|
||||||
passage_data = passages.get_passage(str(nid))
|
|
||||||
txt = passage_data["text"]
|
|
||||||
if not txt:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
|
|
||||||
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
|
||||||
except Exception:
|
|
||||||
# Fallback to msgpack for text requests
|
|
||||||
try:
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
request = msgpack.unpackb(message)
|
|
||||||
if isinstance(request, list) and all(
|
|
||||||
isinstance(item, str) for item in request
|
|
||||||
):
|
|
||||||
texts = request
|
|
||||||
is_text_request = True
|
|
||||||
logger.info(
|
|
||||||
f"ZMQ received msgpack text request for {len(texts)} texts"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Not a valid msgpack text request")
|
|
||||||
except Exception:
|
|
||||||
logger.error("Both protobuf and msgpack parsing failed!")
|
|
||||||
# Send error response
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
rep_socket.send(resp_proto.SerializeToString())
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Process the request
|
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
|
||||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
|
||||||
|
|
||||||
# Validation
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
|
||||||
logger.error("NaN or Inf detected in embeddings!")
|
|
||||||
# Send error response
|
|
||||||
if is_text_request:
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
response_data = msgpack.packb([])
|
|
||||||
else:
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
rep_socket.send(response_data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prepare response based on request type
|
|
||||||
if is_text_request:
|
|
||||||
# For direct text requests, return msgpack
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
response_data = msgpack.packb(embeddings.tolist())
|
|
||||||
else:
|
|
||||||
# For protobuf requests, return protobuf
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
|
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
|
|
||||||
# Send response back to the client
|
|
||||||
rep_socket.send(response_data)
|
|
||||||
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
try:
|
|
||||||
# Send error response for REP socket
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
rep_socket.send(resp_proto.SerializeToString())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
rep_socket.close(0)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
context.term()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
|
||||||
|
|
||||||
# Add shutdown coordination
|
|
||||||
shutdown_event = threading.Event()
|
|
||||||
|
|
||||||
def shutdown_zmq_server():
|
|
||||||
"""Gracefully shutdown ZMQ server."""
|
|
||||||
logger.info("Initiating graceful shutdown...")
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
if zmq_thread.is_alive():
|
|
||||||
logger.info("Waiting for ZMQ thread to finish...")
|
|
||||||
zmq_thread.join(timeout=5)
|
|
||||||
if zmq_thread.is_alive():
|
|
||||||
logger.warning("ZMQ thread did not finish in time")
|
|
||||||
|
|
||||||
# Clean up ZMQ resources
|
|
||||||
try:
|
|
||||||
# Note: socket and context are cleaned up by thread exit
|
|
||||||
logger.info("ZMQ resources cleaned up")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
|
||||||
|
|
||||||
# Clean up other resources
|
|
||||||
try:
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
logger.info("Additional resources cleaned up")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error cleaning additional resources: {e}")
|
|
||||||
|
|
||||||
logger.info("Graceful shutdown completed")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Register signal handlers within this function scope
|
|
||||||
import signal
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
shutdown_zmq_server()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Start ZMQ thread (NOT daemon!)
|
|
||||||
zmq_thread = threading.Thread(
|
|
||||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
|
||||||
daemon=False, # Not daemon - we want to wait for it
|
|
||||||
)
|
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while True:
|
||||||
time.sleep(0.1) # Check shutdown more frequently
|
time.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("DiskANN Server shutting down...")
|
logger.info("DiskANN Server shutting down...")
|
||||||
shutdown_zmq_server()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If we reach here, shutdown was triggered by signal
|
|
||||||
logger.info("Main loop exited, process should be shutting down")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Signal handlers are now registered within create_diskann_embedding_server
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
@@ -451,16 +268,9 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--distance-metric",
|
|
||||||
type=str,
|
|
||||||
default="l2",
|
|
||||||
choices=["l2", "mips", "cosine"],
|
|
||||||
help="Distance metric for similarity computation",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -470,5 +280,4 @@ if __name__ == "__main__":
|
|||||||
zmq_port=args.zmq_port,
|
zmq_port=args.zmq_port,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
distance_metric=args.distance_metric,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,28 +1,27 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: embedding.proto
|
# source: embedding.proto
|
||||||
# ruff: noqa
|
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
|
|
||||||
# @@protoc_insertion_point(imports)
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
|
||||||
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
|
||||||
)
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
DESCRIPTOR._options = None
|
|
||||||
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
DESCRIPTOR._options = None
|
||||||
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
_NODEEMBEDDINGREQUEST._serialized_start=35
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
_NODEEMBEDDINGREQUEST._serialized_end=75
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
||||||
|
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@@ -1,299 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Graph Partition Module for LEANN DiskANN Backend
|
|
||||||
|
|
||||||
This module provides Python bindings for the graph partition functionality
|
|
||||||
of DiskANN, allowing users to partition disk-based indices for better
|
|
||||||
performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class GraphPartitioner:
|
|
||||||
"""
|
|
||||||
A Python interface for DiskANN's graph partition functionality.
|
|
||||||
|
|
||||||
This class provides methods to partition disk-based indices for improved
|
|
||||||
search performance and memory efficiency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, build_type: str = "release"):
|
|
||||||
"""
|
|
||||||
Initialize the GraphPartitioner.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
build_type: Build type for the executables ("debug" or "release")
|
|
||||||
"""
|
|
||||||
self.build_type = build_type
|
|
||||||
self._ensure_executables()
|
|
||||||
|
|
||||||
def _get_executable_path(self, name: str) -> str:
|
|
||||||
"""Get the path to a graph partition executable."""
|
|
||||||
# Get the directory where this Python module is located
|
|
||||||
module_dir = Path(__file__).parent
|
|
||||||
# Navigate to the graph_partition directory
|
|
||||||
graph_partition_dir = module_dir.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
executable_path = graph_partition_dir / "build" / self.build_type / "graph_partition" / name
|
|
||||||
|
|
||||||
if not executable_path.exists():
|
|
||||||
raise FileNotFoundError(f"Executable {name} not found at {executable_path}")
|
|
||||||
|
|
||||||
return str(executable_path)
|
|
||||||
|
|
||||||
def _ensure_executables(self):
|
|
||||||
"""Ensure that the required executables are built."""
|
|
||||||
try:
|
|
||||||
self._get_executable_path("partitioner")
|
|
||||||
self._get_executable_path("index_relayout")
|
|
||||||
except FileNotFoundError:
|
|
||||||
# Try to build the executables automatically
|
|
||||||
print("Executables not found, attempting to build them...")
|
|
||||||
self._build_executables()
|
|
||||||
|
|
||||||
def _build_executables(self):
|
|
||||||
"""Build the required executables."""
|
|
||||||
graph_partition_dir = (
|
|
||||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
)
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.chdir(graph_partition_dir)
|
|
||||||
|
|
||||||
# Clean any existing build
|
|
||||||
if (graph_partition_dir / "build").exists():
|
|
||||||
shutil.rmtree(graph_partition_dir / "build")
|
|
||||||
|
|
||||||
# Run the build script
|
|
||||||
cmd = ["./build.sh", self.build_type, "split_graph", "/tmp/dummy"]
|
|
||||||
subprocess.run(cmd, capture_output=True, text=True, cwd=graph_partition_dir)
|
|
||||||
|
|
||||||
# Check if executables were created
|
|
||||||
partitioner_path = self._get_executable_path("partitioner")
|
|
||||||
relayout_path = self._get_executable_path("index_relayout")
|
|
||||||
|
|
||||||
print(f"✅ Built partitioner: {partitioner_path}")
|
|
||||||
print(f"✅ Built index_relayout: {relayout_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to build executables: {e}")
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
|
|
||||||
def partition_graph(
|
|
||||||
self,
|
|
||||||
index_prefix_path: str,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
partition_prefix: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Partition a disk-based index for improved performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix (e.g., "/path/to/index")
|
|
||||||
output_dir: Output directory for results (defaults to parent of index_prefix_path)
|
|
||||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
|
||||||
**kwargs: Additional parameters for graph partitioning:
|
|
||||||
- gp_times: Number of LDG partition iterations (default: 10)
|
|
||||||
- lock_nums: Number of lock nodes (default: 10)
|
|
||||||
- cut: Cut adjacency list degree (default: 100)
|
|
||||||
- scale_factor: Scale factor (default: 1)
|
|
||||||
- data_type: Data type (default: "float")
|
|
||||||
- thread_nums: Number of threads (default: 10)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the partitioning process fails
|
|
||||||
"""
|
|
||||||
# Set default parameters
|
|
||||||
params = {
|
|
||||||
"gp_times": 10,
|
|
||||||
"lock_nums": 10,
|
|
||||||
"cut": 100,
|
|
||||||
"scale_factor": 1,
|
|
||||||
"data_type": "float",
|
|
||||||
"thread_nums": 10,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Determine output directory
|
|
||||||
if output_dir is None:
|
|
||||||
output_dir = str(Path(index_prefix_path).parent)
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Determine partition prefix
|
|
||||||
if partition_prefix is None:
|
|
||||||
partition_prefix = Path(index_prefix_path).name
|
|
||||||
|
|
||||||
# Get executable paths
|
|
||||||
partitioner_path = self._get_executable_path("partitioner")
|
|
||||||
relayout_path = self._get_executable_path("index_relayout")
|
|
||||||
|
|
||||||
# Create temporary directory for processing
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
# Change to the graph_partition directory for temporary files
|
|
||||||
graph_partition_dir = (
|
|
||||||
Path(__file__).parent.parent / "third_party" / "DiskANN" / "graph_partition"
|
|
||||||
)
|
|
||||||
original_dir = os.getcwd()
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.chdir(graph_partition_dir)
|
|
||||||
|
|
||||||
# Create temporary data directory
|
|
||||||
temp_data_dir = Path(temp_dir) / "data"
|
|
||||||
temp_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Set up paths for temporary files
|
|
||||||
graph_path = temp_data_dir / "starling" / "_M_R_L_B" / "GRAPH"
|
|
||||||
graph_gp_path = (
|
|
||||||
graph_path
|
|
||||||
/ f"GP_TIMES_{params['gp_times']}_LOCK_{params['lock_nums']}_GP_USE_FREQ0_CUT{params['cut']}_SCALE{params['scale_factor']}"
|
|
||||||
)
|
|
||||||
graph_gp_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Find input index file
|
|
||||||
old_index_file = f"{index_prefix_path}_disk_beam_search.index"
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
old_index_file = f"{index_prefix_path}_disk.index"
|
|
||||||
|
|
||||||
if not os.path.exists(old_index_file):
|
|
||||||
raise RuntimeError(f"Index file not found: {old_index_file}")
|
|
||||||
|
|
||||||
# Run partitioner
|
|
||||||
gp_file_path = graph_gp_path / "_part.bin"
|
|
||||||
partitioner_cmd = [
|
|
||||||
partitioner_path,
|
|
||||||
"--index_file",
|
|
||||||
old_index_file,
|
|
||||||
"--data_type",
|
|
||||||
params["data_type"],
|
|
||||||
"--gp_file",
|
|
||||||
str(gp_file_path),
|
|
||||||
"-T",
|
|
||||||
str(params["thread_nums"]),
|
|
||||||
"--ldg_times",
|
|
||||||
str(params["gp_times"]),
|
|
||||||
"--scale",
|
|
||||||
str(params["scale_factor"]),
|
|
||||||
"--mode",
|
|
||||||
"1",
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Running partitioner: {' '.join(partitioner_cmd)}")
|
|
||||||
result = subprocess.run(
|
|
||||||
partitioner_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Partitioner failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run relayout
|
|
||||||
part_tmp_index = graph_gp_path / "_part_tmp.index"
|
|
||||||
relayout_cmd = [
|
|
||||||
relayout_path,
|
|
||||||
old_index_file,
|
|
||||||
str(gp_file_path),
|
|
||||||
params["data_type"],
|
|
||||||
"1",
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"Running relayout: {' '.join(relayout_cmd)}")
|
|
||||||
result = subprocess.run(
|
|
||||||
relayout_cmd, capture_output=True, text=True, cwd=graph_partition_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Relayout failed with return code {result.returncode}.\n"
|
|
||||||
f"stdout: {result.stdout}\n"
|
|
||||||
f"stderr: {result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy results to output directory
|
|
||||||
disk_graph_path = Path(output_dir) / f"{partition_prefix}_disk_graph.index"
|
|
||||||
partition_bin_path = Path(output_dir) / f"{partition_prefix}_partition.bin"
|
|
||||||
|
|
||||||
shutil.copy2(part_tmp_index, disk_graph_path)
|
|
||||||
shutil.copy2(gp_file_path, partition_bin_path)
|
|
||||||
|
|
||||||
print(f"Results copied to: {output_dir}")
|
|
||||||
return str(disk_graph_path), str(partition_bin_path)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
os.chdir(original_dir)
|
|
||||||
|
|
||||||
def get_partition_info(self, partition_bin_path: str) -> dict:
|
|
||||||
"""
|
|
||||||
Get information about a partition file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
partition_bin_path: Path to the partition binary file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing partition information
|
|
||||||
"""
|
|
||||||
if not os.path.exists(partition_bin_path):
|
|
||||||
raise FileNotFoundError(f"Partition file not found: {partition_bin_path}")
|
|
||||||
|
|
||||||
# For now, return basic file information
|
|
||||||
# In the future, this could parse the binary file for detailed info
|
|
||||||
stat = os.stat(partition_bin_path)
|
|
||||||
return {
|
|
||||||
"file_size": stat.st_size,
|
|
||||||
"file_path": partition_bin_path,
|
|
||||||
"modified_time": stat.st_mtime,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def partition_graph(
|
|
||||||
index_prefix_path: str,
|
|
||||||
output_dir: Optional[str] = None,
|
|
||||||
partition_prefix: Optional[str] = None,
|
|
||||||
build_type: str = "release",
|
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Convenience function to partition a graph index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_prefix_path: Path to the index prefix
|
|
||||||
output_dir: Output directory (defaults to parent of index_prefix_path)
|
|
||||||
partition_prefix: Prefix for output files (defaults to basename of index_prefix_path)
|
|
||||||
build_type: Build type for executables ("debug" or "release")
|
|
||||||
**kwargs: Additional parameters for graph partitioning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (disk_graph_index_path, partition_bin_path)
|
|
||||||
"""
|
|
||||||
partitioner = GraphPartitioner(build_type=build_type)
|
|
||||||
return partitioner.partition_graph(index_prefix_path, output_dir, partition_prefix, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Example: partition an index
|
|
||||||
try:
|
|
||||||
disk_graph_path, partition_bin_path = partition_graph(
|
|
||||||
"/path/to/your/index_prefix", gp_times=10, lock_nums=10, cut=100
|
|
||||||
)
|
|
||||||
print("Partitioning completed successfully!")
|
|
||||||
print(f"Disk graph index: {disk_graph_path}")
|
|
||||||
print(f"Partition binary: {partition_bin_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Partitioning failed: {e}")
|
|
||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.2.9"
|
version = "0.1.0"
|
||||||
dependencies = ["leann-core==0.2.9", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
@@ -17,5 +17,3 @@ editable.mode = "redirect"
|
|||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
# Let CMake find packages via Homebrew prefix
|
|
||||||
cmake.define = {CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}, OpenMP_ROOT = {env = "OpenMP_ROOT"}}
|
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 04048bb302...25339b0341
@@ -5,28 +5,11 @@ set(CMAKE_CXX_COMPILER_WORKS 1)
|
|||||||
|
|
||||||
# Set OpenMP path for macOS
|
# Set OpenMP path for macOS
|
||||||
if(APPLE)
|
if(APPLE)
|
||||||
# Detect Homebrew installation path (Apple Silicon vs Intel)
|
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
if(EXISTS "/opt/homebrew/opt/libomp")
|
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include")
|
||||||
set(HOMEBREW_PREFIX "/opt/homebrew")
|
|
||||||
elseif(EXISTS "/usr/local/opt/libomp")
|
|
||||||
set(HOMEBREW_PREFIX "/usr/local")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "Could not find libomp installation. Please install with: brew install libomp")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(OpenMP_C_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
|
||||||
set(OpenMP_CXX_FLAGS "-Xpreprocessor -fopenmp -I${HOMEBREW_PREFIX}/opt/libomp/include")
|
|
||||||
set(OpenMP_C_LIB_NAMES "omp")
|
set(OpenMP_C_LIB_NAMES "omp")
|
||||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||||
set(OpenMP_omp_LIBRARY "${HOMEBREW_PREFIX}/opt/libomp/lib/libomp.dylib")
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
|
|
||||||
# Force use of system libc++ to avoid version mismatch
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
|
||||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
|
||||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
|
||||||
|
|
||||||
# Set minimum macOS version for better compatibility
|
|
||||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Use system ZeroMQ instead of building from source
|
# Use system ZeroMQ instead of building from source
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from . import hnsw_backend as hnsw_backend
|
from . import hnsw_backend
|
||||||
|
|||||||
@@ -1,122 +1,87 @@
|
|||||||
import argparse
|
|
||||||
import gc # Import garbage collector interface
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
# Set up logging to avoid print buffer issues
|
import argparse
|
||||||
logger = logging.getLogger(__name__)
|
import gc # Import garbage collector interface
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
import time
|
||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|
||||||
logger.setLevel(log_level)
|
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||||
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||||
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
||||||
|
|
||||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||||
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
|
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
||||||
|
|
||||||
# --- Helper functions for reading/writing binary data ---
|
# --- Helper functions for reading/writing binary data ---
|
||||||
|
|
||||||
|
|
||||||
def read_struct(f, fmt):
|
def read_struct(f, fmt):
|
||||||
"""Reads data according to the struct format."""
|
"""Reads data according to the struct format."""
|
||||||
size = struct.calcsize(fmt)
|
size = struct.calcsize(fmt)
|
||||||
data = f.read(size)
|
data = f.read(size)
|
||||||
if len(data) != size:
|
if len(data) != size:
|
||||||
raise EOFError(
|
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
|
||||||
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
|
|
||||||
)
|
|
||||||
return struct.unpack(fmt, data)[0]
|
return struct.unpack(fmt, data)[0]
|
||||||
|
|
||||||
|
|
||||||
def read_vector_raw(f, element_fmt_char):
|
def read_vector_raw(f, element_fmt_char):
|
||||||
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
||||||
count = -1 # Initialize count
|
count = -1 # Initialize count
|
||||||
total_bytes = -1 # Initialize total_bytes
|
total_bytes = -1 # Initialize total_bytes
|
||||||
try:
|
try:
|
||||||
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
|
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
||||||
element_size = struct.calcsize(element_fmt_char)
|
element_size = struct.calcsize(element_fmt_char)
|
||||||
# --- FIX for MemoryError: Check for unreasonably large count ---
|
# --- FIX for MemoryError: Check for unreasonably large count ---
|
||||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||||
if count > max_reasonable_count or count < 0:
|
if count > max_reasonable_count or count < 0:
|
||||||
raise MemoryError(
|
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
|
||||||
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
|
|
||||||
)
|
|
||||||
|
|
||||||
total_bytes = count * element_size
|
total_bytes = count * element_size
|
||||||
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
||||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||||
raise MemoryError(
|
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
|
||||||
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
|
|
||||||
)
|
|
||||||
|
|
||||||
data_bytes = f.read(total_bytes)
|
data_bytes = f.read(total_bytes)
|
||||||
|
|
||||||
if len(data_bytes) != total_bytes:
|
if len(data_bytes) != total_bytes:
|
||||||
raise EOFError(
|
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
|
||||||
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
|
|
||||||
)
|
|
||||||
return count, data_bytes
|
return count, data_bytes
|
||||||
except (MemoryError, OverflowError) as e:
|
except (MemoryError, OverflowError) as e:
|
||||||
# Add context to the error message
|
# Add context to the error message
|
||||||
print(
|
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
|
||||||
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
|
raise e # Re-raise the original error type
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise e # Re-raise the original error type
|
|
||||||
|
|
||||||
|
|
||||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||||
"""Reads a vector into a NumPy array."""
|
"""Reads a vector into a NumPy array."""
|
||||||
count = -1 # Initialize count for robust error handling
|
count = -1 # Initialize count for robust error handling
|
||||||
print(
|
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
|
||||||
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
|
||||||
end="",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||||
if count > 0 and len(data_bytes) > 0:
|
if count > 0 and len(data_bytes) > 0:
|
||||||
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
||||||
if arr.size != count:
|
if arr.size != count:
|
||||||
raise ValueError(
|
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
|
||||||
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
|
|
||||||
)
|
|
||||||
return arr
|
return arr
|
||||||
elif count == 0:
|
elif count == 0:
|
||||||
return np.array([], dtype=np_dtype)
|
return np.array([], dtype=np_dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Read zero bytes but count > 0.")
|
raise ValueError("Read zero bytes but count > 0.")
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
# Now count should be defined (or -1 if error was in read_struct)
|
# Now count should be defined (or -1 if error was in read_struct)
|
||||||
print(
|
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
|
||||||
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e: # Catch other potential errors like ValueError
|
except Exception as e: # Catch other potential errors like ValueError
|
||||||
print(
|
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
|
||||||
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def write_numpy_vector(f, arr, struct_fmt_char):
|
def write_numpy_vector(f, arr, struct_fmt_char):
|
||||||
"""Writes a NumPy array as a vector (size followed by data)."""
|
"""Writes a NumPy array as a vector (size followed by data)."""
|
||||||
count = arr.size
|
count = arr.size
|
||||||
f.write(struct.pack("<Q", count))
|
f.write(struct.pack('<Q', count))
|
||||||
try:
|
try:
|
||||||
expected_dtype = np.dtype(struct_fmt_char)
|
expected_dtype = np.dtype(struct_fmt_char)
|
||||||
if arr.dtype != expected_dtype:
|
if arr.dtype != expected_dtype:
|
||||||
@@ -124,30 +89,23 @@ def write_numpy_vector(f, arr, struct_fmt_char):
|
|||||||
else:
|
else:
|
||||||
data_to_write = arr.tobytes()
|
data_to_write = arr.tobytes()
|
||||||
f.write(data_to_write)
|
f.write(data_to_write)
|
||||||
del data_to_write # Hint GC
|
del data_to_write # Hint GC
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(
|
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
|
||||||
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
|
raise e
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def write_list_vector(f, lst, struct_fmt_char):
|
def write_list_vector(f, lst, struct_fmt_char):
|
||||||
"""Writes a Python list as a vector iteratively."""
|
"""Writes a Python list as a vector iteratively."""
|
||||||
count = len(lst)
|
count = len(lst)
|
||||||
f.write(struct.pack("<Q", count))
|
f.write(struct.pack('<Q', count))
|
||||||
fmt = "<" + struct_fmt_char
|
fmt = '<' + struct_fmt_char
|
||||||
chunk_size = 1024 * 1024
|
chunk_size = 1024 * 1024
|
||||||
element_size = struct.calcsize(fmt)
|
element_size = struct.calcsize(fmt)
|
||||||
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
||||||
try:
|
try:
|
||||||
buffer = bytearray(chunk_size * element_size)
|
buffer = bytearray(chunk_size * element_size)
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
print(
|
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
|
||||||
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
buffer_count = 0
|
buffer_count = 0
|
||||||
|
|
||||||
@@ -158,79 +116,65 @@ def write_list_vector(f, lst, struct_fmt_char):
|
|||||||
buffer_count += 1
|
buffer_count += 1
|
||||||
|
|
||||||
if buffer_count == chunk_size or i == count - 1:
|
if buffer_count == chunk_size or i == count - 1:
|
||||||
f.write(buffer[: buffer_count * element_size])
|
f.write(buffer[:buffer_count * element_size])
|
||||||
buffer_count = 0
|
buffer_count = 0
|
||||||
|
|
||||||
except struct.error as e:
|
except struct.error as e:
|
||||||
print(
|
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
|
||||||
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
||||||
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
||||||
if level < 0:
|
if level < 0: return 0
|
||||||
return 0
|
|
||||||
if level < len(cum_nneighbor_per_level_np):
|
if level < len(cum_nneighbor_per_level_np):
|
||||||
return cum_nneighbor_per_level_np[level]
|
return cum_nneighbor_per_level_np[level]
|
||||||
else:
|
else:
|
||||||
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
||||||
|
|
||||||
|
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||||
def write_compact_format(
|
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||||
f_out,
|
compact_neighbors_data, storage_fourcc, storage_data):
|
||||||
original_hnsw_data,
|
|
||||||
assign_probas_np,
|
|
||||||
cum_nneighbor_per_level_np,
|
|
||||||
levels_np,
|
|
||||||
compact_level_ptr,
|
|
||||||
compact_node_offsets_np,
|
|
||||||
compact_neighbors_data,
|
|
||||||
storage_fourcc,
|
|
||||||
storage_data,
|
|
||||||
):
|
|
||||||
"""Write HNSW data in compact format following C++ read order exactly."""
|
"""Write HNSW data in compact format following C++ read order exactly."""
|
||||||
# Write IndexHNSW Header
|
# Write IndexHNSW Header
|
||||||
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['d']))
|
||||||
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
|
||||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
|
||||||
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
|
||||||
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
|
||||||
if original_hnsw_data["metric_type"] > 1:
|
if original_hnsw_data['metric_type'] > 1:
|
||||||
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
|
||||||
|
|
||||||
# Write HNSW struct parts (standard order)
|
# Write HNSW struct parts (standard order)
|
||||||
write_numpy_vector(f_out, assign_probas_np, "d")
|
write_numpy_vector(f_out, assign_probas_np, 'd')
|
||||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
|
||||||
write_numpy_vector(f_out, levels_np, "i")
|
write_numpy_vector(f_out, levels_np, 'i')
|
||||||
|
|
||||||
# Write compact format flag
|
# Write compact format flag
|
||||||
f_out.write(struct.pack("<?", True)) # storage_is_compact = True
|
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
|
||||||
|
|
||||||
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
||||||
if isinstance(compact_level_ptr, np.ndarray):
|
if isinstance(compact_level_ptr, np.ndarray):
|
||||||
write_numpy_vector(f_out, compact_level_ptr, "Q")
|
write_numpy_vector(f_out, compact_level_ptr, 'Q')
|
||||||
else:
|
else:
|
||||||
write_list_vector(f_out, compact_level_ptr, "Q")
|
write_list_vector(f_out, compact_level_ptr, 'Q')
|
||||||
|
|
||||||
write_numpy_vector(f_out, compact_node_offsets_np, "Q")
|
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
|
||||||
|
|
||||||
# Write HNSW scalar parameters
|
# Write HNSW scalar parameters
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
|
||||||
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
|
||||||
|
|
||||||
# Write storage fourcc (this determines how to read what follows)
|
# Write storage fourcc (this determines how to read what follows)
|
||||||
f_out.write(struct.pack("<I", storage_fourcc))
|
f_out.write(struct.pack('<I', storage_fourcc))
|
||||||
|
|
||||||
# Write compact neighbors data AFTER storage fourcc
|
# Write compact neighbors data AFTER storage fourcc
|
||||||
write_list_vector(f_out, compact_neighbors_data, "i")
|
write_list_vector(f_out, compact_neighbors_data, 'i')
|
||||||
|
|
||||||
# Write storage data if not NULL (only after neighbors)
|
# Write storage data if not NULL (only after neighbors)
|
||||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||||
@@ -239,7 +183,6 @@ def write_compact_format(
|
|||||||
|
|
||||||
# --- Main Conversion Logic ---
|
# --- Main Conversion Logic ---
|
||||||
|
|
||||||
|
|
||||||
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
||||||
"""
|
"""
|
||||||
Converts an HNSW graph file to the CSR format.
|
Converts an HNSW graph file to the CSR format.
|
||||||
@@ -250,120 +193,94 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
output_filename: Output CSR index file
|
output_filename: Output CSR index file
|
||||||
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
prune_embeddings: Whether to prune embedding storage (write NULL storage marker)
|
||||||
"""
|
"""
|
||||||
# Keep prints simple; rely on CI runner to flush output as needed
|
|
||||||
|
|
||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
neighbors_np = None # Initialize to allow check in finally block
|
neighbors_np = None # Initialize to allow check in finally block
|
||||||
try:
|
try:
|
||||||
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
|
||||||
|
|
||||||
# --- Read IndexHNSW FourCC and Header ---
|
# --- Read IndexHNSW FourCC and Header ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
||||||
# ... (Keep the header reading logic as before) ...
|
# ... (Keep the header reading logic as before) ...
|
||||||
hnsw_index_fourcc = read_struct(f_in, "<I")
|
hnsw_index_fourcc = read_struct(f_in, '<I')
|
||||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
print(
|
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
|
||||||
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
return False
|
||||||
file=sys.stderr,
|
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
|
||||||
)
|
original_hnsw_data['d'] = read_struct(f_in, '<i')
|
||||||
return False
|
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
|
||||||
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
|
||||||
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
|
||||||
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
|
||||||
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
original_hnsw_data['metric_arg'] = 0.0
|
||||||
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
if original_hnsw_data['metric_type'] > 1:
|
||||||
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
|
||||||
original_hnsw_data["metric_arg"] = 0.0
|
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
|
||||||
if original_hnsw_data["metric_type"] > 1:
|
|
||||||
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
|
||||||
print(
|
|
||||||
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Read original HNSW struct data ---
|
# --- Read original HNSW struct data ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
||||||
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
|
|
||||||
)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
|
|
||||||
)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
levels_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
ntotal = len(levels_np)
|
ntotal = len(levels_np)
|
||||||
if ntotal != original_hnsw_data["ntotal"]:
|
if ntotal != original_hnsw_data['ntotal']:
|
||||||
print(
|
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
|
||||||
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
|
original_hnsw_data['ntotal'] = ntotal
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
original_hnsw_data["ntotal"] = ntotal
|
|
||||||
|
|
||||||
# --- Check for compact format flag ---
|
# --- Check for compact format flag ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
||||||
pos_before_compact = f_in.tell()
|
pos_before_compact = f_in.tell()
|
||||||
try:
|
try:
|
||||||
is_compact_flag = read_struct(f_in, "<?")
|
is_compact_flag = read_struct(f_in, '<?')
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
||||||
|
|
||||||
if is_compact_flag:
|
if is_compact_flag:
|
||||||
# Input is already in compact format - read compact data
|
# Input is already in compact format - read compact data
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
|
||||||
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
|
|
||||||
)
|
|
||||||
|
|
||||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read scalar parameters
|
# Read scalar parameters
|
||||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read storage fourcc
|
# Read storage fourcc
|
||||||
storage_fourcc = read_struct(f_in, "<I")
|
storage_fourcc = read_struct(f_in, '<I')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
|
||||||
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
# Read compact neighbors data
|
# Read compact neighbors data
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
|
|
||||||
)
|
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
|
|
||||||
# Skip storage data and write with NULL marker
|
# Skip storage data and write with NULL marker
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
|
||||||
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
|
|
||||||
)
|
|
||||||
storage_fourcc = NULL_INDEX_FOURCC
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
elif not prune_embeddings:
|
elif not prune_embeddings:
|
||||||
# Read and preserve compact neighbors and storage
|
# Read and preserve compact neighbors and storage
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
|
|
||||||
@@ -371,25 +288,16 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
storage_data = f_in.read()
|
storage_data = f_in.read()
|
||||||
else:
|
else:
|
||||||
# Already pruned (NULL storage)
|
# Already pruned (NULL storage)
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
storage_data = b""
|
storage_data = b''
|
||||||
|
|
||||||
# Write the updated compact format
|
# Write the updated compact format
|
||||||
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
||||||
write_compact_format(
|
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||||
f_out,
|
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||||
original_hnsw_data,
|
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
|
||||||
assign_probas_np,
|
|
||||||
cum_nneighbor_per_level_np,
|
|
||||||
levels_np,
|
|
||||||
compact_level_ptr,
|
|
||||||
compact_node_offsets_np,
|
|
||||||
compact_neighbors_data,
|
|
||||||
storage_fourcc,
|
|
||||||
storage_data if not prune_embeddings else b"",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
||||||
return True
|
return True
|
||||||
@@ -397,86 +305,63 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
else:
|
else:
|
||||||
# is_compact=False, rewind and read original format
|
# is_compact=False, rewind and read original format
|
||||||
f_in.seek(pos_before_compact)
|
f_in.seek(pos_before_compact)
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
|
||||||
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
|
|
||||||
)
|
|
||||||
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
# No compact flag found, assume original format
|
# No compact flag found, assume original format
|
||||||
f_in.seek(pos_before_compact)
|
f_in.seek(pos_before_compact)
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
|
||||||
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Handle potential extra byte in original format (like C++ code) ---
|
# --- Handle potential extra byte in original format (like C++ code) ---
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
|
||||||
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
|
|
||||||
)
|
|
||||||
pos_before_probe = f_in.tell()
|
pos_before_probe = f_in.tell()
|
||||||
try:
|
try:
|
||||||
suspected_flag = read_struct(f_in, "<B") # Read 1 byte
|
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
|
||||||
if suspected_flag == 0x00:
|
if suspected_flag == 0x00:
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
|
||||||
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
|
|
||||||
)
|
|
||||||
elif suspected_flag == 0x01:
|
elif suspected_flag == 0x01:
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
|
||||||
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
|
|
||||||
)
|
|
||||||
raise ValueError("Inconsistent compact flag state")
|
raise ValueError("Inconsistent compact flag state")
|
||||||
else:
|
else:
|
||||||
# Rewind - this byte is part of offsets data
|
# Rewind - this byte is part of offsets data
|
||||||
f_in.seek(pos_before_probe)
|
f_in.seek(pos_before_probe)
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
|
||||||
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
|
|
||||||
)
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
f_in.seek(pos_before_probe)
|
f_in.seek(pos_before_probe)
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
|
||||||
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Read original format data ---
|
# --- Read original format data ---
|
||||||
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
||||||
if len(offsets_np) != ntotal + 1:
|
if len(offsets_np) != ntotal + 1:
|
||||||
raise ValueError(
|
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
|
||||||
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
|
|
||||||
)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
||||||
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
||||||
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
||||||
if neighbors_np.size != expected_neighbors_size:
|
if neighbors_np.size != expected_neighbors_size:
|
||||||
print(
|
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
|
||||||
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
|
|
||||||
)
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
||||||
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
||||||
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
||||||
storage_fourcc = None
|
storage_fourcc = None
|
||||||
try:
|
try:
|
||||||
storage_fourcc = read_struct(f_in, "<I")
|
storage_fourcc = read_struct(f_in, '<I')
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
|
||||||
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
|
|
||||||
)
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
|
||||||
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Perform Conversion ---
|
# --- Perform Conversion ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
||||||
@@ -488,21 +373,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
|
|
||||||
current_level_ptr_idx = 0
|
current_level_ptr_idx = 0
|
||||||
current_data_idx = 0
|
current_data_idx = 0
|
||||||
total_valid_neighbors_counted = 0 # For validation
|
total_valid_neighbors_counted = 0 # For validation
|
||||||
|
|
||||||
# Optimize calculation by getting slices once per node if possible
|
# Optimize calculation by getting slices once per node if possible
|
||||||
for i in range(ntotal):
|
for i in range(ntotal):
|
||||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||||
progress = (i / ntotal) * 100
|
progress = (i / ntotal) * 100
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(
|
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
|
||||||
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
|
|
||||||
end="",
|
|
||||||
)
|
|
||||||
|
|
||||||
node_max_level = levels_np[i] - 1
|
node_max_level = levels_np[i] - 1
|
||||||
if node_max_level < -1:
|
if node_max_level < -1: node_max_level = -1
|
||||||
node_max_level = -1
|
|
||||||
|
|
||||||
node_ptr_start_index = current_level_ptr_idx
|
node_ptr_start_index = current_level_ptr_idx
|
||||||
compact_node_offsets_np[i] = node_ptr_start_index
|
compact_node_offsets_np[i] = node_ptr_start_index
|
||||||
@@ -513,17 +394,13 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
for level in range(node_max_level + 1):
|
for level in range(node_max_level + 1):
|
||||||
compact_level_ptr.append(current_data_idx)
|
compact_level_ptr.append(current_data_idx)
|
||||||
|
|
||||||
begin_orig_np = original_offset_start + get_cum_neighbors(
|
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
|
||||||
cum_nneighbor_per_level_np, level
|
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
|
||||||
)
|
|
||||||
end_orig_np = original_offset_start + get_cum_neighbors(
|
|
||||||
cum_nneighbor_per_level_np, level + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
begin_orig = int(begin_orig_np)
|
begin_orig = int(begin_orig_np)
|
||||||
end_orig = int(end_orig_np)
|
end_orig = int(end_orig_np)
|
||||||
|
|
||||||
neighbors_len = len(neighbors_np) # Cache length
|
neighbors_len = len(neighbors_np) # Cache length
|
||||||
begin_orig = min(max(0, begin_orig), neighbors_len)
|
begin_orig = min(max(0, begin_orig), neighbors_len)
|
||||||
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
||||||
|
|
||||||
@@ -536,116 +413,82 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
|
|
||||||
if num_valid > 0:
|
if num_valid > 0:
|
||||||
# Append valid neighbors
|
# Append valid neighbors
|
||||||
compact_neighbors_data.extend(
|
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
|
||||||
level_neighbors_slice[valid_neighbors_mask]
|
|
||||||
)
|
|
||||||
current_data_idx += num_valid
|
current_data_idx += num_valid
|
||||||
total_valid_neighbors_counted += num_valid
|
total_valid_neighbors_counted += num_valid
|
||||||
|
|
||||||
|
|
||||||
compact_level_ptr.append(current_data_idx)
|
compact_level_ptr.append(current_data_idx)
|
||||||
current_level_ptr_idx += num_pointers_expected
|
current_level_ptr_idx += num_pointers_expected
|
||||||
|
|
||||||
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
||||||
print(
|
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
|
||||||
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
|
|
||||||
) # Clear progress line
|
|
||||||
|
|
||||||
# --- Validation Checks ---
|
# --- Validation Checks ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
||||||
valid_check_passed = True
|
valid_check_passed = True
|
||||||
# Check 1: Total valid neighbors count
|
# Check 1: Total valid neighbors count
|
||||||
print(" Checking total valid neighbor count...")
|
print(f" Checking total valid neighbor count...")
|
||||||
expected_valid_count = np.sum(neighbors_np >= 0)
|
expected_valid_count = np.sum(neighbors_np >= 0)
|
||||||
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
||||||
print(
|
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||||
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
|
valid_check_passed = False
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
valid_check_passed = False
|
|
||||||
if expected_valid_count != len(compact_neighbors_data):
|
if expected_valid_count != len(compact_neighbors_data):
|
||||||
print(
|
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||||
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
|
valid_check_passed = False
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
valid_check_passed = False
|
|
||||||
else:
|
else:
|
||||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||||
|
|
||||||
# Check 2: Final pointer indices consistency
|
# Check 2: Final pointer indices consistency
|
||||||
print(" Checking final pointer indices...")
|
print(f" Checking final pointer indices...")
|
||||||
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
||||||
print(
|
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
|
||||||
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
|
valid_check_passed = False
|
||||||
file=sys.stderr,
|
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
|
||||||
)
|
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||||
valid_check_passed = False
|
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||||
if (
|
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
||||||
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
|
valid_check_passed = False
|
||||||
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
|
||||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
|
||||||
print(
|
|
||||||
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
valid_check_passed = False
|
|
||||||
else:
|
else:
|
||||||
print(" OK: Final pointers match data size.")
|
print(f" OK: Final pointers match data size.")
|
||||||
|
|
||||||
if not valid_check_passed:
|
if not valid_check_passed:
|
||||||
print(
|
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
|
||||||
"Error: Validation checks failed. Output file might be incorrect.",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
# Optional: Exit here if validation fails
|
# Optional: Exit here if validation fails
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
# --- Explicitly delete large intermediate arrays ---
|
# --- Explicitly delete large intermediate arrays ---
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
|
||||||
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
|
|
||||||
)
|
|
||||||
del neighbors_np
|
del neighbors_np
|
||||||
del offsets_np
|
del offsets_np
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(
|
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
|
||||||
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Write CSR HNSW graph data using unified function ---
|
# --- Write CSR HNSW graph data using unified function ---
|
||||||
print(
|
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
||||||
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine storage fourcc and data based on prune_embeddings
|
# Determine storage fourcc and data based on prune_embeddings
|
||||||
if prune_embeddings:
|
if prune_embeddings:
|
||||||
print(" Pruning embeddings: Writing NULL storage marker.")
|
print(f" Pruning embeddings: Writing NULL storage marker.")
|
||||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
storage_data = b""
|
storage_data = b''
|
||||||
else:
|
else:
|
||||||
# Keep embeddings - read and preserve original storage data
|
# Keep embeddings - read and preserve original storage data
|
||||||
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
print(" Preserving embeddings: Reading original storage data...")
|
print(f" Preserving embeddings: Reading original storage data...")
|
||||||
storage_data = f_in.read() # Read remaining storage data
|
storage_data = f_in.read() # Read remaining storage data
|
||||||
output_storage_fourcc = storage_fourcc
|
output_storage_fourcc = storage_fourcc
|
||||||
print(f" Read {len(storage_data)} bytes of storage data")
|
print(f" Read {len(storage_data)} bytes of storage data")
|
||||||
else:
|
else:
|
||||||
print(" No embeddings found in original file (NULL storage)")
|
print(f" No embeddings found in original file (NULL storage)")
|
||||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
storage_data = b""
|
storage_data = b''
|
||||||
|
|
||||||
# Use the unified write function
|
# Use the unified write function
|
||||||
write_compact_format(
|
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
||||||
f_out,
|
levels_np, compact_level_ptr, compact_node_offsets_np,
|
||||||
original_hnsw_data,
|
compact_neighbors_data, output_storage_fourcc, storage_data)
|
||||||
assign_probas_np,
|
|
||||||
cum_nneighbor_per_level_np,
|
|
||||||
levels_np,
|
|
||||||
compact_level_ptr,
|
|
||||||
compact_node_offsets_np,
|
|
||||||
compact_neighbors_data,
|
|
||||||
output_storage_fourcc,
|
|
||||||
storage_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up memory
|
# Clean up memory
|
||||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||||
@@ -660,66 +503,40 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||||
return False
|
return False
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(
|
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
||||||
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
# Clean up potentially partially written output file?
|
||||||
file=sys.stderr,
|
try: os.remove(output_filename)
|
||||||
)
|
except OSError: pass
|
||||||
# Clean up potentially partially written output file?
|
return False
|
||||||
try:
|
|
||||||
os.remove(output_filename)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
except EOFError as e:
|
except EOFError as e:
|
||||||
print(
|
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
|
||||||
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
|
try: os.remove(output_filename)
|
||||||
file=sys.stderr,
|
except OSError: pass
|
||||||
)
|
|
||||||
try:
|
|
||||||
os.remove(output_filename)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
try:
|
try:
|
||||||
os.remove(output_filename)
|
os.remove(output_filename)
|
||||||
except OSError:
|
except OSError: pass
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
||||||
finally:
|
finally:
|
||||||
try:
|
if 'neighbors_np' in locals() and neighbors_np is not None:
|
||||||
if "neighbors_np" in locals() and neighbors_np is not None:
|
del neighbors_np
|
||||||
del neighbors_np
|
gc.collect()
|
||||||
gc.collect()
|
|
||||||
except NameError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# --- Script Execution ---
|
# --- Script Execution ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
|
||||||
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
|
|
||||||
)
|
|
||||||
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
||||||
parser.add_argument(
|
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
|
||||||
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
|
parser.add_argument("--prune-embeddings", action="store_true", default=True,
|
||||||
)
|
help="Prune embedding storage (write NULL storage marker)")
|
||||||
parser.add_argument(
|
parser.add_argument("--keep-embeddings", action="store_true",
|
||||||
"--prune-embeddings",
|
help="Keep embedding storage (overrides --prune-embeddings)")
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Prune embedding storage (write NULL storage marker)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--keep-embeddings",
|
|
||||||
action="store_true",
|
|
||||||
help="Keep embedding storage (overrides --prune-embeddings)",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -728,12 +545,10 @@ if __name__ == "__main__":
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
||||||
print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
||||||
success = convert_hnsw_graph_to_csr(
|
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
|
||||||
args.input_index_file, args.output_csr_graph_file, prune_embeddings
|
|
||||||
)
|
|
||||||
if not success:
|
if not success:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -1,19 +1,19 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Literal, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
|
|
||||||
|
from leann.registry import register_backend
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendBuilderInterface,
|
|
||||||
LeannBackendFactoryInterface,
|
LeannBackendFactoryInterface,
|
||||||
|
LeannBackendBuilderInterface,
|
||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
from leann.registry import register_backend
|
|
||||||
from leann.searcher_base import BaseSearcher
|
|
||||||
|
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -28,12 +28,6 @@ def get_metric_map():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
|
||||||
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
|
||||||
norms[norms == 0] = 1 # Avoid division by zero
|
|
||||||
return data / norms
|
|
||||||
|
|
||||||
|
|
||||||
@register_backend("hnsw")
|
@register_backend("hnsw")
|
||||||
class HNSWBackend(LeannBackendFactoryInterface):
|
class HNSWBackend(LeannBackendFactoryInterface):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -54,15 +48,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||||
self.dimensions = self.build_params.get("dimensions")
|
self.dimensions = self.build_params.get("dimensions")
|
||||||
if not self.is_recompute and self.is_compact:
|
|
||||||
# Auto-correct: non-recompute requires non-compact storage for HNSW
|
|
||||||
logger.warning(
|
|
||||||
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
|
|
||||||
)
|
|
||||||
self.is_compact = False
|
|
||||||
self.build_params["is_compact"] = False
|
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -83,7 +70,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
index.hnsw.efConstruction = self.efConstruction
|
index.hnsw.efConstruction = self.efConstruction
|
||||||
|
|
||||||
if self.distance_metric.lower() == "cosine":
|
if self.distance_metric.lower() == "cosine":
|
||||||
data = normalize_l2(data)
|
faiss.normalize_L2(data)
|
||||||
|
|
||||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||||
index_file = index_dir / f"{index_prefix}.index"
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
@@ -105,15 +92,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info("✅ CSR conversion successful.")
|
logger.info("✅ CSR conversion successful.")
|
||||||
# index_file_old = index_file.with_suffix(".old")
|
index_file_old = index_file.with_suffix(".old")
|
||||||
# shutil.move(str(index_file), str(index_file_old))
|
shutil.move(str(index_file), str(index_file_old))
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
logger.info(
|
||||||
|
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Clean up and fail fast
|
# Clean up and fail fast
|
||||||
if csr_temp_file.exists():
|
if csr_temp_file.exists():
|
||||||
os.remove(csr_temp_file)
|
os.remove(csr_temp_file)
|
||||||
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
raise RuntimeError(
|
||||||
|
"CSR conversion failed - cannot proceed with compact format"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HNSWSearcher(BaseSearcher):
|
class HNSWSearcher(BaseSearcher):
|
||||||
@@ -125,9 +116,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
self.distance_metric = (
|
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||||
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
|
||||||
)
|
|
||||||
metric_enum = get_metric_map().get(self.distance_metric)
|
metric_enum = get_metric_map().get(self.distance_metric)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
@@ -161,7 +150,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for nearest neighbors using HNSW index.
|
Search for nearest neighbors using HNSW index.
|
||||||
|
|
||||||
@@ -185,36 +174,28 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
"""
|
"""
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
if not recompute_embeddings and self.is_pruned:
|
if not recompute_embeddings:
|
||||||
raise RuntimeError(
|
if self.is_pruned:
|
||||||
"Recompute is required for pruned/compact HNSW index. "
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
|
|
||||||
)
|
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
raise ValueError(
|
||||||
|
"zmq_port must be provided if recompute_embeddings is True"
|
||||||
|
)
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if self.distance_metric == "cosine":
|
if self.distance_metric == "cosine":
|
||||||
query = normalize_l2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
params = faiss.SearchParametersHNSW()
|
params = faiss.SearchParametersHNSW()
|
||||||
if zmq_port is not None:
|
if zmq_port is not None:
|
||||||
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
|
params.zmq_port = (
|
||||||
|
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||||
|
)
|
||||||
params.efSearch = complexity
|
params.efSearch = complexity
|
||||||
params.beam_size = beam_width
|
params.beam_size = beam_width
|
||||||
|
|
||||||
# For OpenAI embeddings with cosine distance, disable relative distance check
|
|
||||||
# This prevents early termination when all scores are in a narrow range
|
|
||||||
embedding_model = self.meta.get("embedding_model", "").lower()
|
|
||||||
if self.distance_metric == "cosine" and any(
|
|
||||||
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
|
||||||
):
|
|
||||||
params.check_relative_distance = False
|
|
||||||
else:
|
|
||||||
params.check_relative_distance = True
|
|
||||||
|
|
||||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||||
params.pq_pruning_ratio = prune_ratio
|
params.pq_pruning_ratio = prune_ratio
|
||||||
|
|
||||||
@@ -224,7 +205,9 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
params.send_neigh_times_ratio = 0.0
|
params.send_neigh_times_ratio = 0.0
|
||||||
elif pruning_strategy == "proportional":
|
elif pruning_strategy == "proportional":
|
||||||
params.local_prune = False
|
params.local_prune = False
|
||||||
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
|
params.send_neigh_times_ratio = (
|
||||||
|
1.0 # Any value > 1e-6 triggers proportional mode
|
||||||
|
)
|
||||||
else: # "global"
|
else: # "global"
|
||||||
params.local_prune = False
|
params.local_prune = False
|
||||||
params.send_neigh_times_ratio = 0.0
|
params.send_neigh_times_ratio = 0.0
|
||||||
@@ -245,6 +228,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [
|
||||||
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
|
]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -3,18 +3,17 @@ HNSW-specific embedding server
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import msgpack
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import sys
|
||||||
import msgpack
|
import logging
|
||||||
import numpy as np
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -53,8 +52,8 @@ def create_hnsw_embedding_server(
|
|||||||
sys.path.insert(0, str(leann_core_path))
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from leann.api import PassageManager
|
|
||||||
from leann.embedding_compute import compute_embeddings
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.api import PassageManager
|
||||||
|
|
||||||
logger.info("Successfully imported unified embedding computation module")
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -79,320 +78,206 @@ def create_hnsw_embedding_server(
|
|||||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
# Load metadata to get passage sources
|
# Load metadata to get passage sources
|
||||||
with open(passages_file) as f:
|
with open(passages_file, "r") as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
# Let PassageManager handle path resolution uniformly. It supports fallback order:
|
passages = PassageManager(meta["passage_sources"])
|
||||||
# 1) path/index_path; 2) *_relative; 3) standard siblings next to meta
|
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
|
||||||
# Dimension from metadata for shaping responses
|
|
||||||
try:
|
|
||||||
embedding_dim: int = int(meta.get("dimensions", 0))
|
|
||||||
except Exception:
|
|
||||||
embedding_dim = 0
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
def zmq_server_thread():
|
||||||
|
"""ZMQ server thread"""
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
|
||||||
|
|
||||||
Creates its own REP socket bound to zmq_port and polls with timeouts
|
|
||||||
to allow graceful shutdown.
|
|
||||||
"""
|
|
||||||
logger.info("ZMQ server thread started with shutdown support")
|
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
rep_socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
|
||||||
# Keep sends from blocking during shutdown; fail fast and drop on close
|
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
|
||||||
|
|
||||||
# Track last request type/length for shape-correct fallbacks
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
last_request_length = 0
|
|
||||||
|
|
||||||
try:
|
while True:
|
||||||
while not shutdown_event.is_set():
|
try:
|
||||||
try:
|
message_bytes = socket.recv()
|
||||||
e2e_start = time.time()
|
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
|
||||||
request_bytes = rep_socket.recv()
|
|
||||||
|
|
||||||
# Rest of the processing logic (same as original)
|
e2e_start = time.time()
|
||||||
request = msgpack.unpackb(request_bytes)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
|
|
||||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
# Handle direct text embedding request
|
||||||
response_bytes = msgpack.packb([model_name])
|
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||||
rep_socket.send(response_bytes)
|
# Check if this is a direct text request (list of strings)
|
||||||
continue
|
if all(isinstance(item, str) for item in request_payload):
|
||||||
|
logger.info(
|
||||||
|
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||||
|
)
|
||||||
|
|
||||||
# Handle direct text embedding request
|
# Use unified embedding computation (now with model caching)
|
||||||
if (
|
embeddings = compute_embeddings(
|
||||||
isinstance(request, list)
|
request_payload, model_name, mode=embedding_mode
|
||||||
and request
|
)
|
||||||
and all(isinstance(item, str) for item in request)
|
|
||||||
):
|
response = embeddings.tolist()
|
||||||
last_request_type = "text"
|
socket.send(msgpack.packb(response))
|
||||||
last_request_length = len(request)
|
|
||||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(
|
||||||
|
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle distance calculation request: [[ids], [query_vector]]
|
# Handle distance calculation requests
|
||||||
if (
|
if (
|
||||||
isinstance(request, list)
|
isinstance(request_payload, list)
|
||||||
and len(request) == 2
|
and len(request_payload) == 2
|
||||||
and isinstance(request[0], list)
|
and isinstance(request_payload[0], list)
|
||||||
and isinstance(request[1], list)
|
and isinstance(request_payload[1], list)
|
||||||
):
|
):
|
||||||
node_ids = request[0]
|
node_ids = request_payload[0]
|
||||||
# Handle nested [[ids]] shape defensively
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
|
||||||
node_ids = node_ids[0]
|
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
|
||||||
last_request_type = "distance"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
# Gather texts for found ids
|
# Get embeddings for node IDs
|
||||||
texts: list[str] = []
|
texts = []
|
||||||
found_indices: list[int] = []
|
for nid in node_ids:
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_data = passages.get_passage(str(nid))
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {nid}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
|
|
||||||
# Prepare full-length response with large sentinel values
|
|
||||||
large_distance = 1e9
|
|
||||||
response_distances = [large_distance] * len(node_ids)
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts, model_name, mode=embedding_mode
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
if distance_metric == "l2":
|
|
||||||
partial = np.sum(
|
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
||||||
)
|
|
||||||
else: # mips or cosine
|
|
||||||
partial = -np.dot(embeddings, query_vector)
|
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
|
||||||
response_distances[pos] = float(dval)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Distance computation error, using sentinels: {e}")
|
|
||||||
|
|
||||||
# Send response in expected shape [[distances]]
|
|
||||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Fallback: treat as embedding-by-id request
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
):
|
|
||||||
node_ids = request[0]
|
|
||||||
elif isinstance(request, list):
|
|
||||||
node_ids = request
|
|
||||||
else:
|
|
||||||
node_ids = []
|
|
||||||
last_request_type = "embedding"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
|
||||||
|
|
||||||
# Preallocate zero-filled flat data for robustness
|
|
||||||
if embedding_dim <= 0:
|
|
||||||
dims = [0, 0]
|
|
||||||
flat_data: list[float] = []
|
|
||||||
else:
|
|
||||||
dims = [len(node_ids), embedding_dim]
|
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
|
||||||
|
|
||||||
# Collect texts for found ids
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
try:
|
||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data.get("text", "")
|
txt = passage_data["text"]
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
texts.append(txt)
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {nid}")
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage with ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FATAL: Passage with ID {nid} not found"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if texts:
|
# Process embeddings
|
||||||
try:
|
embeddings = compute_embeddings(
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
texts, model_name, mode=embedding_mode
|
||||||
logger.info(
|
)
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
logger.info(
|
||||||
)
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
# Calculate distances
|
||||||
logger.error(
|
if distance_metric == "l2":
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
distances = np.sum(
|
||||||
)
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
dims = [0, embedding_dim]
|
)
|
||||||
flat_data = []
|
else: # mips or cosine
|
||||||
else:
|
distances = -np.dot(embeddings, query_vector)
|
||||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
flat = emb_f32.flatten().tolist()
|
|
||||||
for j, pos in enumerate(found_indices):
|
|
||||||
start = pos * embedding_dim
|
|
||||||
end = start + embedding_dim
|
|
||||||
if end <= len(flat_data):
|
|
||||||
flat_data[start:end] = flat[
|
|
||||||
j * embedding_dim : (j + 1) * embedding_dim
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
response_payload = distances.flatten().tolist()
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
response_bytes = msgpack.packb(
|
||||||
|
[response_payload], use_single_float=True
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Sending distance response with {len(distances)} distances"
|
||||||
|
)
|
||||||
|
|
||||||
rep_socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(
|
||||||
|
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
except zmq.Again:
|
)
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
# Shape-correct fallback
|
|
||||||
try:
|
|
||||||
if last_request_type == "distance":
|
|
||||||
large_distance = 1e9
|
|
||||||
fallback_len = max(0, int(last_request_length))
|
|
||||||
safe = [[large_distance] * fallback_len]
|
|
||||||
elif last_request_type == "embedding":
|
|
||||||
bsz = max(0, int(last_request_length))
|
|
||||||
dim = max(0, int(embedding_dim))
|
|
||||||
safe = (
|
|
||||||
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
|
||||||
)
|
|
||||||
elif last_request_type == "text":
|
|
||||||
safe = [] # direct text embeddings expectation is a flat list
|
|
||||||
else:
|
|
||||||
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
rep_socket.close(0)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
context.term()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info("ZMQ server thread exiting gracefully")
|
# Standard embedding request (passage ID lookup)
|
||||||
|
if (
|
||||||
|
not isinstance(request_payload, list)
|
||||||
|
or len(request_payload) != 1
|
||||||
|
or not isinstance(request_payload[0], list)
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||||
|
)
|
||||||
|
socket.send(msgpack.packb([[], []]))
|
||||||
|
continue
|
||||||
|
|
||||||
# Add shutdown coordination
|
node_ids = request_payload[0]
|
||||||
shutdown_event = threading.Event()
|
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||||
|
|
||||||
def shutdown_zmq_server():
|
# Look up texts by node IDs
|
||||||
"""Gracefully shutdown ZMQ server."""
|
texts = []
|
||||||
logger.info("Initiating graceful shutdown...")
|
for nid in node_ids:
|
||||||
shutdown_event.set()
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FATAL: Empty text for passage ID {nid}"
|
||||||
|
)
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if zmq_thread.is_alive():
|
# Process embeddings
|
||||||
logger.info("Waiting for ZMQ thread to finish...")
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
zmq_thread.join(timeout=5)
|
logger.info(
|
||||||
if zmq_thread.is_alive():
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
logger.warning("ZMQ thread did not finish in time")
|
)
|
||||||
|
|
||||||
# Clean up ZMQ resources
|
# Serialization and response
|
||||||
try:
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
# Note: socket and context are cleaned up by thread exit
|
logger.error(
|
||||||
logger.info("ZMQ resources cleaned up")
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
except Exception as e:
|
)
|
||||||
logger.warning(f"Error cleaning ZMQ resources: {e}")
|
assert False
|
||||||
|
|
||||||
# Clean up other resources
|
hidden_contiguous_f32 = np.ascontiguousarray(
|
||||||
try:
|
embeddings, dtype=np.float32
|
||||||
import gc
|
)
|
||||||
|
response_payload = [
|
||||||
|
list(hidden_contiguous_f32.shape),
|
||||||
|
hidden_contiguous_f32.flatten().tolist(),
|
||||||
|
]
|
||||||
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
|
||||||
gc.collect()
|
socket.send(response_bytes)
|
||||||
logger.info("Additional resources cleaned up")
|
e2e_end = time.time()
|
||||||
except Exception as e:
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
logger.warning(f"Error cleaning additional resources: {e}")
|
|
||||||
|
|
||||||
logger.info("Graceful shutdown completed")
|
except zmq.Again:
|
||||||
sys.exit(0)
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
# Register signal handlers within this function scope
|
traceback.print_exc()
|
||||||
import signal
|
socket.send(msgpack.packb([[], []]))
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
|
||||||
shutdown_zmq_server()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Pass shutdown_event to ZMQ thread
|
|
||||||
zmq_thread = threading.Thread(
|
|
||||||
target=lambda: zmq_server_thread_with_shutdown(shutdown_event),
|
|
||||||
daemon=False, # Not daemon - we want to wait for it
|
|
||||||
)
|
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while not shutdown_event.is_set():
|
while True:
|
||||||
time.sleep(0.1) # Check shutdown more frequently
|
time.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
shutdown_zmq_server()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If we reach here, shutdown was triggered by signal
|
|
||||||
logger.info("Main loop exited, process should be shutting down")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Signal handlers are now registered within create_hnsw_embedding_server
|
def signal_handler(sig, frame):
|
||||||
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Register signal handlers for graceful shutdown
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
@@ -414,7 +299,7 @@ if __name__ == "__main__":
|
|||||||
"--embedding-mode",
|
"--embedding-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="sentence-transformers",
|
default="sentence-transformers",
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,14 +6,9 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.2.9"
|
version = "0.1.0"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||||
"leann-core==0.2.9",
|
|
||||||
"numpy",
|
|
||||||
"pyzmq>=23.0.0",
|
|
||||||
"msgpack>=1.0.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
wheel.packages = ["leann_backend_hnsw"]
|
wheel.packages = ["leann_backend_hnsw"]
|
||||||
@@ -22,8 +17,6 @@ cmake.build-type = "Release"
|
|||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
|
||||||
# CMake definitions to optimize compilation and find Homebrew packages
|
# CMake definitions to optimize compilation
|
||||||
[tool.scikit-build.cmake.define]
|
[tool.scikit-build.cmake.define]
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
CMAKE_PREFIX_PATH = {env = "CMAKE_PREFIX_PATH"}
|
|
||||||
OpenMP_ROOT = {env = "OpenMP_ROOT"}
|
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 4a2c0d67d3...ff22e2c86b
@@ -4,49 +4,19 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.2.9"
|
version = "0.1.0"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for Leann."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
# All required dependencies included
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
"tqdm>=4.60.0",
|
"tqdm>=4.60.0"
|
||||||
"psutil>=5.8.0",
|
|
||||||
"pyzmq>=23.0.0",
|
|
||||||
"msgpack>=1.0.0",
|
|
||||||
"torch>=2.0.0",
|
|
||||||
"sentence-transformers>=2.2.0",
|
|
||||||
"llama-index-core>=0.12.0",
|
|
||||||
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
|
||||||
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
|
||||||
"python-dotenv>=1.0.0",
|
|
||||||
"openai>=1.0.0",
|
|
||||||
"huggingface-hub>=0.20.0",
|
|
||||||
"transformers>=4.30.0",
|
|
||||||
"requests>=2.25.0",
|
|
||||||
"accelerate>=0.20.0",
|
|
||||||
"PyPDF2>=3.0.0",
|
|
||||||
"pymupdf>=1.23.0",
|
|
||||||
"pdfplumber>=0.10.0",
|
|
||||||
"nbconvert>=7.0.0", # For .ipynb file support
|
|
||||||
"gitignore-parser>=0.1.12", # For proper .gitignore handling
|
|
||||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
colab = [
|
|
||||||
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
|
||||||
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
|
||||||
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
leann = "leann.cli:main"
|
leann = "leann.cli:main"
|
||||||
leann_mcp = "leann.mcp:main"
|
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
@@ -8,14 +8,10 @@ if platform.system() == "Darwin":
|
|||||||
os.environ["MKL_NUM_THREADS"] = "1"
|
os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
os.environ["KMP_BLOCKTIME"] = "0"
|
os.environ["KMP_BLOCKTIME"] = "0"
|
||||||
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
|
|
||||||
if os.environ.get("CI") == "true":
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||||
|
|
||||||
autodiscover_backends()
|
autodiscover_backends()
|
||||||
|
|
||||||
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]
|
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]
|
||||||
@@ -4,32 +4,23 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Literal, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
import numpy as np
|
||||||
from .chat import get_llm
|
import time
|
||||||
from .interface import LeannBackendFactoryInterface
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional, Literal
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
|
from .interface import LeannBackendFactoryInterface
|
||||||
|
from .chat import get_llm
|
||||||
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_registered_backends() -> list[str]:
|
|
||||||
"""Get list of registered backend names."""
|
|
||||||
return list(BACKEND_REGISTRY.keys())
|
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
chunks: list[str],
|
chunks: List[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
@@ -70,7 +61,9 @@ def compute_embeddings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
def compute_embeddings_via_server(
|
||||||
|
chunks: List[str], model_name: str, port: int
|
||||||
|
) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -80,9 +73,9 @@ def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int)
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
)
|
)
|
||||||
|
import zmq
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
|
||||||
|
|
||||||
# Connect to embedding server
|
# Connect to embedding server
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
@@ -111,70 +104,21 @@ class SearchResult:
|
|||||||
id: str
|
id: str
|
||||||
score: float
|
score: float
|
||||||
text: str
|
text: str
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(
|
def __init__(self, passage_sources: List[Dict[str, Any]]):
|
||||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
|
||||||
):
|
|
||||||
self.offset_maps = {}
|
self.offset_maps = {}
|
||||||
self.passage_files = {}
|
self.passage_files = {}
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
|
|
||||||
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
|
||||||
index_name_base = None
|
|
||||||
if metadata_file_path:
|
|
||||||
meta_name = Path(metadata_file_path).name
|
|
||||||
if meta_name.endswith(".meta.json"):
|
|
||||||
index_name_base = meta_name[: -len(".meta.json")]
|
|
||||||
|
|
||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source.get("path", "")
|
passage_file = source["path"]
|
||||||
index_file = source.get("index_path", "") # .idx file
|
index_file = source["index_path"] # .idx file
|
||||||
|
|
||||||
# Fix path resolution - relative paths should be relative to metadata file directory
|
|
||||||
def _resolve_candidates(
|
|
||||||
primary: str,
|
|
||||||
relative_key: str,
|
|
||||||
default_name: Optional[str],
|
|
||||||
source_dict: dict[str, Any],
|
|
||||||
) -> list[Path]:
|
|
||||||
candidates: list[Path] = []
|
|
||||||
# 1) Primary as-is (absolute or relative)
|
|
||||||
if primary:
|
|
||||||
p = Path(primary)
|
|
||||||
candidates.append(p if p.is_absolute() else (Path.cwd() / p))
|
|
||||||
# 2) metadata-relative explicit relative key
|
|
||||||
if metadata_file_path and source_dict.get(relative_key):
|
|
||||||
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
|
||||||
# 3) metadata-relative standard sibling filename
|
|
||||||
if metadata_file_path and default_name:
|
|
||||||
candidates.append(Path(metadata_file_path).parent / default_name)
|
|
||||||
return candidates
|
|
||||||
|
|
||||||
# Build candidate lists and pick first existing; otherwise keep last candidate for error message
|
|
||||||
idx_default = f"{index_name_base}.passages.idx" if index_name_base else None
|
|
||||||
idx_candidates = _resolve_candidates(
|
|
||||||
index_file, "index_path_relative", idx_default, source
|
|
||||||
)
|
|
||||||
pas_default = f"{index_name_base}.passages.jsonl" if index_name_base else None
|
|
||||||
pas_candidates = _resolve_candidates(passage_file, "path_relative", pas_default, source)
|
|
||||||
|
|
||||||
def _pick_existing(cands: list[Path]) -> str:
|
|
||||||
for c in cands:
|
|
||||||
if c.exists():
|
|
||||||
return str(c.resolve())
|
|
||||||
# Fallback to last candidate (best guess) even if not exists; will error below
|
|
||||||
return str(cands[-1].resolve()) if cands else ""
|
|
||||||
|
|
||||||
index_file = _pick_existing(idx_candidates)
|
|
||||||
passage_file = _pick_existing(pas_candidates)
|
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
offset_map = pickle.load(f)
|
offset_map = pickle.load(f)
|
||||||
self.offset_maps[passage_file] = offset_map
|
self.offset_maps[passage_file] = offset_map
|
||||||
@@ -184,11 +128,11 @@ class PassageManager:
|
|||||||
for passage_id, offset in offset_map.items():
|
for passage_id, offset in offset_map.items():
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
if passage_id in self.global_offset_map:
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
# Lazy file opening - only open when needed
|
# Lazy file opening - only open when needed
|
||||||
with open(passage_file, encoding="utf-8") as f:
|
with open(passage_file, "r", encoding="utf-8") as f:
|
||||||
f.seek(offset)
|
f.seek(offset)
|
||||||
return json.loads(f.readline())
|
return json.loads(f.readline())
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
@@ -198,105 +142,25 @@ class LeannBuilder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever-msmarco",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
# Normalize incompatible combinations early (for consistent metadata)
|
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
|
||||||
if backend_name == "hnsw":
|
backend_name
|
||||||
is_recompute = backend_kwargs.get("is_recompute", True)
|
)
|
||||||
is_compact = backend_kwargs.get("is_compact", True)
|
|
||||||
if is_recompute is False and is_compact is True:
|
|
||||||
warnings.warn(
|
|
||||||
"HNSW with is_recompute=False requires non-compact storage. Forcing is_compact=False.",
|
|
||||||
UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
backend_kwargs["is_compact"] = False
|
|
||||||
|
|
||||||
backend_factory: Optional[LeannBackendFactoryInterface] = BACKEND_REGISTRY.get(backend_name)
|
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
self.backend_factory = backend_factory
|
self.backend_factory = backend_factory
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
|
|
||||||
# Check if we need to use cosine distance for normalized embeddings
|
|
||||||
normalized_embeddings_models = {
|
|
||||||
# OpenAI models
|
|
||||||
("openai", "text-embedding-ada-002"),
|
|
||||||
("openai", "text-embedding-3-small"),
|
|
||||||
("openai", "text-embedding-3-large"),
|
|
||||||
# Voyage AI models
|
|
||||||
("voyage", "voyage-2"),
|
|
||||||
("voyage", "voyage-3"),
|
|
||||||
("voyage", "voyage-large-2"),
|
|
||||||
("voyage", "voyage-multilingual-2"),
|
|
||||||
("voyage", "voyage-code-2"),
|
|
||||||
# Cohere models
|
|
||||||
("cohere", "embed-english-v3.0"),
|
|
||||||
("cohere", "embed-multilingual-v3.0"),
|
|
||||||
("cohere", "embed-english-light-v3.0"),
|
|
||||||
("cohere", "embed-multilingual-light-v3.0"),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Also check for patterns in model names
|
|
||||||
is_normalized = False
|
|
||||||
current_model_lower = embedding_model.lower()
|
|
||||||
current_mode_lower = embedding_mode.lower()
|
|
||||||
|
|
||||||
# Check exact matches
|
|
||||||
for mode, model in normalized_embeddings_models:
|
|
||||||
if (current_mode_lower == mode and current_model_lower == model) or (
|
|
||||||
mode in current_mode_lower and model in current_model_lower
|
|
||||||
):
|
|
||||||
is_normalized = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check patterns
|
|
||||||
if not is_normalized:
|
|
||||||
# OpenAI patterns
|
|
||||||
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
|
||||||
if any(
|
|
||||||
pattern in current_model_lower
|
|
||||||
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
|
||||||
):
|
|
||||||
is_normalized = True
|
|
||||||
# Voyage patterns
|
|
||||||
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
|
||||||
is_normalized = True
|
|
||||||
# Cohere patterns
|
|
||||||
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
|
||||||
if "embed" in current_model_lower:
|
|
||||||
is_normalized = True
|
|
||||||
|
|
||||||
# Handle distance metric
|
|
||||||
if is_normalized and "distance_metric" not in backend_kwargs:
|
|
||||||
backend_kwargs["distance_metric"] = "cosine"
|
|
||||||
warnings.warn(
|
|
||||||
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
|
||||||
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
|
||||||
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
|
||||||
UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
|
||||||
current_metric = backend_kwargs.get("distance_metric", "mips")
|
|
||||||
warnings.warn(
|
|
||||||
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
|
||||||
f"'{embedding_model}' may lead to suboptimal search results. "
|
|
||||||
f"Consider using 'cosine' distance metric for better performance.",
|
|
||||||
UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.backend_kwargs = backend_kwargs
|
self.backend_kwargs = backend_kwargs
|
||||||
self.chunks: list[dict[str, Any]] = []
|
self.chunks: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[dict[str, Any]] = None):
|
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
@@ -326,7 +190,9 @@ class LeannBuilder:
|
|||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
chunk_iterator = tqdm(
|
||||||
|
self.chunks, desc="Writing passages", unit="chunk"
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
chunk_iterator = self.chunks
|
chunk_iterator = self.chunks
|
||||||
|
|
||||||
@@ -356,7 +222,9 @@ class LeannBuilder:
|
|||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
builder_instance.build(
|
||||||
|
embeddings, string_ids, index_path, **current_backend_kwargs
|
||||||
|
)
|
||||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
@@ -368,12 +236,8 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
# Preserve existing relative file names (backward-compatible)
|
"path": str(passages_file),
|
||||||
"path": passages_file.name,
|
"index_path": str(offset_file),
|
||||||
"index_path": offset_file.name,
|
|
||||||
# Add optional redundant relative keys for remote build portability (non-breaking)
|
|
||||||
"path_relative": passages_file.name,
|
|
||||||
"index_path_relative": offset_file.name,
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -409,7 +273,9 @@ class LeannBuilder:
|
|||||||
ids, embeddings = data
|
ids, embeddings = data
|
||||||
|
|
||||||
if not isinstance(embeddings, np.ndarray):
|
if not isinstance(embeddings, np.ndarray):
|
||||||
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
|
raise ValueError(
|
||||||
|
f"Expected embeddings to be numpy array, got {type(embeddings)}"
|
||||||
|
)
|
||||||
|
|
||||||
if len(ids) != embeddings.shape[0]:
|
if len(ids) != embeddings.shape[0]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -421,7 +287,9 @@ class LeannBuilder:
|
|||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = embedding_dim
|
self.dimensions = embedding_dim
|
||||||
elif self.dimensions != embedding_dim:
|
elif self.dimensions != embedding_dim:
|
||||||
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
|
raise ValueError(
|
||||||
|
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||||
@@ -488,12 +356,8 @@ class LeannBuilder:
|
|||||||
"passage_sources": [
|
"passage_sources": [
|
||||||
{
|
{
|
||||||
"type": "jsonl",
|
"type": "jsonl",
|
||||||
# Preserve existing relative file names (backward-compatible)
|
"path": str(passages_file),
|
||||||
"path": passages_file.name,
|
"index_path": str(offset_file),
|
||||||
"index_path": offset_file.name,
|
|
||||||
# Add optional redundant relative keys for remote build portability (non-breaking)
|
|
||||||
"path_relative": passages_file.name,
|
|
||||||
"index_path_relative": offset_file.name,
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"built_from_precomputed_embeddings": True,
|
"built_from_precomputed_embeddings": True,
|
||||||
@@ -510,35 +374,27 @@ class LeannBuilder:
|
|||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
logger.info(
|
||||||
|
f"Index built successfully from precomputed embeddings: {index_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
# Fix path resolution for Colab and other environments
|
|
||||||
if not Path(index_path).is_absolute():
|
|
||||||
index_path = str(Path(index_path).resolve())
|
|
||||||
|
|
||||||
self.meta_path_str = f"{index_path}.meta.json"
|
self.meta_path_str = f"{index_path}.meta.json"
|
||||||
if not Path(self.meta_path_str).exists():
|
if not Path(self.meta_path_str).exists():
|
||||||
parent_dir = Path(index_path).parent
|
|
||||||
print(
|
|
||||||
f"Leann metadata file not found at {self.meta_path_str}, and you may need to rm -rf {parent_dir}"
|
|
||||||
)
|
|
||||||
# highlight in red the filenotfound error
|
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Leann metadata file not found at {self.meta_path_str}, \033[91m you may need to rm -rf {parent_dir}\033[0m"
|
f"Leann metadata file not found at {self.meta_path_str}"
|
||||||
)
|
)
|
||||||
with open(self.meta_path_str, encoding="utf-8") as f:
|
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
||||||
self.meta_data = json.load(f)
|
self.meta_data = json.load(f)
|
||||||
backend_name = self.meta_data["backend_name"]
|
backend_name = self.meta_data["backend_name"]
|
||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get(
|
||||||
# Delegate portability handling to PassageManager
|
"embedding_mode", "sentence-transformers"
|
||||||
self.passage_manager = PassageManager(
|
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
|
||||||
)
|
)
|
||||||
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
@@ -559,22 +415,12 @@ class LeannSearcher:
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
logger.info(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Smart top_k detection and adjustment
|
|
||||||
total_docs = len(self.passage_manager.global_offset_map)
|
|
||||||
original_top_k = top_k
|
|
||||||
if top_k > total_docs:
|
|
||||||
top_k = total_docs
|
|
||||||
logger.warning(
|
|
||||||
f" ⚠️ Requested top_k ({original_top_k}) exceeds total documents ({total_docs})"
|
|
||||||
)
|
|
||||||
logger.warning(f" ✅ Auto-adjusted top_k to {top_k} to match available documents")
|
|
||||||
|
|
||||||
zmq_port = None
|
zmq_port = None
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -595,9 +441,9 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
@@ -611,13 +457,15 @@ class LeannSearcher:
|
|||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
search_time = time.time() - start_time
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Search time: {search_time} seconds")
|
||||||
|
logger.info(
|
||||||
|
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
||||||
|
)
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
# Python 3.9 does not support zip(strict=...); lengths are expected to match
|
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0])
|
||||||
):
|
):
|
||||||
@@ -631,63 +479,23 @@ class LeannSearcher:
|
|||||||
metadata=passage_data.get("metadata", {}),
|
metadata=passage_data.get("metadata", {}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Color codes for better logging
|
|
||||||
GREEN = "\033[92m"
|
|
||||||
BLUE = "\033[94m"
|
|
||||||
YELLOW = "\033[93m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
# Truncate text for display (first 100 chars)
|
|
||||||
display_text = passage_data["text"]
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
RED = "\033[91m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define color codes outside the loop for final message
|
logger.info(f" Final enriched results: {len(enriched_results)} passages")
|
||||||
GREEN = "\033[92m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Explicitly cleanup embedding server resources.
|
|
||||||
|
|
||||||
This method should be called after you're done using the searcher,
|
|
||||||
especially in test environments or batch processing scenarios.
|
|
||||||
"""
|
|
||||||
if hasattr(self.backend_impl, "embedding_server_manager"):
|
|
||||||
self.backend_impl.embedding_server_manager.stop_server()
|
|
||||||
|
|
||||||
# Enable automatic cleanup patterns
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
# Avoid noisy errors during interpreter shutdown
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LeannChat:
|
class LeannChat:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[dict[str, Any]] = None,
|
llm_config: Optional[Dict[str, Any]] = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -703,13 +511,13 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
llm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
llm_kwargs = {}
|
llm_kwargs = {}
|
||||||
search_time = time.time()
|
|
||||||
results = self.searcher.search(
|
results = self.searcher.search(
|
||||||
question,
|
question,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -721,8 +529,6 @@ class LeannChat:
|
|||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -731,10 +537,7 @@ class LeannChat:
|
|||||||
"Please provide the best answer you can based on this context and your knowledge."
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
)
|
)
|
||||||
|
|
||||||
ask_time = time.time()
|
|
||||||
ans = self.llm.ask(prompt, **llm_kwargs)
|
ans = self.llm.ask(prompt, **llm_kwargs)
|
||||||
ask_time = time.time() - ask_time
|
|
||||||
logger.info(f" Ask time: {ask_time} seconds")
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def start_interactive(self):
|
def start_interactive(self):
|
||||||
@@ -751,28 +554,3 @@ class LeannChat:
|
|||||||
except (KeyboardInterrupt, EOFError):
|
except (KeyboardInterrupt, EOFError):
|
||||||
print("\nGoodbye!")
|
print("\nGoodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Explicitly cleanup embedding server resources.
|
|
||||||
|
|
||||||
This method should be called after you're done using the chat interface,
|
|
||||||
especially in test environments or batch processing scenarios.
|
|
||||||
"""
|
|
||||||
if hasattr(self.searcher, "cleanup"):
|
|
||||||
self.searcher.cleanup()
|
|
||||||
|
|
||||||
# Enable automatic cleanup patterns
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
try:
|
|
||||||
self.cleanup()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -4,25 +4,22 @@ This file contains the chat generation logic for the LEANN project,
|
|||||||
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
|
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import difflib
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
import difflib
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_models(host: str) -> list[str]:
|
def check_ollama_models() -> List[str]:
|
||||||
"""Check available Ollama models and return a list"""
|
"""Check available Ollama models and return a list"""
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return [model["name"] for model in data.get("models", [])]
|
return [model["name"] for model in data.get("models", [])]
|
||||||
@@ -31,70 +28,7 @@ def check_ollama_models(host: str) -> list[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
||||||
"""Check if a model exists in Ollama's remote library and return available tags
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(model_exists, available_tags): bool and list of matching tags
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import re
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
# Split model name and tag
|
|
||||||
if ":" in model_name:
|
|
||||||
base_model, requested_tag = model_name.split(":", 1)
|
|
||||||
else:
|
|
||||||
base_model, requested_tag = model_name, None
|
|
||||||
|
|
||||||
# First check if base model exists in library
|
|
||||||
library_response = requests.get("https://ollama.com/library", timeout=8)
|
|
||||||
if library_response.status_code != 200:
|
|
||||||
return True, [] # Assume exists if can't check
|
|
||||||
|
|
||||||
# Extract model names from library page
|
|
||||||
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
|
||||||
|
|
||||||
if base_model not in models_in_library:
|
|
||||||
return False, [] # Base model doesn't exist
|
|
||||||
|
|
||||||
# If base model exists, get available tags
|
|
||||||
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
|
||||||
if tags_response.status_code != 200:
|
|
||||||
return True, [] # Base model exists but can't get tags
|
|
||||||
|
|
||||||
# Extract tags for this model - be more specific to avoid HTML artifacts
|
|
||||||
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
|
|
||||||
raw_tags = re.findall(tag_pattern, tags_response.text)
|
|
||||||
|
|
||||||
# Clean up tags - remove HTML artifacts and duplicates
|
|
||||||
available_tags = []
|
|
||||||
seen = set()
|
|
||||||
for tag in raw_tags:
|
|
||||||
# Skip if it looks like HTML (contains < or >)
|
|
||||||
if "<" in tag or ">" in tag:
|
|
||||||
continue
|
|
||||||
if tag not in seen:
|
|
||||||
seen.add(tag)
|
|
||||||
available_tags.append(tag)
|
|
||||||
|
|
||||||
# Check if exact model exists
|
|
||||||
if requested_tag is None:
|
|
||||||
# User just requested base model, suggest tags
|
|
||||||
return True, available_tags[:10] # Return up to 10 tags
|
|
||||||
else:
|
|
||||||
exact_match = model_name in available_tags
|
|
||||||
return exact_match, available_tags[:10]
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# If scraping fails, assume model might exist (don't block user)
|
|
||||||
return True, []
|
|
||||||
|
|
||||||
|
|
||||||
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
|
|
||||||
"""Use intelligent fuzzy search for Ollama models"""
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
return []
|
return []
|
||||||
@@ -107,9 +41,7 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
|
|||||||
suggestions.extend(exact_matches)
|
suggestions.extend(exact_matches)
|
||||||
|
|
||||||
# 2. Starts with query
|
# 2. Starts with query
|
||||||
starts_with = [
|
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
|
||||||
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
|
|
||||||
]
|
|
||||||
suggestions.extend(starts_with)
|
suggestions.extend(starts_with)
|
||||||
|
|
||||||
# 3. Contains query
|
# 3. Contains query
|
||||||
@@ -119,25 +51,24 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
|
|||||||
# 4. Base model name matching (remove version numbers)
|
# 4. Base model name matching (remove version numbers)
|
||||||
def get_base_name(model_name: str) -> str:
|
def get_base_name(model_name: str) -> str:
|
||||||
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
||||||
return model_name.split(":")[0].split("-")[0]
|
return model_name.split(':')[0].split('-')[0]
|
||||||
|
|
||||||
query_base = get_base_name(query_lower)
|
query_base = get_base_name(query_lower)
|
||||||
base_matches = [
|
base_matches = [
|
||||||
m
|
m for m in available_models
|
||||||
for m in available_models
|
|
||||||
if get_base_name(m.lower()) == query_base and m not in suggestions
|
if get_base_name(m.lower()) == query_base and m not in suggestions
|
||||||
]
|
]
|
||||||
suggestions.extend(base_matches)
|
suggestions.extend(base_matches)
|
||||||
|
|
||||||
# 5. Family/variant matching
|
# 5. Family/variant matching
|
||||||
model_families = {
|
model_families = {
|
||||||
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
|
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
|
||||||
"qwen": ["qwen", "qwen2", "qwen3"],
|
'qwen': ['qwen', 'qwen2', 'qwen3'],
|
||||||
"gemma": ["gemma", "gemma2"],
|
'gemma': ['gemma', 'gemma2'],
|
||||||
"phi": ["phi", "phi2", "phi3"],
|
'phi': ['phi', 'phi2', 'phi3'],
|
||||||
"mistral": ["mistral", "mixtral", "openhermes"],
|
'mistral': ['mistral', 'mixtral', 'openhermes'],
|
||||||
"dolphin": ["dolphin", "openchat"],
|
'dolphin': ['dolphin', 'openchat'],
|
||||||
"deepseek": ["deepseek", "deepseek-coder"],
|
'deepseek': ['deepseek', 'deepseek-coder']
|
||||||
}
|
}
|
||||||
|
|
||||||
query_family = None
|
query_family = None
|
||||||
@@ -149,8 +80,7 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
|
|||||||
if query_family:
|
if query_family:
|
||||||
family_variants = model_families[query_family]
|
family_variants = model_families[query_family]
|
||||||
family_matches = [
|
family_matches = [
|
||||||
m
|
m for m in available_models
|
||||||
for m in available_models
|
|
||||||
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
||||||
]
|
]
|
||||||
suggestions.extend(family_matches)
|
suggestions.extend(family_matches)
|
||||||
@@ -169,13 +99,15 @@ def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[
|
|||||||
# Remove this too - no need for fallback
|
# Remove this too - no need for fallback
|
||||||
|
|
||||||
|
|
||||||
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
|
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
|
||||||
"""Use difflib to find similar model names"""
|
"""Use difflib to find similar model names"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Get close matches using fuzzy matching
|
# Get close matches using fuzzy matching
|
||||||
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
|
suggestions = difflib.get_close_matches(
|
||||||
|
invalid_model, available_models, n=3, cutoff=0.3
|
||||||
|
)
|
||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
|
|
||||||
@@ -183,14 +115,13 @@ def check_hf_model_exists(model_name: str) -> bool:
|
|||||||
"""Quick check if HuggingFace model exists without downloading"""
|
"""Quick check if HuggingFace model exists without downloading"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import model_info
|
from huggingface_hub import model_info
|
||||||
|
|
||||||
model_info(model_name)
|
model_info(model_name)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_popular_hf_models() -> list[str]:
|
def get_popular_hf_models() -> List[str]:
|
||||||
"""Return a list of popular HuggingFace models for suggestions"""
|
"""Return a list of popular HuggingFace models for suggestions"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import list_models
|
||||||
@@ -200,15 +131,15 @@ def get_popular_hf_models() -> list[str]:
|
|||||||
filter="text-generation",
|
filter="text-generation",
|
||||||
sort="downloads",
|
sort="downloads",
|
||||||
direction=-1,
|
direction=-1,
|
||||||
limit=20, # Get top 20 most downloaded
|
limit=20 # Get top 20 most downloaded
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract model names and filter for chat/conversation models
|
# Extract model names and filter for chat/conversation models
|
||||||
model_names = []
|
model_names = []
|
||||||
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
|
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
model_name = model.id if hasattr(model, "id") else str(model)
|
model_name = model.id if hasattr(model, 'id') else str(model)
|
||||||
# Prioritize models with chat-related keywords
|
# Prioritize models with chat-related keywords
|
||||||
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
||||||
model_names.append(model_name)
|
model_names.append(model_name)
|
||||||
@@ -222,7 +153,7 @@ def get_popular_hf_models() -> list[str]:
|
|||||||
return _get_fallback_hf_models()
|
return _get_fallback_hf_models()
|
||||||
|
|
||||||
|
|
||||||
def _get_fallback_hf_models() -> list[str]:
|
def _get_fallback_hf_models() -> List[str]:
|
||||||
"""Fallback list of popular HuggingFace models"""
|
"""Fallback list of popular HuggingFace models"""
|
||||||
return [
|
return [
|
||||||
"microsoft/DialoGPT-medium",
|
"microsoft/DialoGPT-medium",
|
||||||
@@ -234,11 +165,11 @@ def _get_fallback_hf_models() -> list[str]:
|
|||||||
"facebook/blenderbot_small-90M",
|
"facebook/blenderbot_small-90M",
|
||||||
"microsoft/phi-1_5",
|
"microsoft/phi-1_5",
|
||||||
"facebook/opt-350m",
|
"facebook/opt-350m",
|
||||||
"EleutherAI/gpt-neo-1.3B",
|
"EleutherAI/gpt-neo-1.3B"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
||||||
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import list_models
|
||||||
@@ -249,10 +180,10 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
|||||||
filter="text-generation",
|
filter="text-generation",
|
||||||
sort="downloads",
|
sort="downloads",
|
||||||
direction=-1,
|
direction=-1,
|
||||||
limit=limit,
|
limit=limit
|
||||||
)
|
)
|
||||||
|
|
||||||
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
|
||||||
|
|
||||||
# If direct search doesn't return enough results, try some variations
|
# If direct search doesn't return enough results, try some variations
|
||||||
if len(model_names) < 3:
|
if len(model_names) < 3:
|
||||||
@@ -260,17 +191,17 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
|||||||
variations = []
|
variations = []
|
||||||
|
|
||||||
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
||||||
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
|
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
|
||||||
if base_query != query.lower():
|
if base_query != query.lower():
|
||||||
variations.append(base_query)
|
variations.append(base_query)
|
||||||
|
|
||||||
# Try common model name patterns
|
# Try common model name patterns
|
||||||
if "gpt" in query.lower():
|
if 'gpt' in query.lower():
|
||||||
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
|
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
|
||||||
elif "llama" in query.lower():
|
elif 'llama' in query.lower():
|
||||||
variations.extend(["llama2", "alpaca", "vicuna"])
|
variations.extend(['llama2', 'alpaca', 'vicuna'])
|
||||||
elif "bert" in query.lower():
|
elif 'bert' in query.lower():
|
||||||
variations.extend(["roberta", "distilbert", "albert"])
|
variations.extend(['roberta', 'distilbert', 'albert'])
|
||||||
|
|
||||||
# Search with variations
|
# Search with variations
|
||||||
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
||||||
@@ -280,13 +211,11 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
|||||||
filter="text-generation",
|
filter="text-generation",
|
||||||
sort="downloads",
|
sort="downloads",
|
||||||
direction=-1,
|
direction=-1,
|
||||||
limit=3,
|
limit=3
|
||||||
)
|
)
|
||||||
var_names = [
|
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
|
||||||
model.id if hasattr(model, "id") else str(model) for model in var_models
|
|
||||||
]
|
|
||||||
model_names.extend(var_names)
|
model_names.extend(var_names)
|
||||||
except Exception:
|
except:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Remove duplicates while preserving order
|
# Remove duplicates while preserving order
|
||||||
@@ -304,86 +233,34 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
def search_hf_models(query: str, limit: int = 10) -> List[str]:
|
||||||
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||||
return search_hf_models_fuzzy(query, limit)
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
||||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models(host)
|
available_models = check_ollama_models()
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
|
# Use intelligent fuzzy search based on locally installed models
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
if suggestions:
|
||||||
# Check if the model exists remotely and get available tags
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
if model_exists_remotely and model_name in available_tags:
|
|
||||||
# Exact model exists remotely - suggest pulling it
|
|
||||||
error_msg += "\n\nTo install the requested model:\n"
|
|
||||||
error_msg += f" ollama pull {model_name}\n"
|
|
||||||
|
|
||||||
# Show local alternatives
|
|
||||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
|
||||||
if suggestions:
|
|
||||||
error_msg += "\nOr use one of these similar installed models:\n"
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
|
||||||
|
|
||||||
elif model_exists_remotely and available_tags:
|
|
||||||
# Base model exists but requested tag doesn't - suggest correct tags
|
|
||||||
base_model = model_name.split(":")[0]
|
|
||||||
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
|
|
||||||
|
|
||||||
error_msg += (
|
|
||||||
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
|
||||||
)
|
|
||||||
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
|
||||||
for i, tag in enumerate(available_tags[:8], 1):
|
|
||||||
error_msg += f" {i}. ollama pull {tag}\n"
|
|
||||||
if len(available_tags) > 8:
|
|
||||||
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
|
||||||
|
|
||||||
# Also show local alternatives
|
|
||||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
|
||||||
if suggestions:
|
|
||||||
error_msg += "\nOr use one of these similar installed models:\n"
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Model doesn't exist remotely - show fuzzy suggestions
|
error_msg += "\n\nYour installed models:\n"
|
||||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
for i, model in enumerate(available_models[:8], 1):
|
||||||
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
error_msg += f" {i}. {model}\n"
|
||||||
|
if len(available_models) > 8:
|
||||||
|
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||||
|
|
||||||
if suggestions:
|
error_msg += "\nTo list all models: ollama list"
|
||||||
error_msg += (
|
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
||||||
"\n\nDid you mean one of these installed models?\n"
|
error_msg += "\nBrowse models: https://ollama.com/library"
|
||||||
+ "\nTry to use ollama pull to install the model you need\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
|
||||||
else:
|
|
||||||
error_msg += "\n\nYour installed models:\n"
|
|
||||||
for i, model in enumerate(available_models[:8], 1):
|
|
||||||
error_msg += f" {i}. {model}\n"
|
|
||||||
if len(available_models) > 8:
|
|
||||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
|
||||||
|
|
||||||
error_msg += "\n\nCommands:"
|
|
||||||
error_msg += "\n ollama list # List installed models"
|
|
||||||
if model_exists_remotely and available_tags:
|
|
||||||
if model_name in available_tags:
|
|
||||||
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
|
||||||
else:
|
|
||||||
error_msg += (
|
|
||||||
f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
|
||||||
)
|
|
||||||
error_msg += "\n https://ollama.com/library # Browse available models"
|
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
@@ -422,6 +299,7 @@ class LLMInterface(ABC):
|
|||||||
top_k=10,
|
top_k=10,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
beam_width=8,
|
beam_width=8,
|
||||||
|
USE_DEFERRED_FETCH=True,
|
||||||
skip_search_reorder=True,
|
skip_search_reorder=True,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
dedup_node_dis=True,
|
dedup_node_dis=True,
|
||||||
@@ -433,6 +311,7 @@ class LLMInterface(ABC):
|
|||||||
Supported kwargs:
|
Supported kwargs:
|
||||||
- complexity (int): Search complexity parameter (default: 32)
|
- complexity (int): Search complexity parameter (default: 32)
|
||||||
- beam_width (int): Beam width for search (default: 4)
|
- beam_width (int): Beam width for search (default: 4)
|
||||||
|
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
|
||||||
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
- skip_search_reorder (bool): Skip search reorder step (default: False)
|
||||||
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
|
||||||
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
|
||||||
@@ -469,7 +348,7 @@ class OllamaChat(LLMInterface):
|
|||||||
requests.get(host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
model_error = validate_model_and_suggest(model, "ollama")
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -478,50 +357,27 @@ class OllamaChat(LLMInterface):
|
|||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
logger.error(
|
||||||
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
import requests
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
full_url = f"{self.host}/api/generate"
|
full_url = f"{self.host}/api/generate"
|
||||||
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
options = kwargs.copy()
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget:
|
|
||||||
# Remove thinking_budget from options as it's not a standard Ollama option
|
|
||||||
options.pop("thinking_budget", None)
|
|
||||||
# Only apply reasoning parameters to models that support it
|
|
||||||
reasoning_supported_models = [
|
|
||||||
"gpt-oss:20b",
|
|
||||||
"gpt-oss:120b",
|
|
||||||
"deepseek-r1",
|
|
||||||
"deepseek-coder",
|
|
||||||
]
|
|
||||||
|
|
||||||
if thinking_budget in ["low", "medium", "high"]:
|
|
||||||
if any(model in self.model.lower() for model in reasoning_supported_models):
|
|
||||||
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
|
||||||
logger.info(f"Applied reasoning effort={thinking_budget} to model {self.model}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": False, # Keep it simple for now
|
"stream": False, # Keep it simple for now
|
||||||
"options": options,
|
"options": kwargs,
|
||||||
}
|
}
|
||||||
logger.debug(f"Sending request to Ollama: {payload}")
|
logger.debug(f"Sending request to Ollama: {payload}")
|
||||||
try:
|
try:
|
||||||
logger.info("Sending request to Ollama and waiting for response...")
|
logger.info(f"Sending request to Ollama and waiting for response...")
|
||||||
response = requests.post(full_url, data=json.dumps(payload))
|
response = requests.post(full_url, data=json.dumps(payload))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -541,7 +397,7 @@ class OllamaChat(LLMInterface):
|
|||||||
|
|
||||||
|
|
||||||
class HFChat(LLMInterface):
|
class HFChat(LLMInterface):
|
||||||
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
"""LLM interface for local Hugging Face Transformers models."""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
@@ -552,8 +408,8 @@ class HFChat(LLMInterface):
|
|||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from transformers.pipelines import pipeline
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
||||||
@@ -561,123 +417,54 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
# Auto-detect device
|
# Auto-detect device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.device = "cuda"
|
device = "cuda"
|
||||||
logger.info("CUDA is available. Using GPU.")
|
logger.info("CUDA is available. Using GPU.")
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
self.device = "mps"
|
device = "mps"
|
||||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||||
else:
|
else:
|
||||||
self.device = "cpu"
|
device = "cpu"
|
||||||
logger.info("No GPU detected. Using CPU.")
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
# Load tokenizer and model with timeout protection
|
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
||||||
try:
|
|
||||||
import signal
|
|
||||||
|
|
||||||
def timeout_handler(signum, frame):
|
|
||||||
raise TimeoutError("Model download/loading timed out")
|
|
||||||
|
|
||||||
# Set timeout for model loading (60 seconds)
|
|
||||||
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
|
||||||
signal.alarm(60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Loading tokenizer for {model_name}...")
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
|
|
||||||
logger.info(f"Loading model {model_name}...")
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
|
||||||
device_map="auto" if self.device != "cpu" else None,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
logger.info(f"Successfully loaded {model_name}")
|
|
||||||
finally:
|
|
||||||
signal.alarm(0) # Cancel the alarm
|
|
||||||
signal.signal(signal.SIGALRM, old_handler) # Restore old handler
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
logger.error(f"Model loading timed out for {model_name}")
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Model loading timed out for {model_name}. Please check your internet connection or try a smaller model."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model {model_name}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Move model to device if not using device_map
|
|
||||||
if self.device != "cpu" and "device_map" not in str(self.model):
|
|
||||||
self.model = self.model.to(self.device)
|
|
||||||
|
|
||||||
# Set pad token if not present
|
|
||||||
if self.tokenizer.pad_token is None:
|
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
print("kwargs in HF: ", kwargs)
|
# Map OpenAI-style arguments to Hugging Face equivalents
|
||||||
# Check if this is a Qwen model and add /no_think by default
|
if "max_tokens" in kwargs:
|
||||||
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
# Prefer user-provided max_new_tokens if both are present
|
||||||
|
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
||||||
|
# Remove the unsupported key to avoid errors in Transformers
|
||||||
|
kwargs.pop("max_tokens")
|
||||||
|
|
||||||
# For Qwen models, automatically add /no_think to the prompt
|
# Handle temperature=0 edge-case for greedy decoding
|
||||||
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
||||||
prompt = prompt + " /no_think"
|
# Remove unsupported zero temperature and use deterministic generation
|
||||||
|
kwargs.pop("temperature")
|
||||||
|
kwargs.setdefault("do_sample", False)
|
||||||
|
|
||||||
# Prepare chat template
|
# Sensible defaults for text generation
|
||||||
messages = [{"role": "user", "content": prompt}]
|
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
||||||
|
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
||||||
|
results = self.pipeline(prompt, **params)
|
||||||
|
|
||||||
# Apply chat template if available
|
# Handle different response formats from transformers
|
||||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
if isinstance(results, list) and len(results) > 0:
|
||||||
try:
|
generated_text = (
|
||||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
results[0].get("generated_text", "")
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
if isinstance(results[0], dict)
|
||||||
)
|
else str(results[0])
|
||||||
except Exception as e:
|
)
|
||||||
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
|
||||||
formatted_prompt = prompt
|
|
||||||
else:
|
else:
|
||||||
# Fallback for models without chat template
|
generated_text = str(results)
|
||||||
formatted_prompt = prompt
|
|
||||||
|
|
||||||
# Tokenize input
|
# Extract only the newly generated portion by removing the original prompt
|
||||||
inputs = self.tokenizer(
|
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
||||||
formatted_prompt,
|
response = generated_text[len(prompt) :].strip()
|
||||||
return_tensors="pt",
|
else:
|
||||||
padding=True,
|
# Fallback: return the full response if prompt removal fails
|
||||||
truncation=True,
|
response = str(generated_text)
|
||||||
max_length=2048,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move inputs to device
|
return response
|
||||||
if self.device != "cpu":
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
# Set generation parameters
|
|
||||||
generation_config = {
|
|
||||||
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
|
||||||
"temperature": kwargs.get("temperature", 0.7),
|
|
||||||
"top_p": kwargs.get("top_p", 0.9),
|
|
||||||
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
|
||||||
"eos_token_id": self.tokenizer.eos_token_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Handle temperature=0 for greedy decoding
|
|
||||||
if generation_config["temperature"] == 0.0:
|
|
||||||
generation_config["do_sample"] = False
|
|
||||||
generation_config.pop("temperature")
|
|
||||||
|
|
||||||
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model.generate(**inputs, **generation_config)
|
|
||||||
|
|
||||||
# Decode response
|
|
||||||
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
|
||||||
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
||||||
|
|
||||||
return response.strip()
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
@@ -708,38 +495,15 @@ class OpenAIChat(LLMInterface):
|
|||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
"temperature": kwargs.get("temperature", 0.7),
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
if k not in ["max_tokens", "temperature"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle max_tokens vs max_completion_tokens based on model
|
|
||||||
max_tokens = kwargs.get("max_tokens", 1000)
|
|
||||||
if "o3" in self.model or "o4" in self.model or "o1" in self.model:
|
|
||||||
# o-series models use max_completion_tokens
|
|
||||||
params["max_completion_tokens"] = max_tokens
|
|
||||||
params["temperature"] = 1.0
|
|
||||||
else:
|
|
||||||
# Other models use max_tokens
|
|
||||||
params["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
# Handle thinking budget for reasoning models
|
|
||||||
thinking_budget = kwargs.get("thinking_budget")
|
|
||||||
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
|
||||||
# Check if this is an o-series model (partial match for model names)
|
|
||||||
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
|
||||||
if any(model in self.model for model in o_series_models):
|
|
||||||
# Use the correct OpenAI reasoning parameter format
|
|
||||||
params["reasoning_effort"] = thinking_budget
|
|
||||||
logger.info(f"Applied reasoning_effort={thinking_budget} to model {self.model}")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Thinking budget '{thinking_budget}' requested but model '{self.model}' may not support reasoning parameters. Proceeding without reasoning."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add other kwargs (excluding thinking_budget as it's handled above)
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if k not in ["max_tokens", "temperature", "thinking_budget"]:
|
|
||||||
params[k] = v
|
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -759,7 +523,7 @@ class SimulatedChat(LLMInterface):
|
|||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
|
|||||||
@@ -1,65 +1,22 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
|
||||||
"""Extract text from PDF using PyMuPDF for better quality."""
|
|
||||||
try:
|
|
||||||
import fitz # PyMuPDF
|
|
||||||
|
|
||||||
doc = fitz.open(file_path)
|
|
||||||
text = ""
|
|
||||||
for page in doc:
|
|
||||||
text += page.get_text()
|
|
||||||
doc.close()
|
|
||||||
return text
|
|
||||||
except ImportError:
|
|
||||||
# Fallback to default reader
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
|
||||||
"""Extract text from PDF using pdfplumber for better quality."""
|
|
||||||
try:
|
|
||||||
import pdfplumber
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
with pdfplumber.open(file_path) as pdf:
|
|
||||||
for page in pdf.pages:
|
|
||||||
text += page.extract_text() or ""
|
|
||||||
return text
|
|
||||||
except ImportError:
|
|
||||||
# Fallback to default reader
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LeannCLI:
|
class LeannCLI:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Always use project-local .leann directory (like .git)
|
self.indexes_dir = Path.home() / ".leann" / "indexes"
|
||||||
self.indexes_dir = Path.cwd() / ".leann" / "indexes"
|
|
||||||
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
self.indexes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Default parser for documents
|
|
||||||
self.node_parser = SentenceSplitter(
|
self.node_parser = SentenceSplitter(
|
||||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Code-optimized parser
|
|
||||||
self.code_parser = SentenceSplitter(
|
|
||||||
chunk_size=512, # Larger chunks for code context
|
|
||||||
chunk_overlap=50, # Less overlap to preserve function boundaries
|
|
||||||
separator="\n", # Split by lines for code
|
|
||||||
paragraph_separator="\n\n", # Preserve logical code blocks
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_index_path(self, index_name: str) -> str:
|
def get_index_path(self, index_name: str) -> str:
|
||||||
index_dir = self.indexes_dir / index_name
|
index_dir = self.indexes_dir / index_name
|
||||||
return str(index_dir / "documents.leann")
|
return str(index_dir / "documents.leann")
|
||||||
@@ -72,18 +29,14 @@ class LeannCLI:
|
|||||||
def create_parser(self) -> argparse.ArgumentParser:
|
def create_parser(self) -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="leann",
|
prog="leann",
|
||||||
description="The smallest vector index in the world. RAG Everything with LEANN!",
|
description="LEANN - Local Enhanced AI Navigation",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
leann build my-docs --docs ./documents # Build index from directory
|
leann build my-docs --docs ./documents # Build index named my-docs
|
||||||
leann build my-code --docs ./src ./tests ./config # Build index from multiple directories
|
leann search my-docs "query" # Search in my-docs index
|
||||||
leann build my-files --docs ./file1.py ./file2.txt ./docs/ # Build index from files and directories
|
leann ask my-docs "question" # Ask my-docs index
|
||||||
leann build my-mixed --docs ./readme.md ./src/ ./config.json # Build index from mixed files/dirs
|
leann list # List all stored indexes
|
||||||
leann build my-ppts --docs ./ --file-types .pptx,.pdf # Index only PowerPoint and PDF files
|
|
||||||
leann search my-docs "query" # Search in my-docs index
|
|
||||||
leann ask my-docs "question" # Ask my-docs index
|
|
||||||
leann list # List all stored indexes
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -91,88 +44,38 @@ Examples:
|
|||||||
|
|
||||||
# Build command
|
# Build command
|
||||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"index_name", nargs="?", help="Index name (default: current directory name)"
|
"--docs", type=str, required=True, help="Documents directory"
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--docs",
|
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["."],
|
|
||||||
help="Documents directories and/or files (default: current directory)",
|
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--backend",
|
"--embedding-model", type=str, default="facebook/contriever"
|
||||||
type=str,
|
|
||||||
default="hnsw",
|
|
||||||
choices=["hnsw", "diskann"],
|
|
||||||
help="Backend to use (default: hnsw)",
|
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--embedding-model",
|
"--force", "-f", action="store_true", help="Force rebuild"
|
||||||
type=str,
|
|
||||||
default="facebook/contriever",
|
|
||||||
help="Embedding model (default: facebook/contriever)",
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--embedding-mode",
|
|
||||||
type=str,
|
|
||||||
default="sentence-transformers",
|
|
||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--graph-degree", type=int, default=32, help="Graph degree (default: 32)"
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--complexity", type=int, default=64, help="Build complexity (default: 64)"
|
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument("--compact", action="store_true", default=True)
|
||||||
"--compact",
|
build_parser.add_argument("--recompute", action="store_true", default=True)
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=True,
|
|
||||||
help="Use compact storage (default: true). Must be `no-compact` for `no-recompute` build.",
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--recompute",
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=True,
|
|
||||||
help="Enable recomputation (default: true)",
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--file-types",
|
|
||||||
type=str,
|
|
||||||
help="Comma-separated list of file extensions to include (e.g., '.txt,.pdf,.pptx'). If not specified, uses default supported types.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search command
|
# Search command
|
||||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||||
search_parser.add_argument("index_name", help="Index name")
|
search_parser.add_argument("index_name", help="Index name")
|
||||||
search_parser.add_argument("query", help="Search query")
|
search_parser.add_argument("query", help="Search query")
|
||||||
search_parser.add_argument(
|
search_parser.add_argument("--top-k", type=int, default=5)
|
||||||
"--top-k", type=int, default=5, help="Number of results (default: 5)"
|
search_parser.add_argument("--complexity", type=int, default=64)
|
||||||
)
|
|
||||||
search_parser.add_argument(
|
|
||||||
"--complexity", type=int, default=64, help="Search complexity (default: 64)"
|
|
||||||
)
|
|
||||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
search_parser.add_argument(
|
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
"--recompute",
|
|
||||||
dest="recompute_embeddings",
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=True,
|
|
||||||
help="Enable/disable embedding recomputation (default: enabled). Should not do a `no-recompute` search in a `recompute` build.",
|
|
||||||
)
|
|
||||||
search_parser.add_argument(
|
search_parser.add_argument(
|
||||||
"--pruning-strategy",
|
"--pruning-strategy",
|
||||||
choices=["global", "local", "proportional"],
|
choices=["global", "local", "proportional"],
|
||||||
default="global",
|
default="global",
|
||||||
help="Pruning strategy (default: global)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
@@ -183,513 +86,75 @@ Examples:
|
|||||||
type=str,
|
type=str,
|
||||||
default="ollama",
|
default="ollama",
|
||||||
choices=["simulated", "ollama", "hf", "openai"],
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
help="LLM provider (default: ollama)",
|
|
||||||
)
|
|
||||||
ask_parser.add_argument(
|
|
||||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
|
||||||
)
|
)
|
||||||
|
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
ask_parser.add_argument("--top-k", type=int, default=20)
|
||||||
)
|
|
||||||
ask_parser.add_argument(
|
|
||||||
"--top-k", type=int, default=20, help="Retrieval count (default: 20)"
|
|
||||||
)
|
|
||||||
ask_parser.add_argument("--complexity", type=int, default=32)
|
ask_parser.add_argument("--complexity", type=int, default=32)
|
||||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
"--recompute",
|
|
||||||
dest="recompute_embeddings",
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=True,
|
|
||||||
help="Enable/disable embedding recomputation during ask (default: enabled)",
|
|
||||||
)
|
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--pruning-strategy",
|
"--pruning-strategy",
|
||||||
choices=["global", "local", "proportional"],
|
choices=["global", "local", "proportional"],
|
||||||
default="global",
|
default="global",
|
||||||
)
|
)
|
||||||
ask_parser.add_argument(
|
|
||||||
"--thinking-budget",
|
|
||||||
type=str,
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
default=None,
|
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
subparsers.add_parser("list", help="List all indexes")
|
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def register_project_dir(self):
|
|
||||||
"""Register current project directory in global registry"""
|
|
||||||
global_registry = Path.home() / ".leann" / "projects.json"
|
|
||||||
global_registry.parent.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
current_dir = str(Path.cwd())
|
|
||||||
|
|
||||||
# Load existing registry
|
|
||||||
projects = []
|
|
||||||
if global_registry.exists():
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(global_registry) as f:
|
|
||||||
projects = json.load(f)
|
|
||||||
except Exception:
|
|
||||||
projects = []
|
|
||||||
|
|
||||||
# Add current directory if not already present
|
|
||||||
if current_dir not in projects:
|
|
||||||
projects.append(current_dir)
|
|
||||||
|
|
||||||
# Save registry
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(global_registry, "w") as f:
|
|
||||||
json.dump(projects, f, indent=2)
|
|
||||||
|
|
||||||
def _build_gitignore_parser(self, docs_dir: str):
|
|
||||||
"""Build gitignore parser using gitignore-parser library."""
|
|
||||||
from gitignore_parser import parse_gitignore
|
|
||||||
|
|
||||||
# Try to parse the root .gitignore
|
|
||||||
gitignore_path = Path(docs_dir) / ".gitignore"
|
|
||||||
|
|
||||||
if gitignore_path.exists():
|
|
||||||
try:
|
|
||||||
# gitignore-parser automatically handles all subdirectory .gitignore files!
|
|
||||||
matches = parse_gitignore(str(gitignore_path))
|
|
||||||
print(f"📋 Loaded .gitignore from {docs_dir} (includes all subdirectories)")
|
|
||||||
return matches
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not parse .gitignore: {e}")
|
|
||||||
else:
|
|
||||||
print("📋 No .gitignore found")
|
|
||||||
|
|
||||||
# Fallback: basic pattern matching for essential files
|
|
||||||
essential_patterns = {".git", ".DS_Store", "__pycache__", "node_modules", ".venv", "venv"}
|
|
||||||
|
|
||||||
def basic_matches(file_path):
|
|
||||||
path_parts = Path(file_path).parts
|
|
||||||
return any(part in essential_patterns for part in path_parts)
|
|
||||||
|
|
||||||
return basic_matches
|
|
||||||
|
|
||||||
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
|
|
||||||
"""Check if a file should be excluded using gitignore parser."""
|
|
||||||
return gitignore_matches(str(relative_path))
|
|
||||||
|
|
||||||
def _is_git_submodule(self, path: Path) -> bool:
|
|
||||||
"""Check if a path is a git submodule."""
|
|
||||||
try:
|
|
||||||
# Find the git repo root
|
|
||||||
current_dir = Path.cwd()
|
|
||||||
while current_dir != current_dir.parent:
|
|
||||||
if (current_dir / ".git").exists():
|
|
||||||
gitmodules_path = current_dir / ".gitmodules"
|
|
||||||
if gitmodules_path.exists():
|
|
||||||
# Read .gitmodules to check if this path is a submodule
|
|
||||||
gitmodules_content = gitmodules_path.read_text()
|
|
||||||
# Convert path to relative to git root
|
|
||||||
try:
|
|
||||||
relative_path = path.resolve().relative_to(current_dir)
|
|
||||||
# Check if this path appears in .gitmodules
|
|
||||||
return f"path = {relative_path}" in gitmodules_content
|
|
||||||
except ValueError:
|
|
||||||
# Path is not under git root
|
|
||||||
return False
|
|
||||||
break
|
|
||||||
current_dir = current_dir.parent
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
# If anything goes wrong, assume it's not a submodule
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_indexes(self):
|
def list_indexes(self):
|
||||||
print("Stored LEANN indexes:")
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
# Get all project directories with .leann
|
if not self.indexes_dir.exists():
|
||||||
global_registry = Path.home() / ".leann" / "projects.json"
|
|
||||||
all_projects = []
|
|
||||||
|
|
||||||
if global_registry.exists():
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(global_registry) as f:
|
|
||||||
all_projects = json.load(f)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Filter to only existing directories with .leann
|
|
||||||
valid_projects = []
|
|
||||||
for project_dir in all_projects:
|
|
||||||
project_path = Path(project_dir)
|
|
||||||
if project_path.exists() and (project_path / ".leann" / "indexes").exists():
|
|
||||||
valid_projects.append(project_path)
|
|
||||||
|
|
||||||
# Add current project if it has .leann but not in registry
|
|
||||||
current_path = Path.cwd()
|
|
||||||
if (current_path / ".leann" / "indexes").exists() and current_path not in valid_projects:
|
|
||||||
valid_projects.append(current_path)
|
|
||||||
|
|
||||||
if not valid_projects:
|
|
||||||
print(
|
print(
|
||||||
"No indexes found. Use 'leann build <name> --docs <dir> [<dir2> ...]' to create one."
|
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
total_indexes = 0
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
current_dir = Path.cwd()
|
|
||||||
|
|
||||||
for project_path in valid_projects:
|
if not index_dirs:
|
||||||
indexes_dir = project_path / ".leann" / "indexes"
|
|
||||||
if not indexes_dir.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
index_dirs = [d for d in indexes_dir.iterdir() if d.is_dir()]
|
|
||||||
if not index_dirs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Show project header
|
|
||||||
if project_path == current_dir:
|
|
||||||
print(f"\n📁 Current project ({project_path}):")
|
|
||||||
else:
|
|
||||||
print(f"\n📂 {project_path}:")
|
|
||||||
|
|
||||||
for index_dir in index_dirs:
|
|
||||||
total_indexes += 1
|
|
||||||
index_name = index_dir.name
|
|
||||||
meta_file = index_dir / "documents.leann.meta.json"
|
|
||||||
status = "✓" if meta_file.exists() else "✗"
|
|
||||||
|
|
||||||
print(f" {total_indexes}. {index_name} [{status}]")
|
|
||||||
if status == "✓":
|
|
||||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
|
||||||
1024 * 1024
|
|
||||||
)
|
|
||||||
print(f" Size: {size_mb:.1f} MB")
|
|
||||||
|
|
||||||
if total_indexes > 0:
|
|
||||||
print(f"\nTotal: {total_indexes} indexes across {len(valid_projects)} projects")
|
|
||||||
print("\nUsage (current project only):")
|
|
||||||
|
|
||||||
# Show example from current project
|
|
||||||
current_indexes_dir = current_dir / ".leann" / "indexes"
|
|
||||||
if current_indexes_dir.exists():
|
|
||||||
current_index_dirs = [d for d in current_indexes_dir.iterdir() if d.is_dir()]
|
|
||||||
if current_index_dirs:
|
|
||||||
example_name = current_index_dirs[0].name
|
|
||||||
print(f' leann search {example_name} "your query"')
|
|
||||||
print(f" leann ask {example_name} --interactive")
|
|
||||||
|
|
||||||
def load_documents(
|
|
||||||
self, docs_paths: Union[str, list], custom_file_types: Union[str, None] = None
|
|
||||||
):
|
|
||||||
# Handle both single path (string) and multiple paths (list) for backward compatibility
|
|
||||||
if isinstance(docs_paths, str):
|
|
||||||
docs_paths = [docs_paths]
|
|
||||||
|
|
||||||
# Separate files and directories
|
|
||||||
files = []
|
|
||||||
directories = []
|
|
||||||
for path in docs_paths:
|
|
||||||
path_obj = Path(path)
|
|
||||||
if path_obj.is_file():
|
|
||||||
files.append(str(path_obj))
|
|
||||||
elif path_obj.is_dir():
|
|
||||||
# Check if this is a git submodule - if so, skip it
|
|
||||||
if self._is_git_submodule(path_obj):
|
|
||||||
print(f"⚠️ Skipping git submodule: {path}")
|
|
||||||
continue
|
|
||||||
directories.append(str(path_obj))
|
|
||||||
else:
|
|
||||||
print(f"⚠️ Warning: Path '{path}' does not exist, skipping...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Print summary of what we're processing
|
|
||||||
total_items = len(files) + len(directories)
|
|
||||||
items_desc = []
|
|
||||||
if files:
|
|
||||||
items_desc.append(f"{len(files)} file{'s' if len(files) > 1 else ''}")
|
|
||||||
if directories:
|
|
||||||
items_desc.append(
|
|
||||||
f"{len(directories)} director{'ies' if len(directories) > 1 else 'y'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loading documents from {' and '.join(items_desc)} ({total_items} total):")
|
|
||||||
if files:
|
|
||||||
print(f" 📄 Files: {', '.join([Path(f).name for f in files])}")
|
|
||||||
if directories:
|
|
||||||
print(f" 📁 Directories: {', '.join(directories)}")
|
|
||||||
|
|
||||||
if custom_file_types:
|
|
||||||
print(f"Using custom file types: {custom_file_types}")
|
|
||||||
|
|
||||||
all_documents = []
|
|
||||||
|
|
||||||
# First, process individual files if any
|
|
||||||
if files:
|
|
||||||
print(f"\n🔄 Processing {len(files)} individual file{'s' if len(files) > 1 else ''}...")
|
|
||||||
|
|
||||||
# Load individual files using SimpleDirectoryReader with input_files
|
|
||||||
# Note: We skip gitignore filtering for explicitly specified files
|
|
||||||
try:
|
|
||||||
# Group files by their parent directory for efficient loading
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
files_by_dir = defaultdict(list)
|
|
||||||
for file_path in files:
|
|
||||||
parent_dir = str(Path(file_path).parent)
|
|
||||||
files_by_dir[parent_dir].append(file_path)
|
|
||||||
|
|
||||||
# Load files from each parent directory
|
|
||||||
for parent_dir, file_list in files_by_dir.items():
|
|
||||||
print(
|
|
||||||
f" Loading {len(file_list)} file{'s' if len(file_list) > 1 else ''} from {parent_dir}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
file_docs = SimpleDirectoryReader(
|
|
||||||
parent_dir,
|
|
||||||
input_files=file_list,
|
|
||||||
filename_as_id=True,
|
|
||||||
).load_data()
|
|
||||||
all_documents.extend(file_docs)
|
|
||||||
print(
|
|
||||||
f" ✅ Loaded {len(file_docs)} document{'s' if len(file_docs) > 1 else ''}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ Warning: Could not load files from {parent_dir}: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error processing individual files: {e}")
|
|
||||||
|
|
||||||
# Define file extensions to process
|
|
||||||
if custom_file_types:
|
|
||||||
# Parse custom file types from comma-separated string
|
|
||||||
code_extensions = [ext.strip() for ext in custom_file_types.split(",") if ext.strip()]
|
|
||||||
# Ensure extensions start with a dot
|
|
||||||
code_extensions = [ext if ext.startswith(".") else f".{ext}" for ext in code_extensions]
|
|
||||||
else:
|
|
||||||
# Use default supported file types
|
|
||||||
code_extensions = [
|
|
||||||
# Original document types
|
|
||||||
".txt",
|
|
||||||
".md",
|
|
||||||
".docx",
|
|
||||||
".pptx",
|
|
||||||
# Code files for Claude Code integration
|
|
||||||
".py",
|
|
||||||
".js",
|
|
||||||
".ts",
|
|
||||||
".jsx",
|
|
||||||
".tsx",
|
|
||||||
".java",
|
|
||||||
".cpp",
|
|
||||||
".c",
|
|
||||||
".h",
|
|
||||||
".hpp",
|
|
||||||
".cs",
|
|
||||||
".go",
|
|
||||||
".rs",
|
|
||||||
".rb",
|
|
||||||
".php",
|
|
||||||
".swift",
|
|
||||||
".kt",
|
|
||||||
".scala",
|
|
||||||
".r",
|
|
||||||
".sql",
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".fish",
|
|
||||||
".ps1",
|
|
||||||
".bat",
|
|
||||||
# Config and markup files
|
|
||||||
".json",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".xml",
|
|
||||||
".toml",
|
|
||||||
".ini",
|
|
||||||
".cfg",
|
|
||||||
".conf",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".scss",
|
|
||||||
".less",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
# Data science
|
|
||||||
".ipynb",
|
|
||||||
".R",
|
|
||||||
".py",
|
|
||||||
".jl",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Process each directory
|
|
||||||
if directories:
|
|
||||||
print(
|
print(
|
||||||
f"\n🔄 Processing {len(directories)} director{'ies' if len(directories) > 1 else 'y'}..."
|
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
for docs_dir in directories:
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
print(f"Processing directory: {docs_dir}")
|
for i, index_dir in enumerate(index_dirs, 1):
|
||||||
# Build gitignore parser for each directory
|
index_name = index_dir.name
|
||||||
gitignore_matches = self._build_gitignore_parser(docs_dir)
|
status = "✓" if self.index_exists(index_name) else "✗"
|
||||||
|
|
||||||
# Try to use better PDF parsers first, but only if PDFs are requested
|
print(f" {i}. {index_name} [{status}]")
|
||||||
documents = []
|
if self.index_exists(index_name):
|
||||||
docs_path = Path(docs_dir)
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
|
size_mb = sum(
|
||||||
|
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
||||||
|
) / (1024 * 1024)
|
||||||
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
# Check if we should process PDFs
|
if index_dirs:
|
||||||
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
example_name = index_dirs[0].name
|
||||||
|
print(f"\nUsage:")
|
||||||
|
print(f' leann search {example_name} "your query"')
|
||||||
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
if should_process_pdfs:
|
def load_documents(self, docs_dir: str):
|
||||||
for file_path in docs_path.rglob("*.pdf"):
|
print(f"Loading documents from {docs_dir}...")
|
||||||
# Check if file matches any exclude pattern
|
|
||||||
try:
|
|
||||||
relative_path = file_path.relative_to(docs_path)
|
|
||||||
if self._should_exclude_file(relative_path, gitignore_matches):
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
# Skip files that can't be made relative to docs_path
|
|
||||||
print(f"⚠️ Skipping file outside directory scope: {file_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Processing PDF: {file_path}")
|
documents = SimpleDirectoryReader(
|
||||||
|
docs_dir,
|
||||||
# Try PyMuPDF first (best quality)
|
recursive=True,
|
||||||
text = extract_pdf_text_with_pymupdf(str(file_path))
|
encoding="utf-8",
|
||||||
if text is None:
|
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
||||||
# Try pdfplumber
|
).load_data(show_progress=True)
|
||||||
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
|
||||||
|
|
||||||
if text:
|
|
||||||
# Create a simple document structure
|
|
||||||
from llama_index.core import Document
|
|
||||||
|
|
||||||
doc = Document(text=text, metadata={"source": str(file_path)})
|
|
||||||
documents.append(doc)
|
|
||||||
else:
|
|
||||||
# Fallback to default reader
|
|
||||||
print(f"Using default reader for {file_path}")
|
|
||||||
try:
|
|
||||||
default_docs = SimpleDirectoryReader(
|
|
||||||
str(file_path.parent),
|
|
||||||
filename_as_id=True,
|
|
||||||
required_exts=[file_path.suffix],
|
|
||||||
).load_data()
|
|
||||||
documents.extend(default_docs)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Could not process {file_path}: {e}")
|
|
||||||
|
|
||||||
# Load other file types with default reader
|
|
||||||
try:
|
|
||||||
# Create a custom file filter function using our PathSpec
|
|
||||||
def file_filter(
|
|
||||||
file_path: str, docs_dir=docs_dir, gitignore_matches=gitignore_matches
|
|
||||||
) -> bool:
|
|
||||||
"""Return True if file should be included (not excluded)"""
|
|
||||||
try:
|
|
||||||
docs_path_obj = Path(docs_dir)
|
|
||||||
file_path_obj = Path(file_path)
|
|
||||||
relative_path = file_path_obj.relative_to(docs_path_obj)
|
|
||||||
return not self._should_exclude_file(relative_path, gitignore_matches)
|
|
||||||
except (ValueError, OSError):
|
|
||||||
return True # Include files that can't be processed
|
|
||||||
|
|
||||||
other_docs = SimpleDirectoryReader(
|
|
||||||
docs_dir,
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=code_extensions,
|
|
||||||
file_extractor={}, # Use default extractors
|
|
||||||
filename_as_id=True,
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
|
|
||||||
# Filter documents after loading based on gitignore rules
|
|
||||||
filtered_docs = []
|
|
||||||
for doc in other_docs:
|
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
if file_filter(file_path):
|
|
||||||
filtered_docs.append(doc)
|
|
||||||
|
|
||||||
documents.extend(filtered_docs)
|
|
||||||
except ValueError as e:
|
|
||||||
if "No files found" in str(e):
|
|
||||||
print(f"No additional files found for other supported types in {docs_dir}.")
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
all_documents.extend(documents)
|
|
||||||
print(f"Loaded {len(documents)} documents from {docs_dir}")
|
|
||||||
|
|
||||||
documents = all_documents
|
|
||||||
|
|
||||||
all_texts = []
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
# Define code file extensions for intelligent chunking
|
nodes = self.node_parser.get_nodes_from_documents([doc])
|
||||||
code_file_exts = {
|
|
||||||
".py",
|
|
||||||
".js",
|
|
||||||
".ts",
|
|
||||||
".jsx",
|
|
||||||
".tsx",
|
|
||||||
".java",
|
|
||||||
".cpp",
|
|
||||||
".c",
|
|
||||||
".h",
|
|
||||||
".hpp",
|
|
||||||
".cs",
|
|
||||||
".go",
|
|
||||||
".rs",
|
|
||||||
".rb",
|
|
||||||
".php",
|
|
||||||
".swift",
|
|
||||||
".kt",
|
|
||||||
".scala",
|
|
||||||
".r",
|
|
||||||
".sql",
|
|
||||||
".sh",
|
|
||||||
".bash",
|
|
||||||
".zsh",
|
|
||||||
".fish",
|
|
||||||
".ps1",
|
|
||||||
".bat",
|
|
||||||
".json",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".xml",
|
|
||||||
".toml",
|
|
||||||
".ini",
|
|
||||||
".cfg",
|
|
||||||
".conf",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".scss",
|
|
||||||
".less",
|
|
||||||
".vue",
|
|
||||||
".svelte",
|
|
||||||
".ipynb",
|
|
||||||
".R",
|
|
||||||
".jl",
|
|
||||||
}
|
|
||||||
|
|
||||||
print("start chunking documents")
|
|
||||||
# Add progress bar for document chunking
|
|
||||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
|
||||||
# Check if this is a code file based on source path
|
|
||||||
source_path = doc.metadata.get("source", "")
|
|
||||||
is_code_file = any(source_path.endswith(ext) for ext in code_file_exts)
|
|
||||||
|
|
||||||
# Use appropriate parser based on file type
|
|
||||||
parser = self.code_parser if is_code_file else self.node_parser
|
|
||||||
nodes = parser.get_nodes_from_documents([doc])
|
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
@@ -697,36 +162,16 @@ Examples:
|
|||||||
return all_texts
|
return all_texts
|
||||||
|
|
||||||
async def build_index(self, args):
|
async def build_index(self, args):
|
||||||
docs_paths = args.docs
|
docs_dir = args.docs
|
||||||
# Use current directory name if index_name not provided
|
index_name = args.index_name
|
||||||
if args.index_name:
|
|
||||||
index_name = args.index_name
|
|
||||||
else:
|
|
||||||
index_name = Path.cwd().name
|
|
||||||
print(f"Using current directory name as index: '{index_name}'")
|
|
||||||
|
|
||||||
index_dir = self.indexes_dir / index_name
|
index_dir = self.indexes_dir / index_name
|
||||||
index_path = self.get_index_path(index_name)
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
# Display all paths being indexed with file/directory distinction
|
|
||||||
files = [p for p in docs_paths if Path(p).is_file()]
|
|
||||||
directories = [p for p in docs_paths if Path(p).is_dir()]
|
|
||||||
|
|
||||||
print(f"📂 Indexing {len(docs_paths)} path{'s' if len(docs_paths) > 1 else ''}:")
|
|
||||||
if files:
|
|
||||||
print(f" 📄 Files ({len(files)}):")
|
|
||||||
for i, file_path in enumerate(files, 1):
|
|
||||||
print(f" {i}. {Path(file_path).resolve()}")
|
|
||||||
if directories:
|
|
||||||
print(f" 📁 Directories ({len(directories)}):")
|
|
||||||
for i, dir_path in enumerate(directories, 1):
|
|
||||||
print(f" {i}. {Path(dir_path).resolve()}")
|
|
||||||
|
|
||||||
if index_dir.exists() and not args.force:
|
if index_dir.exists() and not args.force:
|
||||||
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
|
||||||
return
|
return
|
||||||
|
|
||||||
all_texts = self.load_documents(docs_paths, args.file_types)
|
all_texts = self.load_documents(docs_dir)
|
||||||
if not all_texts:
|
if not all_texts:
|
||||||
print("No documents found")
|
print("No documents found")
|
||||||
return
|
return
|
||||||
@@ -738,7 +183,6 @@ Examples:
|
|||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.complexity,
|
complexity=args.complexity,
|
||||||
is_compact=args.compact,
|
is_compact=args.compact,
|
||||||
@@ -752,9 +196,6 @@ Examples:
|
|||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"Index built at {index_path}")
|
print(f"Index built at {index_path}")
|
||||||
|
|
||||||
# Register this project directory in global registry
|
|
||||||
self.register_project_dir()
|
|
||||||
|
|
||||||
async def search_documents(self, args):
|
async def search_documents(self, args):
|
||||||
index_name = args.index_name
|
index_name = args.index_name
|
||||||
query = args.query
|
query = args.query
|
||||||
@@ -762,7 +203,7 @@ Examples:
|
|||||||
|
|
||||||
if not self.index_exists(index_name):
|
if not self.index_exists(index_name):
|
||||||
print(
|
print(
|
||||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -789,7 +230,7 @@ Examples:
|
|||||||
|
|
||||||
if not self.index_exists(index_name):
|
if not self.index_exists(index_name):
|
||||||
print(
|
print(
|
||||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -815,11 +256,6 @@ Examples:
|
|||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
response = chat.ask(
|
||||||
user_input,
|
user_input,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
@@ -828,17 +264,11 @@ Examples:
|
|||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
)
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
else:
|
else:
|
||||||
query = input("Enter your question: ").strip()
|
query = input("Enter your question: ").strip()
|
||||||
if query:
|
if query:
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
|
||||||
llm_kwargs = {}
|
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
response = chat.ask(
|
||||||
query,
|
query,
|
||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
@@ -847,7 +277,6 @@ Examples:
|
|||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy,
|
pruning_strategy=args.pruning_strategy,
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
)
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ Consolidates all embedding computation logic using SentenceTransformer
|
|||||||
Preserves all optimization parameters to ensure performance
|
Preserves all optimization parameters to ensure performance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -18,11 +17,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: dict[str, Any] = {}
|
_model_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: list[str],
|
texts: List[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
@@ -35,7 +34,7 @@ def compute_embeddings(
|
|||||||
Args:
|
Args:
|
||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Model name
|
model_name: Model name
|
||||||
mode: Computation mode ('sentence-transformers', 'openai', 'mlx', 'ollama')
|
mode: Computation mode ('sentence-transformers', 'openai', 'mlx')
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
batch_size: Batch size for processing
|
batch_size: Batch size for processing
|
||||||
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
adaptive_optimization: Whether to use adaptive optimization based on batch size
|
||||||
@@ -55,14 +54,12 @@ def compute_embeddings(
|
|||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported embedding mode: {mode}")
|
raise ValueError(f"Unsupported embedding mode: {mode}")
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(
|
def compute_embeddings_sentence_transformers(
|
||||||
texts: list[str],
|
texts: List[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_fp16: bool = True,
|
use_fp16: bool = True,
|
||||||
device: str = "auto",
|
device: str = "auto",
|
||||||
@@ -104,7 +101,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if device == "mps":
|
if device == "mps":
|
||||||
batch_size = 128 # MPS optimal batch size from benchmark
|
batch_size = 128 # MPS optimal batch size from benchmark
|
||||||
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||||
batch_size = 32
|
batch_size = 64
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
batch_size = 256 # CUDA optimal batch size
|
batch_size = 256 # CUDA optimal batch size
|
||||||
# Keep original batch_size for CPU
|
# Keep original batch_size for CPU
|
||||||
@@ -117,7 +114,9 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Using cached optimized model: {model_name}")
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
model = _model_cache[cache_key]
|
model = _model_cache[cache_key]
|
||||||
else:
|
else:
|
||||||
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
|
logger.info(
|
||||||
|
f"Loading and caching optimized SentenceTransformer model: {model_name}"
|
||||||
|
)
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
logger.info(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
@@ -135,7 +134,9 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||||
torch.mps.set_per_process_memory_fraction(0.9)
|
torch.mps.set_per_process_memory_fraction(0.9)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning("Some MPS optimizations not available in this PyTorch version")
|
logger.warning(
|
||||||
|
"Some MPS optimizations not available in this PyTorch version"
|
||||||
|
)
|
||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
# TODO: Haven't tested this yet
|
# TODO: Haven't tested this yet
|
||||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||||
@@ -225,22 +226,25 @@ def compute_embeddings_sentence_transformers(
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(
|
||||||
|
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
raise RuntimeError(
|
||||||
|
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
import os
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
import os
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
@@ -260,10 +264,9 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
print(f"len of texts: {len(texts)}")
|
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 1000 # Conservative batch size
|
max_batch_size = 100 # Conservative batch size
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -290,12 +293,15 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(
|
||||||
print(f"len of embeddings: {len(embeddings)}")
|
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
||||||
|
)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
def compute_embeddings_mlx(
|
||||||
|
chunks: List[str], model_name: str, batch_size: int = 16
|
||||||
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Computes embeddings using an MLX model."""
|
"""Computes embeddings using an MLX model."""
|
||||||
try:
|
try:
|
||||||
@@ -367,286 +373,3 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
# Stack numpy arrays
|
# Stack numpy arrays
|
||||||
return np.stack(all_embeddings)
|
return np.stack(all_embeddings)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
|
||||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
|
||||||
|
|
||||||
Uses batch size of 32 for MPS/CPU and 128 for CUDA to optimize performance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to compute embeddings for
|
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import requests
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'requests' library is required for Ollama embeddings. Install with: uv pip install requests"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Ollama is running
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
error_msg = (
|
|
||||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
|
||||||
"Please ensure Ollama is running:\n"
|
|
||||||
" • macOS/Linux: ollama serve\n"
|
|
||||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
|
||||||
"Installation: https://ollama.com/download"
|
|
||||||
)
|
|
||||||
raise RuntimeError(error_msg)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Unexpected error connecting to Ollama: {e}")
|
|
||||||
|
|
||||||
# Check if model exists and provide helpful suggestions
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
models = response.json()
|
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
|
||||||
|
|
||||||
# Filter for embedding models (models that support embeddings)
|
|
||||||
embedding_models = []
|
|
||||||
suggested_embedding_models = [
|
|
||||||
"nomic-embed-text",
|
|
||||||
"mxbai-embed-large",
|
|
||||||
"bge-m3",
|
|
||||||
"all-minilm",
|
|
||||||
"snowflake-arctic-embed",
|
|
||||||
]
|
|
||||||
|
|
||||||
for model in model_names:
|
|
||||||
# Check if it's an embedding model (by name patterns or known models)
|
|
||||||
base_name = model.split(":")[0]
|
|
||||||
if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5"]):
|
|
||||||
embedding_models.append(model)
|
|
||||||
|
|
||||||
# Check if model exists (handle versioned names) and resolve to full name
|
|
||||||
resolved_model_name = None
|
|
||||||
for name in model_names:
|
|
||||||
# Exact match
|
|
||||||
if model_name == name:
|
|
||||||
resolved_model_name = name
|
|
||||||
break
|
|
||||||
# Match without version tag (use the versioned name)
|
|
||||||
elif model_name == name.split(":")[0]:
|
|
||||||
resolved_model_name = name
|
|
||||||
break
|
|
||||||
|
|
||||||
if not resolved_model_name:
|
|
||||||
error_msg = f"❌ Model '{model_name}' not found in local Ollama.\n\n"
|
|
||||||
|
|
||||||
# Suggest pulling the model
|
|
||||||
error_msg += "📦 To install this embedding model:\n"
|
|
||||||
error_msg += f" ollama pull {model_name}\n\n"
|
|
||||||
|
|
||||||
# Show available embedding models
|
|
||||||
if embedding_models:
|
|
||||||
error_msg += "✅ Available embedding models:\n"
|
|
||||||
for model in embedding_models[:5]:
|
|
||||||
error_msg += f" • {model}\n"
|
|
||||||
if len(embedding_models) > 5:
|
|
||||||
error_msg += f" ... and {len(embedding_models) - 5} more\n"
|
|
||||||
else:
|
|
||||||
error_msg += "💡 Popular embedding models to install:\n"
|
|
||||||
for model in suggested_embedding_models[:3]:
|
|
||||||
error_msg += f" • ollama pull {model}\n"
|
|
||||||
|
|
||||||
error_msg += "\n📚 Browse more: https://ollama.com/library"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
# Use the resolved model name for all subsequent operations
|
|
||||||
if resolved_model_name != model_name:
|
|
||||||
logger.info(f"Resolved model name '{model_name}' to '{resolved_model_name}'")
|
|
||||||
model_name = resolved_model_name
|
|
||||||
|
|
||||||
# Verify the model supports embeddings by testing it
|
|
||||||
try:
|
|
||||||
test_response = requests.post(
|
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
|
||||||
)
|
|
||||||
if test_response.status_code != 200:
|
|
||||||
error_msg = (
|
|
||||||
f"⚠️ Model '{model_name}' exists but may not support embeddings.\n\n"
|
|
||||||
f"Please use an embedding model like:\n"
|
|
||||||
)
|
|
||||||
for model in suggested_embedding_models[:3]:
|
|
||||||
error_msg += f" • {model}\n"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
# If test fails, continue anyway - model might still work
|
|
||||||
pass
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.warning(f"Could not verify model existence: {e}")
|
|
||||||
|
|
||||||
# Determine batch size based on device availability
|
|
||||||
# Check for CUDA/MPS availability using torch if available
|
|
||||||
batch_size = 32 # Default for MPS/CPU
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
batch_size = 128 # CUDA gets larger batch size
|
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
||||||
batch_size = 32 # MPS gets smaller batch size
|
|
||||||
except ImportError:
|
|
||||||
# If torch is not available, use conservative batch size
|
|
||||||
batch_size = 32
|
|
||||||
|
|
||||||
logger.info(f"Using batch size: {batch_size}")
|
|
||||||
|
|
||||||
def get_batch_embeddings(batch_texts):
|
|
||||||
"""Get embeddings for a batch of texts."""
|
|
||||||
all_embeddings = []
|
|
||||||
failed_indices = []
|
|
||||||
|
|
||||||
for i, text in enumerate(batch_texts):
|
|
||||||
max_retries = 3
|
|
||||||
retry_count = 0
|
|
||||||
|
|
||||||
# Truncate very long texts to avoid API issues
|
|
||||||
truncated_text = text[:8000] if len(text) > 8000 else text
|
|
||||||
while retry_count < max_retries:
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{host}/api/embeddings",
|
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
embedding = result.get("embedding")
|
|
||||||
|
|
||||||
if embedding is None:
|
|
||||||
raise ValueError(f"No embedding returned for text {i}")
|
|
||||||
|
|
||||||
if not isinstance(embedding, list) or len(embedding) == 0:
|
|
||||||
raise ValueError(f"Invalid embedding format for text {i}")
|
|
||||||
|
|
||||||
all_embeddings.append(embedding)
|
|
||||||
break
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.warning(f"Timeout for text {i} after {max_retries} retries")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count >= max_retries:
|
|
||||||
logger.error(f"Failed to get embedding for text {i}: {e}")
|
|
||||||
failed_indices.append(i)
|
|
||||||
all_embeddings.append(None)
|
|
||||||
break
|
|
||||||
return all_embeddings, failed_indices
|
|
||||||
|
|
||||||
# Process texts in batches
|
|
||||||
all_embeddings = []
|
|
||||||
all_failed_indices = []
|
|
||||||
|
|
||||||
# Setup progress bar if needed
|
|
||||||
show_progress = is_build or len(texts) > 10
|
|
||||||
try:
|
|
||||||
if show_progress:
|
|
||||||
from tqdm import tqdm
|
|
||||||
except ImportError:
|
|
||||||
show_progress = False
|
|
||||||
|
|
||||||
# Process batches
|
|
||||||
num_batches = (len(texts) + batch_size - 1) // batch_size
|
|
||||||
|
|
||||||
if show_progress:
|
|
||||||
batch_iterator = tqdm(range(num_batches), desc="Computing Ollama embeddings")
|
|
||||||
else:
|
|
||||||
batch_iterator = range(num_batches)
|
|
||||||
|
|
||||||
for batch_idx in batch_iterator:
|
|
||||||
start_idx = batch_idx * batch_size
|
|
||||||
end_idx = min(start_idx + batch_size, len(texts))
|
|
||||||
batch_texts = texts[start_idx:end_idx]
|
|
||||||
|
|
||||||
batch_embeddings, batch_failed = get_batch_embeddings(batch_texts)
|
|
||||||
|
|
||||||
# Adjust failed indices to global indices
|
|
||||||
global_failed = [start_idx + idx for idx in batch_failed]
|
|
||||||
all_failed_indices.extend(global_failed)
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
|
||||||
# Handle failed embeddings
|
|
||||||
if all_failed_indices:
|
|
||||||
if len(all_failed_indices) == len(texts):
|
|
||||||
raise RuntimeError("Failed to compute any embeddings")
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to compute embeddings for {len(all_failed_indices)}/{len(texts)} texts"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use zero embeddings as fallback for failed ones
|
|
||||||
valid_embedding = next((e for e in all_embeddings if e is not None), None)
|
|
||||||
if valid_embedding:
|
|
||||||
embedding_dim = len(valid_embedding)
|
|
||||||
for i, embedding in enumerate(all_embeddings):
|
|
||||||
if embedding is None:
|
|
||||||
all_embeddings[i] = [0.0] * embedding_dim
|
|
||||||
|
|
||||||
# Remove None values
|
|
||||||
all_embeddings = [e for e in all_embeddings if e is not None]
|
|
||||||
|
|
||||||
if not all_embeddings:
|
|
||||||
raise RuntimeError("No valid embeddings were computed")
|
|
||||||
|
|
||||||
# Validate embedding dimensions
|
|
||||||
expected_dim = len(all_embeddings[0])
|
|
||||||
inconsistent_dims = []
|
|
||||||
for i, embedding in enumerate(all_embeddings):
|
|
||||||
if len(embedding) != expected_dim:
|
|
||||||
inconsistent_dims.append((i, len(embedding)))
|
|
||||||
|
|
||||||
if inconsistent_dims:
|
|
||||||
error_msg = f"Ollama returned inconsistent embedding dimensions. Expected {expected_dim}, but got:\n"
|
|
||||||
for idx, dim in inconsistent_dims[:10]: # Show first 10 inconsistent ones
|
|
||||||
error_msg += f" - Text {idx}: {dim} dimensions\n"
|
|
||||||
if len(inconsistent_dims) > 10:
|
|
||||||
error_msg += f" ... and {len(inconsistent_dims) - 10} more\n"
|
|
||||||
error_msg += f"\nThis is likely an Ollama API bug with model '{model_name}'. Please try:\n"
|
|
||||||
error_msg += "1. Restart Ollama service: 'ollama serve'\n"
|
|
||||||
error_msg += f"2. Re-pull the model: 'ollama pull {model_name}'\n"
|
|
||||||
error_msg += (
|
|
||||||
"3. Use sentence-transformers instead: --embedding-mode sentence-transformers\n"
|
|
||||||
)
|
|
||||||
error_msg += "4. Report this issue to Ollama: https://github.com/ollama/ollama/issues"
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
|
|
||||||
# Convert to numpy array and normalize
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
|
||||||
|
|
||||||
# Normalize embeddings (L2 normalization)
|
|
||||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
||||||
embeddings = embeddings / (norms + 1e-8) # Add small epsilon to avoid division by zero
|
|
||||||
|
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
|
import time
|
||||||
import atexit
|
import atexit
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import os
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import psutil
|
||||||
# Lightweight, self-contained server manager with no cross-process inspection
|
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -19,31 +18,136 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _is_colab_environment() -> bool:
|
|
||||||
"""Check if we're running in Google Colab environment."""
|
|
||||||
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
|
|
||||||
|
|
||||||
|
|
||||||
def _get_available_port(start_port: int = 5557) -> int:
|
|
||||||
"""Get an available port starting from start_port."""
|
|
||||||
port = start_port
|
|
||||||
while port < start_port + 100: # Try up to 100 ports
|
|
||||||
try:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("localhost", port))
|
|
||||||
return port
|
|
||||||
except OSError:
|
|
||||||
port += 1
|
|
||||||
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
def _check_port(port: int) -> bool:
|
||||||
"""Check if a port is in use"""
|
"""Check if a port is in use"""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
|
||||||
# Note: All cross-process scanning helpers removed for simplicity
|
def _check_process_matches_config(
|
||||||
|
port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the process using the port matches our expected model and passages file.
|
||||||
|
Returns True if matches, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for proc in psutil.process_iter(["pid", "cmdline"]):
|
||||||
|
if not _is_process_listening_on_port(proc, port):
|
||||||
|
continue
|
||||||
|
|
||||||
|
cmdline = proc.info["cmdline"]
|
||||||
|
if not cmdline:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return _check_cmdline_matches_config(
|
||||||
|
cmdline, port, expected_model, expected_passages_file
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"No process found listening on port {port}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not check process on port {port}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_process_listening_on_port(proc, port: int) -> bool:
|
||||||
|
"""Check if a process is listening on the given port."""
|
||||||
|
try:
|
||||||
|
connections = proc.net_connections()
|
||||||
|
for conn in connections:
|
||||||
|
if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_cmdline_matches_config(
|
||||||
|
cmdline: list, port: int, expected_model: str, expected_passages_file: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if command line matches our expected configuration."""
|
||||||
|
cmdline_str = " ".join(cmdline)
|
||||||
|
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
|
# Check if it's our embedding server
|
||||||
|
is_embedding_server = any(
|
||||||
|
server_type in cmdline_str
|
||||||
|
for server_type in [
|
||||||
|
"embedding_server",
|
||||||
|
"leann_backend_diskann.embedding_server",
|
||||||
|
"leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_embedding_server:
|
||||||
|
logger.debug(f"Process on port {port} is not our embedding server")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check model name
|
||||||
|
model_matches = _check_model_in_cmdline(cmdline, expected_model)
|
||||||
|
|
||||||
|
# Check passages file if provided
|
||||||
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
|
result = model_matches and passages_matches
|
||||||
|
logger.debug(
|
||||||
|
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected model."""
|
||||||
|
if "--model-name" not in cmdline:
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_idx = cmdline.index("--model-name")
|
||||||
|
if model_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_model = cmdline[model_idx + 1]
|
||||||
|
return actual_model == expected_model
|
||||||
|
|
||||||
|
|
||||||
|
def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool:
|
||||||
|
"""Check if the command line contains the expected passages file."""
|
||||||
|
if "--passages-file" not in cmdline:
|
||||||
|
return False # Expected but not found
|
||||||
|
|
||||||
|
passages_idx = cmdline.index("--passages-file")
|
||||||
|
if passages_idx + 1 >= len(cmdline):
|
||||||
|
return False
|
||||||
|
|
||||||
|
actual_passages = cmdline[passages_idx + 1]
|
||||||
|
expected_path = Path(expected_passages_file).resolve()
|
||||||
|
actual_path = Path(actual_passages).resolve()
|
||||||
|
return actual_path == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def _find_compatible_port_or_next_available(
|
||||||
|
start_port: int, model_name: str, passages_file: str, max_attempts: int = 100
|
||||||
|
) -> tuple[int, bool]:
|
||||||
|
"""
|
||||||
|
Find a port that either has a compatible server or is available.
|
||||||
|
Returns (port, is_compatible) where is_compatible indicates if we found a matching server.
|
||||||
|
"""
|
||||||
|
for port in range(start_port, start_port + max_attempts):
|
||||||
|
if not _check_port(port):
|
||||||
|
# Port is available
|
||||||
|
return port, False
|
||||||
|
|
||||||
|
# Port is in use, check if it's compatible
|
||||||
|
if _check_process_matches_config(port, model_name, passages_file):
|
||||||
|
logger.info(f"Found compatible server on port {port}")
|
||||||
|
return port, True
|
||||||
|
else:
|
||||||
|
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
@@ -62,16 +166,7 @@ class EmbeddingServerManager:
|
|||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: Optional[subprocess.Popen] = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: Optional[int] = None
|
||||||
# Track last-started config for in-process reuse only
|
|
||||||
self._server_config: Optional[dict] = None
|
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
# Also register a weakref finalizer to ensure cleanup when manager is GC'ed
|
|
||||||
try:
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, self._finalize_process)
|
|
||||||
except Exception:
|
|
||||||
self._finalizer = None
|
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
self,
|
self,
|
||||||
@@ -80,58 +175,69 @@ class EmbeddingServerManager:
|
|||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
Starts the embedding server process.
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
Args:
|
||||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
port (int): The preferred ZMQ port for the server.
|
||||||
logger.info("Reusing in-process server")
|
model_name (str): The name of the embedding model to use.
|
||||||
|
**kwargs: Additional arguments for the server.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, int]: (success, actual_port_used)
|
||||||
|
"""
|
||||||
|
passages_file = kwargs.get("passages_file")
|
||||||
|
assert isinstance(passages_file, str), "passages_file must be a string"
|
||||||
|
|
||||||
|
# Check if we have a compatible running server
|
||||||
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
|
assert self.server_port is not None, (
|
||||||
|
"a compatible running server should set server_port"
|
||||||
|
)
|
||||||
return True, self.server_port
|
return True, self.server_port
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# Find available port (compatible or free)
|
||||||
if _is_colab_environment():
|
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
|
||||||
|
|
||||||
# Always pick a fresh available port
|
|
||||||
try:
|
try:
|
||||||
actual_port = _get_available_port(port)
|
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||||
except RuntimeError:
|
port, model_name, passages_file
|
||||||
logger.error("No available ports found")
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(str(e))
|
||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
# Start a new server
|
if is_compatible:
|
||||||
|
logger.info(f"Using existing compatible server on port {actual_port}")
|
||||||
|
self.server_port = actual_port
|
||||||
|
self.server_process = None # We don't own this process
|
||||||
|
return True, actual_port
|
||||||
|
|
||||||
|
if actual_port != port:
|
||||||
|
logger.info(f"Using port {actual_port} instead of {port}")
|
||||||
|
|
||||||
|
# Start new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
def _start_server_colab(
|
def _has_compatible_running_server(
|
||||||
self,
|
self, model_name: str, passages_file: str
|
||||||
port: int,
|
) -> bool:
|
||||||
model_name: str,
|
"""Check if we have a compatible running server."""
|
||||||
embedding_mode: str = "sentence-transformers",
|
if not (
|
||||||
**kwargs,
|
self.server_process
|
||||||
) -> tuple[bool, int]:
|
and self.server_process.poll() is None
|
||||||
"""Start server with Colab-specific configuration."""
|
and self.server_port
|
||||||
# Try to find an available port
|
):
|
||||||
try:
|
return False
|
||||||
actual_port = _get_available_port(port)
|
|
||||||
except RuntimeError:
|
|
||||||
logger.error("No available ports found")
|
|
||||||
return False, port
|
|
||||||
|
|
||||||
logger.info(f"Starting server on port {actual_port} for Colab environment")
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
|
logger.info(
|
||||||
|
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
# Use a simpler startup strategy for Colab
|
logger.info(
|
||||||
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
|
"Existing server process is incompatible. Should start a new server."
|
||||||
|
)
|
||||||
try:
|
return False
|
||||||
# In Colab, we'll use a more direct approach
|
|
||||||
self._launch_server_process_colab(command, actual_port)
|
|
||||||
return self._wait_for_server_ready_colab(actual_port)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
|
||||||
return False, actual_port
|
|
||||||
|
|
||||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
@@ -163,13 +269,9 @@ class EmbeddingServerManager:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if kwargs.get("passages_file"):
|
if kwargs.get("passages_file"):
|
||||||
# Convert to absolute path to ensure subprocess can find the file
|
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
||||||
passages_file = Path(kwargs["passages_file"]).resolve()
|
|
||||||
command.extend(["--passages-file", str(passages_file)])
|
|
||||||
if embedding_mode != "sentence-transformers":
|
if embedding_mode != "sentence-transformers":
|
||||||
command.extend(["--embedding-mode", embedding_mode])
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
if kwargs.get("distance_metric"):
|
|
||||||
command.extend(["--distance-metric", kwargs["distance_metric"]])
|
|
||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
@@ -178,61 +280,22 @@ class EmbeddingServerManager:
|
|||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
# In CI environment, redirect stdout to avoid buffer deadlock but keep stderr for debugging
|
# Let server output go directly to console
|
||||||
# Embedding servers use many print statements that can fill stdout buffers
|
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||||
is_ci = os.environ.get("CI") == "true"
|
|
||||||
if is_ci:
|
|
||||||
stdout_target = subprocess.DEVNULL
|
|
||||||
stderr_target = None # Keep stderr for error debugging in CI
|
|
||||||
logger.info(
|
|
||||||
"CI environment detected, redirecting embedding server stdout to DEVNULL, keeping stderr"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stdout_target = None # Direct to console for visible logs
|
|
||||||
stderr_target = None # Direct to console for visible logs
|
|
||||||
|
|
||||||
# Start embedding server subprocess
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=stdout_target,
|
stdout=None, # Direct to console
|
||||||
stderr=stderr_target,
|
stderr=None, # Direct to console
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse
|
|
||||||
try:
|
|
||||||
self._server_config = {
|
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
|
||||||
if "--model-name" in command
|
|
||||||
else "",
|
|
||||||
"passages_file": command[command.index("--passages-file") + 1]
|
|
||||||
if "--passages-file" in command
|
|
||||||
else "",
|
|
||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
|
||||||
if "--embedding-mode" in command
|
|
||||||
else "sentence-transformers",
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
self._server_config = {
|
|
||||||
"model_name": "",
|
|
||||||
"passages_file": "",
|
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
}
|
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
# Always attempt best-effort finalize at interpreter exit
|
# Use a lambda to avoid issues with bound methods
|
||||||
atexit.register(self._finalize_process)
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
# Touch finalizer so it knows there is a live process
|
|
||||||
if getattr(self, "_finalizer", None) is not None and not self._finalizer.alive:
|
|
||||||
try:
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
self._finalizer = weakref.finalize(self, self._finalize_process)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready."""
|
"""Wait for the server to be ready."""
|
||||||
@@ -257,114 +320,29 @@ class EmbeddingServerManager:
|
|||||||
if not self.server_process:
|
if not self.server_process:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.server_process and self.server_process.poll() is not None:
|
if self.server_process.poll() is not None:
|
||||||
# Process already terminated
|
# Process already terminated
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
self.server_port = None
|
|
||||||
self._server_config = None
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
|
self.server_process.terminate()
|
||||||
# Use simple termination first; if the server installed signal handlers,
|
|
||||||
# it will exit cleanly. Otherwise escalate to kill after a short wait.
|
|
||||||
try:
|
|
||||||
self.server_process.terminate()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=5) # Give more time for graceful shutdown
|
self.server_process.wait(timeout=5)
|
||||||
logger.info(f"Server process {self.server_process.pid} terminated gracefully.")
|
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Server process {self.server_process.pid} did not terminate within 5 seconds, force killing..."
|
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
)
|
)
|
||||||
try:
|
self.server_process.kill()
|
||||||
self.server_process.kill()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
self.server_process.wait(timeout=2)
|
|
||||||
logger.info(f"Server process {self.server_process.pid} killed successfully.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to kill server process {self.server_process.pid} - it may be hung"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clean up process resources with timeout to avoid CI hang
|
# Clean up process resources to prevent resource tracker warnings
|
||||||
try:
|
try:
|
||||||
# Use shorter timeout in CI environments
|
self.server_process.wait() # Ensure process is fully cleaned up
|
||||||
is_ci = os.environ.get("CI") == "true"
|
|
||||||
timeout = 3 if is_ci else 10
|
|
||||||
self.server_process.wait(timeout=timeout)
|
|
||||||
logger.info(f"Server process {self.server_process.pid} cleanup completed")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning(f"Process cleanup timeout after {timeout}s, proceeding anyway")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error during process cleanup: {e}")
|
|
||||||
finally:
|
|
||||||
self.server_process = None
|
|
||||||
self.server_port = None
|
|
||||||
self._server_config = None
|
|
||||||
|
|
||||||
def _finalize_process(self) -> None:
|
|
||||||
"""Best-effort cleanup used by weakref.finalize/atexit."""
|
|
||||||
try:
|
|
||||||
self.stop_server()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _adopt_existing_server(self, *args, **kwargs) -> None:
|
self.server_process = None
|
||||||
# Removed: cross-process adoption no longer supported
|
|
||||||
return
|
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
|
||||||
"""Launch the server process with Colab-specific settings."""
|
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
|
||||||
|
|
||||||
# In Colab, we need to be more careful about process management
|
|
||||||
self.server_process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
self.server_port = port
|
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
|
||||||
|
|
||||||
# Register atexit callback (unified)
|
|
||||||
if not self._atexit_registered:
|
|
||||||
atexit.register(self._finalize_process)
|
|
||||||
self._atexit_registered = True
|
|
||||||
# Record config for in-process reuse is best-effort in Colab mode
|
|
||||||
self._server_config = {
|
|
||||||
"model_name": "",
|
|
||||||
"passages_file": "",
|
|
||||||
"embedding_mode": "sentence-transformers",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
|
||||||
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
|
|
||||||
|
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
|
||||||
if _check_port(port):
|
|
||||||
logger.info("Colab embedding server is ready!")
|
|
||||||
return True, port
|
|
||||||
|
|
||||||
if self.server_process and self.server_process.poll() is not None:
|
|
||||||
# Check for error output
|
|
||||||
stdout, stderr = self.server_process.communicate()
|
|
||||||
logger.error("Colab server terminated during startup.")
|
|
||||||
logger.error(f"stdout: {stdout}")
|
|
||||||
logger.error(f"stderr: {stderr}")
|
|
||||||
return False, port
|
|
||||||
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
|
|
||||||
logger.error(f"Colab server failed to start within {max_wait} seconds.")
|
|
||||||
self.stop_server()
|
|
||||||
return False, port
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendBuilderInterface(ABC):
|
class LeannBackendBuilderInterface(ABC):
|
||||||
"""Backend interface for building indexes"""
|
"""Backend interface for building indexes"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
|
def build(
|
||||||
|
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
|
||||||
|
) -> None:
|
||||||
"""Build index
|
"""Build index
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -52,7 +53,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def handle_request(request):
|
|
||||||
if request.get("method") == "initialize":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"serverInfo": {"name": "leann-mcp", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
elif request.get("method") == "tools/list":
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "leann_search",
|
|
||||||
"description": """🔍 Search code using natural language - like having a coding assistant who knows your entire codebase!
|
|
||||||
|
|
||||||
🎯 **Perfect for**:
|
|
||||||
- "How does authentication work?" → finds auth-related code
|
|
||||||
- "Error handling patterns" → locates try-catch blocks and error logic
|
|
||||||
- "Database connection setup" → finds DB initialization code
|
|
||||||
- "API endpoint definitions" → locates route handlers
|
|
||||||
- "Configuration management" → finds config files and usage
|
|
||||||
|
|
||||||
💡 **Pro tip**: Use this before making any changes to understand existing patterns and conventions.""",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Name of the LEANN index to search. Use 'leann_list' first to see available indexes.",
|
|
||||||
},
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Search query - can be natural language (e.g., 'how to handle errors') or technical terms (e.g., 'async function definition')",
|
|
||||||
},
|
|
||||||
"top_k": {
|
|
||||||
"type": "integer",
|
|
||||||
"default": 5,
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 20,
|
|
||||||
"description": "Number of search results to return. Use 5-10 for focused results, 15-20 for comprehensive exploration.",
|
|
||||||
},
|
|
||||||
"complexity": {
|
|
||||||
"type": "integer",
|
|
||||||
"default": 32,
|
|
||||||
"minimum": 16,
|
|
||||||
"maximum": 128,
|
|
||||||
"description": "Search complexity level. Use 16-32 for fast searches (recommended), 64+ for higher precision when needed.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["index_name", "query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "leann_status",
|
|
||||||
"description": "📊 Check the health and stats of your code indexes - like a medical checkup for your codebase knowledge!",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"index_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Optional: Name of specific index to check. If not provided, shows status of all indexes.",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "leann_list",
|
|
||||||
"description": "📋 Show all your indexed codebases - your personal code library! Use this to see what's available for search.",
|
|
||||||
"inputSchema": {"type": "object", "properties": {}},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
elif request.get("method") == "tools/call":
|
|
||||||
tool_name = request["params"]["name"]
|
|
||||||
args = request["params"].get("arguments", {})
|
|
||||||
|
|
||||||
try:
|
|
||||||
if tool_name == "leann_search":
|
|
||||||
# Validate required parameters
|
|
||||||
if not args.get("index_name") or not args.get("query"):
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "Error: Both index_name and query are required",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Build simplified command
|
|
||||||
cmd = [
|
|
||||||
"leann",
|
|
||||||
"search",
|
|
||||||
args["index_name"],
|
|
||||||
args["query"],
|
|
||||||
f"--top-k={args.get('top_k', 5)}",
|
|
||||||
f"--complexity={args.get('complexity', 32)}",
|
|
||||||
]
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
||||||
|
|
||||||
elif tool_name == "leann_status":
|
|
||||||
if args.get("index_name"):
|
|
||||||
# Check specific index status - for now, we'll use leann list and filter
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
# We could enhance this to show more detailed status per index
|
|
||||||
else:
|
|
||||||
# Show all indexes status
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
|
|
||||||
elif tool_name == "leann_list":
|
|
||||||
result = subprocess.run(["leann", "list"], capture_output=True, text=True)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"result": {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": result.stdout
|
|
||||||
if result.returncode == 0
|
|
||||||
else f"Error: {result.stderr}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": request.get("id"),
|
|
||||||
"error": {"code": -1, "message": str(e)},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
for line in sys.stdin:
|
|
||||||
try:
|
|
||||||
request = json.loads(line.strip())
|
|
||||||
response = handle_request(request)
|
|
||||||
if response:
|
|
||||||
print(json.dumps(response))
|
|
||||||
sys.stdout.flush()
|
|
||||||
except Exception as e:
|
|
||||||
error_response = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": None,
|
|
||||||
"error": {"code": -1, "message": str(e)},
|
|
||||||
}
|
|
||||||
print(json.dumps(error_response))
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
# packages/leann-core/src/leann/registry.py
|
# packages/leann-core/src/leann/registry.py
|
||||||
|
|
||||||
|
from typing import Dict, TYPE_CHECKING
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_backend(name: str):
|
def register_backend(name: str):
|
||||||
@@ -31,11 +31,13 @@ def autodiscover_backends():
|
|||||||
backend_module_name = dist_name.replace("-", "_")
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
discovered_backends.append(backend_module_name)
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
for backend_module_name in sorted(
|
||||||
|
discovered_backends
|
||||||
|
): # sort for deterministic loading
|
||||||
try:
|
try:
|
||||||
importlib.import_module(backend_module_name)
|
importlib.import_module(backend_module_name)
|
||||||
# Registration message is printed by the decorator
|
# Registration message is printed by the decorator
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
pass
|
pass
|
||||||
# print("INFO: Backend auto-discovery finished.")
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Dict, Any, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -38,7 +38,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
|
|
||||||
self.embedding_model = self.meta.get("embedding_model")
|
self.embedding_model = self.meta.get("embedding_model")
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
print(
|
||||||
|
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
||||||
|
)
|
||||||
|
|
||||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
|
||||||
@@ -46,40 +48,39 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_meta(self) -> dict[str, Any]:
|
def _load_meta(self) -> Dict[str, Any]:
|
||||||
"""Loads the metadata file associated with the index."""
|
"""Loads the metadata file associated with the index."""
|
||||||
# This is the corrected logic for finding the meta file.
|
# This is the corrected logic for finding the meta file.
|
||||||
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
if not meta_path.exists():
|
if not meta_path.exists():
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
|
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
|
||||||
with open(meta_path, encoding="utf-8") as f:
|
with open(meta_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
|
def _ensure_server_running(
|
||||||
|
self, passages_source_file: str, port: int, **kwargs
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Ensures the embedding server is running if recompute is needed.
|
Ensures the embedding server is running if recompute is needed.
|
||||||
This is a helper for subclasses.
|
This is a helper for subclasses.
|
||||||
"""
|
"""
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
raise ValueError(
|
||||||
|
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
||||||
# Get distance_metric from meta if not provided in kwargs
|
)
|
||||||
distance_metric = (
|
|
||||||
kwargs.get("distance_metric")
|
|
||||||
or self.meta.get("backend_kwargs", {}).get("distance_metric")
|
|
||||||
or "mips"
|
|
||||||
)
|
|
||||||
|
|
||||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
port=port,
|
port=port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
embedding_mode=self.embedding_mode,
|
embedding_mode=self.embedding_mode,
|
||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=kwargs.get("distance_metric"),
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(
|
||||||
|
f"Failed to start embedding server on port {actual_port}"
|
||||||
|
)
|
||||||
|
|
||||||
return actual_port
|
return actual_port
|
||||||
|
|
||||||
@@ -108,10 +109,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
# on that port?
|
# on that port?
|
||||||
|
|
||||||
# Ensure we have a server with passages_file for compatibility
|
# Ensure we have a server with passages_file for compatibility
|
||||||
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
|
passages_source_file = (
|
||||||
# Convert to absolute path to ensure server can find it
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
|
)
|
||||||
zmq_port = self._ensure_server_running(
|
zmq_port = self._ensure_server_running(
|
||||||
str(passages_source_file.resolve()), zmq_port
|
str(passages_source_file), zmq_port
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
@@ -129,8 +131,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
|
|
||||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
"""Compute embeddings using the ZMQ embedding server."""
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
import msgpack
|
|
||||||
import zmq
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
@@ -171,7 +173,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for the top_k nearest neighbors of the query vector.
|
Search for the top_k nearest neighbors of the query vector.
|
||||||
|
|
||||||
|
|||||||
@@ -1,119 +0,0 @@
|
|||||||
# 🔥 LEANN Claude Code Integration
|
|
||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
Install LEANN globally for MCP integration (with default backend):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv tool install leann-core --with leann
|
|
||||||
```
|
|
||||||
This installs the `leann` CLI into an isolated tool environment and includes both backends so `leann build` works out-of-the-box.
|
|
||||||
|
|
||||||
## 🚀 Quick Setup
|
|
||||||
|
|
||||||
Add the LEANN MCP server to Claude Code:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
claude mcp add leann-server -- leann_mcp
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🛠️ Available Tools
|
|
||||||
|
|
||||||
Once connected, you'll have access to these powerful semantic search tools in Claude Code:
|
|
||||||
|
|
||||||
- **`leann_list`** - List all available indexes across your projects
|
|
||||||
- **`leann_search`** - Perform semantic searches across code and documents
|
|
||||||
- **`leann_ask`** - Ask natural language questions and get AI-powered answers from your codebase
|
|
||||||
|
|
||||||
## 🎯 Quick Start Example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build an index for your project (change to your actual path)
|
|
||||||
leann build my-project --docs ./
|
|
||||||
|
|
||||||
# Start Claude Code
|
|
||||||
claude
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🚀 Advanced Usage Examples
|
|
||||||
|
|
||||||
### Index Entire Git Repository
|
|
||||||
```bash
|
|
||||||
# Index all tracked files in your git repository, note right now we will skip submodules, but we can add it back easily if you want
|
|
||||||
leann build my-repo --docs $(git ls-files) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Index only specific file types from git
|
|
||||||
leann build my-python-code --docs $(git ls-files "*.py") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multiple Directories and Files
|
|
||||||
```bash
|
|
||||||
# Index multiple directories
|
|
||||||
leann build my-codebase --docs ./src ./tests ./docs ./config --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Mix files and directories
|
|
||||||
leann build my-project --docs ./README.md ./src/ ./package.json ./docs/ --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Specific files only
|
|
||||||
leann build my-configs --docs ./tsconfig.json ./package.json ./webpack.config.js --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
```
|
|
||||||
|
|
||||||
### Advanced Git Integration
|
|
||||||
```bash
|
|
||||||
# Index recently modified files
|
|
||||||
leann build recent-changes --docs $(git diff --name-only HEAD~10..HEAD) --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Index files matching pattern
|
|
||||||
leann build frontend --docs $(git ls-files "*.tsx" "*.ts" "*.jsx" "*.js") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
|
|
||||||
# Index documentation and config files
|
|
||||||
leann build docs-and-configs --docs $(git ls-files "*.md" "*.yml" "*.yaml" "*.json" "*.toml") --embedding-mode sentence-transformers --embedding-model all-MiniLM-L6-v2 --backend hnsw
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
**Try this in Claude Code:**
|
|
||||||
```
|
|
||||||
Help me understand this codebase. List available indexes and search for authentication patterns.
|
|
||||||
```
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="../../assets/claude_code_leann.png" alt="LEANN in Claude Code" width="80%">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
## 🧠 How It Works
|
|
||||||
|
|
||||||
The integration consists of three key components working seamlessly together:
|
|
||||||
|
|
||||||
- **`leann`** - Core CLI tool for indexing and searching (installed globally via `uv tool install`)
|
|
||||||
- **`leann_mcp`** - MCP server that wraps `leann` commands for Claude Code integration
|
|
||||||
- **Claude Code** - Calls `leann_mcp`, which executes `leann` commands and returns intelligent results
|
|
||||||
|
|
||||||
## 📁 File Support
|
|
||||||
|
|
||||||
LEANN understands **30+ file types** including:
|
|
||||||
- **Programming**: Python, JavaScript, TypeScript, Java, Go, Rust, C++, C#
|
|
||||||
- **Data**: SQL, YAML, JSON, CSV, XML
|
|
||||||
- **Documentation**: Markdown, TXT, PDF
|
|
||||||
- **And many more!**
|
|
||||||
|
|
||||||
## 💾 Storage & Organization
|
|
||||||
|
|
||||||
- **Project indexes**: Stored in `.leann/` directory (just like `.git`)
|
|
||||||
- **Global registry**: Project tracking at `~/.leann/projects.json`
|
|
||||||
- **Multi-project support**: Switch between different codebases seamlessly
|
|
||||||
- **Portable**: Transfer indexes between machines with minimal overhead
|
|
||||||
|
|
||||||
## 🗑️ Uninstalling
|
|
||||||
|
|
||||||
To remove the LEANN MCP server from Claude Code:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
claude mcp remove leann-server
|
|
||||||
```
|
|
||||||
To remove LEANN
|
|
||||||
```
|
|
||||||
uv pip uninstall leann leann-backend-hnsw leann-core
|
|
||||||
```
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# LEANN - The smallest vector index in the world
|
|
||||||
|
|
||||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Default installation (includes both HNSW and DiskANN backends)
|
|
||||||
uv pip install leann
|
|
||||||
```
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
```python
|
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from pathlib import Path
|
|
||||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
|
||||||
|
|
||||||
# Build an index (choose backend: "hnsw" or "diskann")
|
|
||||||
builder = LeannBuilder(backend_name="hnsw") # or "diskann" for large-scale deployments
|
|
||||||
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
|
||||||
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
|
|
||||||
# Search
|
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
|
||||||
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
|
||||||
|
|
||||||
# Chat with your data
|
|
||||||
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
|
||||||
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|
||||||
```
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
MIT License
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
LEANN - Low-storage Embedding Approximation for Neural Networks
|
|
||||||
|
|
||||||
A revolutionary vector database that democratizes personal AI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
|
||||||
|
|
||||||
# Re-export main API from leann-core
|
|
||||||
from leann_core import LeannBuilder, LeannChat, LeannSearcher
|
|
||||||
|
|
||||||
__all__ = ["LeannBuilder", "LeannChat", "LeannSearcher"]
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = ["setuptools>=61.0"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[project]
|
|
||||||
name = "leann"
|
|
||||||
version = "0.2.9"
|
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
|
||||||
readme = "README.md"
|
|
||||||
requires-python = ">=3.9"
|
|
||||||
license = { text = "MIT" }
|
|
||||||
authors = [
|
|
||||||
{ name = "LEANN Team" }
|
|
||||||
]
|
|
||||||
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
|
||||||
classifiers = [
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.12",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Default installation: core + hnsw + diskann
|
|
||||||
dependencies = [
|
|
||||||
"leann-core>=0.1.0",
|
|
||||||
"leann-backend-hnsw>=0.1.0",
|
|
||||||
"leann-backend-diskann>=0.1.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
# All backends now included by default
|
|
||||||
|
|
||||||
[project.urls]
|
|
||||||
Repository = "https://github.com/yichuan-w/LEANN"
|
|
||||||
Issues = "https://github.com/yichuan-w/LEANN/issues"
|
|
||||||
@@ -1,23 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
|
||||||
import xml.etree.ElementTree as ElementTree
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import typer
|
import typer
|
||||||
|
from pathlib import Path
|
||||||
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
def get_safe_path(s: str) -> str:
|
def get_safe_path(s: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove invalid characters to sanitize a path.
|
Remove invalid characters to sanitize a path.
|
||||||
:param s: str to sanitize
|
:param s: str to sanitize
|
||||||
:returns: sanitized str
|
:returns: sanitized str
|
||||||
"""
|
"""
|
||||||
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(" ", "")
|
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(
|
||||||
|
' ', '')
|
||||||
for i in ban_chars:
|
for i in ban_chars:
|
||||||
s = s.replace(i, "")
|
s = s.replace(i, "")
|
||||||
return s
|
return s
|
||||||
@@ -26,40 +25,36 @@ def get_safe_path(s: str) -> str:
|
|||||||
def process_history(history: str):
|
def process_history(history: str):
|
||||||
if history.startswith("<?xml") or history.startswith("<msg>"):
|
if history.startswith("<?xml") or history.startswith("<msg>"):
|
||||||
try:
|
try:
|
||||||
root = ElementTree.fromstring(history)
|
root = ET.fromstring(history)
|
||||||
title = root.find(".//title").text if root.find(".//title") is not None else None
|
title = root.find('.//title').text if root.find('.//title') is not None else None
|
||||||
quoted = (
|
quoted = root.find('.//refermsg/content').text if root.find('.//refermsg/content') is not None else None
|
||||||
root.find(".//refermsg/content").text
|
|
||||||
if root.find(".//refermsg/content") is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if title and quoted:
|
if title and quoted:
|
||||||
return {"title": title, "quoted": process_history(quoted)}
|
return {
|
||||||
|
"title": title,
|
||||||
|
"quoted": process_history(quoted)
|
||||||
|
}
|
||||||
if title:
|
if title:
|
||||||
return title
|
return title
|
||||||
except Exception:
|
except Exception:
|
||||||
return history
|
return history
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def get_message(history: dict | str):
|
def get_message(history: dict | str):
|
||||||
if isinstance(history, dict):
|
if isinstance(history, dict):
|
||||||
if "title" in history:
|
if 'title' in history:
|
||||||
return history["title"]
|
return history['title']
|
||||||
else:
|
else:
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def export_chathistory(user_id: str):
|
def export_chathistory(user_id: str):
|
||||||
res = requests.get(
|
res = requests.get("http://localhost:48065/wechat/chatlog", params={
|
||||||
"http://localhost:48065/wechat/chatlog",
|
"userId": user_id,
|
||||||
params={"userId": user_id, "count": 100000},
|
"count": 100000
|
||||||
).json()
|
}).json()
|
||||||
for i in range(len(res["chatLogs"])):
|
for i in range(len(res['chatLogs'])):
|
||||||
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
res['chatLogs'][i]['content'] = process_history(res['chatLogs'][i]['content'])
|
||||||
res["chatLogs"][i]["message"] = get_message(res["chatLogs"][i]["content"])
|
res['chatLogs'][i]['message'] = get_message(res['chatLogs'][i]['content'])
|
||||||
return res["chatLogs"]
|
return res['chatLogs']
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
|
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
|
||||||
@@ -69,7 +64,7 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
|
|||||||
if not dest.is_dir():
|
if not dest.is_dir():
|
||||||
if not dest.exists():
|
if not dest.exists():
|
||||||
inp = typer.prompt("Destination path does not exist, create it? (y/n)")
|
inp = typer.prompt("Destination path does not exist, create it? (y/n)")
|
||||||
if inp.lower() == "y":
|
if inp.lower() == 'y':
|
||||||
dest.mkdir(parents=True)
|
dest.mkdir(parents=True)
|
||||||
else:
|
else:
|
||||||
typer.echo("Aborted.", err=True)
|
typer.echo("Aborted.", err=True)
|
||||||
@@ -82,12 +77,12 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
|
|||||||
exported_count = 0
|
exported_count = 0
|
||||||
for user in tqdm(all_users):
|
for user in tqdm(all_users):
|
||||||
try:
|
try:
|
||||||
usr_chatlog = export_chathistory(user["arg"])
|
usr_chatlog = export_chathistory(user['arg'])
|
||||||
|
|
||||||
# Only write file if there are messages
|
# Only write file if there are messages
|
||||||
if len(usr_chatlog) > 0:
|
if len(usr_chatlog) > 0:
|
||||||
out_path = dest / get_safe_path((user["title"] or "") + "-" + user["arg"] + ".json")
|
out_path = dest/get_safe_path((user['title'] or "")+"-"+user['arg']+'.json')
|
||||||
with open(out_path, "w", encoding="utf-8") as f:
|
with open(out_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
|
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
|
||||||
exported_count += 1
|
exported_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -96,43 +91,23 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
|
|||||||
|
|
||||||
print(f"Exported {exported_count} users' chat history to {dest} in json.")
|
print(f"Exported {exported_count} users' chat history to {dest} in json.")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_sqlite(
|
def export_sqlite(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path("chatlog.db")):
|
||||||
dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path(
|
|
||||||
"chatlog.db"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Export all users' chat history to a sqlite database.
|
Export all users' chat history to a sqlite database.
|
||||||
"""
|
"""
|
||||||
connection = sqlite3.connect(dest)
|
connection = sqlite3.connect(dest)
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
cursor.execute(
|
cursor.execute("CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)")
|
||||||
"CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)"
|
|
||||||
)
|
|
||||||
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
|
||||||
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
|
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
|
||||||
|
|
||||||
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||||
for user in tqdm(all_users):
|
for user in tqdm(all_users):
|
||||||
cursor.execute(
|
cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user['arg'], user['title']))
|
||||||
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
|
usr_chatlog = export_chathistory(user['arg'])
|
||||||
(user["arg"], user["title"]),
|
|
||||||
)
|
|
||||||
usr_chatlog = export_chathistory(user["arg"])
|
|
||||||
for msg in usr_chatlog:
|
for msg in usr_chatlog:
|
||||||
cursor.execute(
|
cursor.execute("INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)", (user['arg'], msg['fromUser'], msg['toUser'], msg['message'], msg['createTime'], str(msg['content'])))
|
||||||
"INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
(
|
|
||||||
user["arg"],
|
|
||||||
msg["fromUser"],
|
|
||||||
msg["toUser"],
|
|
||||||
msg["message"],
|
|
||||||
msg["createTime"],
|
|
||||||
str(msg["content"]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
122
pyproject.toml
122
pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "leann-workspace"
|
name = "leann-workspace"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
@@ -25,63 +25,33 @@ dependencies = [
|
|||||||
"requests>=2.25.0",
|
"requests>=2.25.0",
|
||||||
"sentence-transformers>=2.2.0",
|
"sentence-transformers>=2.2.0",
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
# PDF parsing dependencies - essential for document processing
|
|
||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
"pdfplumber>=0.11.0",
|
|
||||||
"pymupdf>=1.26.0",
|
|
||||||
"pypdfium2>=4.30.0",
|
|
||||||
# LlamaIndex core and readers - updated versions
|
|
||||||
"llama-index>=0.12.44",
|
"llama-index>=0.12.44",
|
||||||
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
"llama-index-readers-docling",
|
||||||
# "llama-index-readers-docling", # Requires Python >= 3.10
|
"llama-index-node-parser-docling",
|
||||||
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
|
||||||
"llama-index-embeddings-huggingface>=0.5.5",
|
|
||||||
# Other dependencies
|
|
||||||
"ipykernel==6.29.5",
|
"ipykernel==6.29.5",
|
||||||
"msgpack>=1.1.1",
|
"msgpack>=1.1.1",
|
||||||
"mlx>=0.26.3; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
"mlx-lm>=0.26.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
|
"mlx>=0.26.3",
|
||||||
|
"mlx-lm>=0.26.0",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
"pybind11>=3.0.0",
|
|
||||||
"pathspec>=0.12.1",
|
|
||||||
"nbconvert>=7.16.6",
|
|
||||||
"gitignore-parser>=0.1.12",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
"pytest-xdist>=3.0", # For parallel test execution
|
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff==0.12.7", # Fixed version to ensure consistent formatting across all environments
|
"ruff>=0.1.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
"pre-commit>=3.5.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
test = [
|
|
||||||
"pytest>=7.0",
|
|
||||||
"pytest-timeout>=2.0",
|
|
||||||
"llama-index-core>=0.12.0",
|
|
||||||
"llama-index-readers-file>=0.4.0",
|
|
||||||
"python-dotenv>=1.0.0",
|
|
||||||
"sentence-transformers>=2.2.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
diskann = [
|
diskann = [
|
||||||
"leann-backend-diskann",
|
"leann-backend-diskann",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add a new optional dependency group for document processing
|
|
||||||
documents = [
|
|
||||||
"beautifulsoup4>=4.13.0", # For HTML parsing
|
|
||||||
"python-docx>=0.8.11", # For Word documents
|
|
||||||
"openpyxl>=3.1.0", # For Excel files
|
|
||||||
"pandas>=2.2.0", # For data processing
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = []
|
py-modules = []
|
||||||
|
|
||||||
@@ -90,79 +60,3 @@ py-modules = []
|
|||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
target-version = "py39"
|
|
||||||
line-length = 100
|
|
||||||
extend-exclude = [
|
|
||||||
"third_party",
|
|
||||||
"*.egg-info",
|
|
||||||
"__pycache__",
|
|
||||||
".git",
|
|
||||||
".venv",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
"E", # pycodestyle errors
|
|
||||||
"W", # pycodestyle warnings
|
|
||||||
"F", # pyflakes
|
|
||||||
"I", # isort
|
|
||||||
"B", # flake8-bugbear
|
|
||||||
"C4", # flake8-comprehensions
|
|
||||||
"UP", # pyupgrade
|
|
||||||
"N", # pep8-naming
|
|
||||||
"RUF", # ruff-specific rules
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
"E501", # line too long (handled by formatter)
|
|
||||||
"B008", # do not perform function calls in argument defaults
|
|
||||||
"B904", # raise without from
|
|
||||||
"N812", # lowercase imported as non-lowercase
|
|
||||||
"N806", # variable in function should be lowercase
|
|
||||||
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
|
|
||||||
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
quote-style = "double"
|
|
||||||
indent-style = "space"
|
|
||||||
skip-magic-trailing-comma = false
|
|
||||||
line-ending = "auto"
|
|
||||||
|
|
||||||
[dependency-groups]
|
|
||||||
dev = [
|
|
||||||
"ruff>=0.12.4",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.lychee]
|
|
||||||
accept = ["200", "403", "429", "503"]
|
|
||||||
timeout = 20
|
|
||||||
max_retries = 2
|
|
||||||
exclude = ["localhost", "127.0.0.1", "example.com"]
|
|
||||||
exclude_path = [".git/", ".venv/", "__pycache__/", "third_party/"]
|
|
||||||
scheme = ["https", "http"]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
testpaths = ["tests"]
|
|
||||||
python_files = ["test_*.py"]
|
|
||||||
python_classes = ["Test*"]
|
|
||||||
python_functions = ["test_*"]
|
|
||||||
markers = [
|
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
|
||||||
"openai: marks tests that require OpenAI API key",
|
|
||||||
]
|
|
||||||
timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety
|
|
||||||
addopts = [
|
|
||||||
"-v",
|
|
||||||
"--tb=short",
|
|
||||||
"--strict-markers",
|
|
||||||
"--disable-warnings",
|
|
||||||
]
|
|
||||||
env = [
|
|
||||||
"HF_HUB_DISABLE_SYMLINKS=1",
|
|
||||||
"TOKENIZERS_PARALLELISM=false",
|
|
||||||
]
|
|
||||||
|
|||||||
12
research/micro/analyze_HNSW.py
Normal file
12
research/micro/analyze_HNSW.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import faiss
|
||||||
|
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||||
|
|
||||||
|
# print total number of nodes
|
||||||
|
print(hnsw_index.ntotal)
|
||||||
|
|
||||||
|
# print stats of the graph
|
||||||
|
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
||||||
|
|
||||||
|
|
||||||
|
# save_degree_distribution
|
||||||
|
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
||||||
11
research/micro/analyze_NSG.py
Normal file
11
research/micro/analyze_NSG.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import faiss
|
||||||
|
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
||||||
|
|
||||||
|
# print total number of nodes
|
||||||
|
print(nsg_index.ntotal)
|
||||||
|
|
||||||
|
# print stats of the graph
|
||||||
|
print(nsg_index.nsg.print_neighbor_stats(0))
|
||||||
|
|
||||||
|
# save degree distribution
|
||||||
|
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
||||||
63
research/micro/bnbtest.py
Normal file
63
research/micro/bnbtest.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import time
|
||||||
|
|
||||||
|
# import bitsandbytes as bnb
|
||||||
|
from bitsandbytes.nn import Linear8bitLt
|
||||||
|
|
||||||
|
# set default to half
|
||||||
|
import torch
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
|
M = 2048
|
||||||
|
N = 2048
|
||||||
|
|
||||||
|
bsz = 2048
|
||||||
|
import torch_int
|
||||||
|
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
||||||
|
|
||||||
|
fp16_model = nn.Sequential(
|
||||||
|
nn.Linear(M, N),
|
||||||
|
# nn.Linear(2048, 2048)
|
||||||
|
)
|
||||||
|
|
||||||
|
int8_model = nn.Sequential(
|
||||||
|
Linear8bitLt(M, N, has_fp16_weights=False),
|
||||||
|
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
int8_model.load_state_dict(fp16_model.state_dict())
|
||||||
|
int8_model = int8_model.to(0) # Quantization happens here
|
||||||
|
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
||||||
|
|
||||||
|
# Create random input tensor
|
||||||
|
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
||||||
|
|
||||||
|
# Speed test function
|
||||||
|
def speed_test(model, input_tensor, name, num_iterations=100):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(10):
|
||||||
|
_ = model(input_tensor)
|
||||||
|
|
||||||
|
# Actual timing
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for _ in range(num_iterations):
|
||||||
|
_ = model(input_tensor)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
avg_time = (end_time - start_time) / num_iterations
|
||||||
|
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
||||||
|
return avg_time
|
||||||
|
|
||||||
|
# Run speed tests
|
||||||
|
with torch.no_grad(): # Disable gradient calculation for inference
|
||||||
|
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
||||||
|
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
||||||
|
|
||||||
|
# Calculate speedup
|
||||||
|
speedup = fp16_time / int8_time
|
||||||
|
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user