Compare commits
155 Commits
apps
...
fix-macos-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc4987591b | ||
|
|
d8b6ae8d1a | ||
|
|
f2ffcf5665 | ||
|
|
27d0d73f99 | ||
|
|
b124709bcd | ||
|
|
78251a6d4c | ||
|
|
16c833da86 | ||
|
|
c246cb4a01 | ||
|
|
0f34aee5db | ||
|
|
3e53d3d264 | ||
|
|
22c8f861bc | ||
|
|
a52e3c583a | ||
|
|
ab339886dd | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
8c988cf98b | ||
|
|
ac5fd844a5 | ||
|
|
4b4b825fec | ||
|
|
34ef0db42f | ||
|
|
41812c7d22 | ||
|
|
2047a1a128 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
402e8f97ad | ||
|
|
9a5c197acd | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 |
11
.github/workflows/build-and-publish.yml
vendored
Normal file
11
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
251
.github/workflows/build-reusable.yml
vendored
Normal file
251
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint and Format Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install ruff
|
||||||
|
run: |
|
||||||
|
uv tool install ruff
|
||||||
|
|
||||||
|
- name: Run ruff check
|
||||||
|
run: |
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
- name: Run ruff format check
|
||||||
|
run: |
|
||||||
|
ruff format --check .
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: lint
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-latest
|
||||||
|
python: '3.13'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||||
|
|
||||||
|
# Install Intel MKL for DiskANN
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Don't install LLVM, use system clang for better compatibility
|
||||||
|
brew install libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --system auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --system delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=11.0
|
||||||
|
uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
# Use system clang instead of homebrew LLVM for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
uv build --wheel --python python
|
||||||
|
else
|
||||||
|
uv build --wheel --python python
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
- name: Install built packages for testing
|
||||||
|
run: |
|
||||||
|
# Create a virtual environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Install the built wheels
|
||||||
|
# Use --find-links to let uv choose the correct wheel for the platform
|
||||||
|
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||||
|
uv pip install leann-core --find-links packages/leann-core/dist
|
||||||
|
uv pip install leann --find-links packages/leann/dist
|
||||||
|
fi
|
||||||
|
uv pip install leann-backend-hnsw --find-links packages/leann-backend-hnsw/dist
|
||||||
|
uv pip install leann-backend-diskann --find-links packages/leann-backend-diskann/dist
|
||||||
|
|
||||||
|
# Install test dependencies using extras
|
||||||
|
uv pip install -e ".[test]"
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
env:
|
||||||
|
CI: true # Mark as CI environment to skip memory-intensive tests
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
|
TOKENIZERS_PARALLELISM: false
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK: 0 # Disable MPS on macOS CI to avoid memory issues
|
||||||
|
OMP_NUM_THREADS: 1 # Disable OpenMP parallelism to avoid libomp crashes
|
||||||
|
MKL_NUM_THREADS: 1 # Single thread for MKL operations
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
|
- name: Run sanity checks (optional)
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run distance function tests if available
|
||||||
|
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||||
|
echo "Running distance function sanity checks..."
|
||||||
|
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
name: Update Version
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate version
|
||||||
|
run: |
|
||||||
|
# Remove 'v' prefix if present for validation
|
||||||
|
VERSION_CLEAN="${{ inputs.version }}"
|
||||||
|
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||||
|
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
|
run: |
|
||||||
|
# Check current version
|
||||||
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
|
echo "Current version: $CURRENT_VERSION"
|
||||||
|
echo "Target version: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
|
run: |
|
||||||
|
mkdir -p dist
|
||||||
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
|
echo "📦 Packages to publish:"
|
||||||
|
ls -la dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
|
run: |
|
||||||
|
# Check if tag already exists
|
||||||
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if release already exists
|
||||||
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
|
else
|
||||||
|
gh release create "v${{ inputs.version }}" \
|
||||||
|
--title "Release v${{ inputs.version }}" \
|
||||||
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -9,10 +9,9 @@ demo/indices/
|
|||||||
outputs/
|
outputs/
|
||||||
*.pkl
|
*.pkl
|
||||||
*.pdf
|
*.pdf
|
||||||
*.idx
|
*.idx
|
||||||
*.map
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
@@ -86,4 +85,6 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
*.meta.json
|
*.meta.json
|
||||||
*.passages.json
|
*.passages.json
|
||||||
|
|
||||||
batchtest.py
|
batchtest.py
|
||||||
|
tests/__pytest_cache__/
|
||||||
|
tests/__pycache__/
|
||||||
|
|||||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.2.1
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
- id: ruff-format
|
||||||
363
README.md
363
README.md
@@ -12,11 +12,11 @@
|
|||||||
The smallest vector index in the world. RAG Everything with LEANN!
|
The smallest vector index in the world. RAG Everything with LEANN!
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
|
LEANN is an innovative vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig →](#️-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
|
||||||
|
|
||||||
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -26,18 +26,52 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
|
|||||||
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
|
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
|
||||||
|
|
||||||
|
|
||||||
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
|
||||||
|
|
||||||
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
🪶 **Lightweight:** Graph-based recomputation eliminates heavy embedding storage, while smart graph pruning and CSR format minimize graph storage overhead. Always less storage, less memory usage!
|
||||||
|
|
||||||
|
📦 **Portable:** Transfer your entire knowledge base between devices (even with others) with minimal cost - your personal AI memory travels with you.
|
||||||
|
|
||||||
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
📈 **Scalability:** Handle messy personal data that would crash traditional vector DBs, easily managing your growing personalized data and agent generated memory!
|
||||||
|
|
||||||
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
✨ **No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
|
||||||
|
|
||||||
## Quick Start in 1 minute
|
## Installation
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📦 Prerequisites: Install uv (if you don't have it)</strong></summary>
|
||||||
|
|
||||||
|
Install uv first if you don't have it:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
📖 [Detailed uv installation methods →](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
LEANN provides two installation methods: **pip install** (quick and easy) and **build from source** (recommended for development).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 🚀 Quick Install (Recommended for most users)
|
||||||
|
|
||||||
|
Clone the repository to access all examples and install LEANN from [PyPI](https://pypi.org/project/leann/) to run them immediately:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🔧 Build from Source (Recommended for development)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone git@github.com:yichuan-w/LEANN.git leann
|
git clone git@github.com:yichuan-w/LEANN.git leann
|
||||||
@@ -47,36 +81,68 @@ git submodule update --init --recursive
|
|||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf zeromq
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
export CC=$(brew --prefix llvm)/bin/clang
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
export CXX=$(brew --prefix llvm)/bin/clang++
|
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
|
||||||
uv sync
|
|
||||||
|
|
||||||
# Or add DiskANN backend if you want to test more options
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux (Ubuntu/Debian):**
|
**Linux:**
|
||||||
```bash
|
```bash
|
||||||
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
|
||||||
# Install with HNSW backend (default, recommended for most users)
|
|
||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
# Or add DiskANN backend if you want to test more options
|
|
||||||
uv sync --extra diskann
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
**Ollama Setup (Recommended for full privacy):**
|
|
||||||
|
|
||||||
> *You can skip this installation if you only want to use OpenAI API for generation.*
|
## Quick Start
|
||||||
|
|
||||||
|
Our declarative API makes RAG as easy as writing a config file.
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb) [Try in this ipynb file →](demo.ipynb)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
|
# Build an index
|
||||||
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
|
||||||
|
# Search
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||||
|
|
||||||
|
# Chat with your data
|
||||||
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
## RAG on Everything!
|
||||||
|
|
||||||
|
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
*macOS:*
|
> **Generation Model Setup**
|
||||||
|
> LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
|
||||||
|
|
||||||
|
Set your OpenAI API key as an environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>🔧 Ollama Setup (Recommended for full privacy)</strong></summary>
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
|
||||||
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
First, [download Ollama for macOS](https://ollama.com/download/mac).
|
||||||
|
|
||||||
@@ -85,7 +151,8 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
*Linux:*
|
**Linux:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install Ollama
|
# Install Ollama
|
||||||
curl -fsSL https://ollama.ai/install.sh | sh
|
curl -fsSL https://ollama.ai/install.sh | sh
|
||||||
@@ -97,57 +164,19 @@ ollama serve &
|
|||||||
ollama pull llama3.2:1b
|
ollama pull llama3.2:1b
|
||||||
```
|
```
|
||||||
|
|
||||||
## Dead Simple API
|
</details>
|
||||||
|
|
||||||
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
|
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
|
||||||
|
|
||||||
```python
|
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
|
||||||
# 1. Build the index (no embeddings stored!)
|
<p align="center">
|
||||||
builder = LeannBuilder(backend_name="hnsw")
|
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
|
||||||
builder.add_text("C# is a powerful programming language")
|
</p>
|
||||||
builder.add_text("Python is a powerful programming language and it is very popular")
|
|
||||||
builder.add_text("Machine learning transforms industries")
|
|
||||||
builder.add_text("Neural networks process complex data")
|
|
||||||
builder.add_text("Leann is a great storage saving engine for RAG on your MacBook")
|
|
||||||
builder.build_index("knowledge.leann")
|
|
||||||
|
|
||||||
# 2. Search with real-time embeddings
|
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
|
||||||
searcher = LeannSearcher("knowledge.leann")
|
|
||||||
results = searcher.search("programming languages", top_k=2)
|
|
||||||
|
|
||||||
# 3. Chat with LEANN using retrieved results
|
|
||||||
llm_config = {
|
|
||||||
"type": "ollama",
|
|
||||||
"model": "llama3.2:1b"
|
|
||||||
}
|
|
||||||
|
|
||||||
chat = LeannChat(index_path="knowledge.leann", llm_config=llm_config)
|
|
||||||
response = chat.ask(
|
|
||||||
"Compare the two retrieved programming languages and say which one is more popular today.",
|
|
||||||
top_k=2,
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
|
|
||||||
|
|
||||||
[Try the interactive demo →](demo.ipynb)
|
|
||||||
|
|
||||||
## Wild Things You Can Do
|
|
||||||
|
|
||||||
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
|
|
||||||
|
|
||||||
### Process Any Documents (.pdf, .txt, .md)
|
|
||||||
|
|
||||||
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents, and even any directory that stores your personal files!
|
|
||||||
|
|
||||||
The following scripts use Ollama `qwen3:8b` by default, so you need `ollama pull qwen3:8b` first. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Drop your PDFs, .txt, .md files into examples/data/
|
|
||||||
uv run ./examples/main_cli_example.py
|
|
||||||
|
|
||||||
# Or use python directly
|
# Or use python directly
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
python ./examples/main_cli_example.py
|
python ./examples/main_cli_example.py
|
||||||
@@ -155,14 +184,20 @@ python ./examples/main_cli_example.py
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
|
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
|
||||||
|
|
||||||
### Search Your Entire Life
|
> **Note:** The examples below currently support macOS only. Windows support coming soon.
|
||||||
|
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
|
||||||
```bash
|
```bash
|
||||||
python examples/mail_reader_leann.py
|
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
|
||||||
# "What's the number of class recommend to take per semester for incoming EECS students?"
|
|
||||||
```
|
```
|
||||||
**90K emails → 14MB.** Finally, search your email like you search Google.
|
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
@@ -195,12 +230,16 @@ Once the index is built, you can ask questions like:
|
|||||||
- "Show me emails about travel expenses"
|
- "Show me emails about travel expenses"
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Time Machine for the Web
|
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/google_history_reader_leann.py
|
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
|
||||||
# "Tell me my browser history about machine learning system stuff?"
|
|
||||||
```
|
```
|
||||||
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
|
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
|
||||||
@@ -249,13 +288,17 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### WeChat Detective
|
### 💬 WeChat Detective: Unlock Your Golden Memories!
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/wechat_history_reader_leann.py
|
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
|
||||||
# "Show me all group chats about weekend plans"
|
|
||||||
```
|
```
|
||||||
**400K messages → 64MB.** Search years of chat history in any language.
|
**400K messages → 64MB storage** Search years of chat history in any language.
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
|
||||||
@@ -266,7 +309,13 @@ First, you need to install the WeChat exporter:
|
|||||||
sudo packages/wechat-exporter/wechattweak-cli install
|
sudo packages/wechat-exporter/wechattweak-cli install
|
||||||
```
|
```
|
||||||
|
|
||||||
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
|
**Troubleshooting:**
|
||||||
|
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
|
||||||
|
- **Export errors**: If you encounter the error below, try restarting WeChat
|
||||||
|
```
|
||||||
|
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
|
||||||
|
Failed to find or export WeChat data. Exiting.
|
||||||
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -310,7 +359,7 @@ LEANN includes a powerful CLI for document processing and search. Perfect for qu
|
|||||||
# Build an index from documents
|
# Build an index from documents
|
||||||
leann build my-docs --docs ./documents
|
leann build my-docs --docs ./documents
|
||||||
|
|
||||||
# Search your documents
|
# Search your documents
|
||||||
leann search my-docs "machine learning concepts"
|
leann search my-docs "machine learning concepts"
|
||||||
|
|
||||||
# Interactive chat with your documents
|
# Interactive chat with your documents
|
||||||
@@ -378,7 +427,7 @@ Options:
|
|||||||
|
|
||||||
**Core techniques:**
|
**Core techniques:**
|
||||||
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
- **Graph-based selective recomputation:** Only compute embeddings for nodes in the search path
|
||||||
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
- **High-degree preserving pruning:** Keep important "hub" nodes while removing redundant connections
|
||||||
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
- **Dynamic batching:** Efficiently batch embedding computations for GPU utilization
|
||||||
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
- **Two-level search:** Smart graph traversal that prioritizes promising nodes
|
||||||
|
|
||||||
@@ -386,46 +435,18 @@ Options:
|
|||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
Run the comparison yourself:
|
|
||||||
```bash
|
|
||||||
python examples/compare_faiss_vs_leann.py
|
|
||||||
```
|
|
||||||
|
|
||||||
| System | Storage |
|
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
|
||||||
|--------|---------|
|
### Storage Comparison
|
||||||
| FAISS HNSW | 5.5 MB |
|
|
||||||
| LEANN | 0.5 MB |
|
|
||||||
| **Savings** | **91%** |
|
|
||||||
|
|
||||||
Same dataset, same hardware, same embedding model. LEANN just works better.
|
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|
||||||
|
|--------|-------------|------------|-------------|--------------|---------------|
|
||||||
|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
|
||||||
|
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
|
||||||
|
| Savings| 91% | 97% | 97% | 97% | 95% |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Storage Usage Comparison
|
|
||||||
|
|
||||||
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|
|
||||||
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
|
|
||||||
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
|
|
||||||
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
|
|
||||||
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
|
|
||||||
|
|
||||||
<!-- ### Memory Usage Comparison
|
|
||||||
|
|
||||||
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
|
|
||||||
| --------------------- | ---------------- | ---------------- | ---------------- |
|
|
||||||
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
|
|
||||||
| **Leann** | **xx MB** | **x GB** | **x GB** |
|
|
||||||
| **Reduction** | **x%** | **x%** | **x%** |
|
|
||||||
|
|
||||||
### Query Performance of LEANN
|
|
||||||
|
|
||||||
| Backend | Index Size | Query Time | Recall@3 |
|
|
||||||
| ------------------- | ---------- | ---------- | --------- |
|
|
||||||
| DiskANN | 1M docs | xms | 0.95 |
|
|
||||||
| HNSW | 1M docs | xms | 0.95 | -->
|
|
||||||
|
|
||||||
*Benchmarks run on Apple M3 Pro 36 GB*
|
|
||||||
|
|
||||||
## Reproduce Our Results
|
## Reproduce Our Results
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -443,108 +464,25 @@ If you find Leann useful, please cite:
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@misc{wang2025leannlowstoragevectorindex,
|
@misc{wang2025leannlowstoragevectorindex,
|
||||||
title={LEANN: A Low-Storage Vector Index},
|
title={LEANN: A Low-Storage Vector Index},
|
||||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||||
year={2025},
|
year={2025},
|
||||||
eprint={2506.08276},
|
eprint={2506.08276},
|
||||||
archivePrefix={arXiv},
|
archivePrefix={arXiv},
|
||||||
primaryClass={cs.DB},
|
primaryClass={cs.DB},
|
||||||
url={https://arxiv.org/abs/2506.08276},
|
url={https://arxiv.org/abs/2506.08276},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ [Detailed Features →](docs/features.md)
|
||||||
|
|
||||||
### 🔥 Core Features
|
## 🤝 [CONTRIBUTING →](docs/CONTRIBUTING.md)
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
|
||||||
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
|
||||||
|
|
||||||
### 🛠️ Technical Highlights
|
|
||||||
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
|
||||||
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
|
||||||
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
|
||||||
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
|
||||||
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
|
||||||
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
|
||||||
|
|
||||||
### 🎨 Developer Experience
|
|
||||||
|
|
||||||
- **Simple Python API** - Get started in minutes
|
|
||||||
- **Extensible backend system** - Easy to add new algorithms
|
|
||||||
- **Comprehensive examples** - From basic usage to production deployment
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
We welcome contributions! Leann is built by the community, for the community.
|
|
||||||
|
|
||||||
### Ways to Contribute
|
|
||||||
|
|
||||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
|
||||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
|
||||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
|
||||||
- 📖 **Documentation**: Help make Leann more accessible
|
|
||||||
- 🧪 **Benchmarks**: Share your performance results
|
|
||||||
|
|
||||||
|
|
||||||
<!-- ## ❓ FAQ
|
## ❓ [FAQ →](docs/faq.md)
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
#### NCCL Topology Error
|
|
||||||
|
|
||||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
|
||||||
|
|
||||||
```
|
|
||||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
|
||||||
```
|
|
||||||
|
|
||||||
**Solution**: Set these environment variables before running your script:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
|
||||||
export NCCL_DEBUG=INFO
|
|
||||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
|
||||||
export NCCL_IB_DISABLE=1
|
|
||||||
export NCCL_NET_PLUGIN=none
|
|
||||||
export NCCL_SOCKET_IFNAME=ens5
|
|
||||||
``` -->
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
### 1. My building time seems long
|
|
||||||
|
|
||||||
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
|
||||||
```
|
|
||||||
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
|
||||||
|
|
||||||
|
|
||||||
## 📈 Roadmap
|
## 📈 [Roadmap →](docs/roadmap.md)
|
||||||
|
|
||||||
### 🎯 Q2 2025
|
|
||||||
|
|
||||||
- [X] DiskANN backend with MIPS/L2/Cosine support
|
|
||||||
- [X] HNSW backend integration
|
|
||||||
- [X] Real-time embedding pipeline
|
|
||||||
- [X] Memory-efficient graph pruning
|
|
||||||
|
|
||||||
### 🚀 Q3 2025
|
|
||||||
|
|
||||||
|
|
||||||
- [ ] Advanced caching strategies
|
|
||||||
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
|
||||||
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
|
||||||
- [ ] Add OpenAI recompute API
|
|
||||||
|
|
||||||
### 🌟 Q4 2025
|
|
||||||
|
|
||||||
- [ ] Integration with LangChain/LlamaIndex
|
|
||||||
- [ ] Visual similarity search
|
|
||||||
- [ ] Query rewrtiting, rerank and expansion
|
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
|
|
||||||
@@ -552,11 +490,7 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
## 🙏 Acknowledgments
|
## 🙏 Acknowledgments
|
||||||
|
|
||||||
- **Microsoft Research** for the DiskANN algorithm
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
|
||||||
- **Meta AI** for FAISS and optimization insights
|
|
||||||
- **HuggingFace** for the transformer ecosystem
|
|
||||||
- **Our amazing contributors** who make this possible
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -566,4 +500,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
Made with ❤️ by the Leann team
|
Made with ❤️ by the Leann team
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
105
demo.ipynb
105
demo.ipynb
@@ -1,37 +1,116 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Quick Start \n",
|
||||||
|
"\n",
|
||||||
|
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||||
|
"\n",
|
||||||
|
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
"# install this if you are using colab\n",
|
||||||
|
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||||
|
"! uv pip install leann --no-deps\n",
|
||||||
|
"# For Colab environment, we need to set some environment variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||||
|
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 1. Build the index (no embeddings stored!)\n",
|
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
"builder.add_text(\n",
|
||||||
|
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||||
|
")\n",
|
||||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
"builder.build_index(INDEX_PATH)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Search with real-time embeddings\n",
|
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
|
||||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 3. Chat with LEANN using retrieved results\n",
|
|
||||||
"llm_config = {\n",
|
"llm_config = {\n",
|
||||||
" \"type\": \"ollama\",\n",
|
" \"type\": \"hf\",\n",
|
||||||
" \"model\": \"llama3.2:1b\"\n",
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
")"
|
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
220
docs/CONTRIBUTING.md
Normal file
220
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
## 🚀 Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Install uv** (fast Python package installer):
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||||
|
cd LEANN-RAG
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install system dependencies**:
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||||
|
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Build from source**:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔨 Pre-commit Hooks
|
||||||
|
|
||||||
|
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||||
|
|
||||||
|
### Setup Pre-commit
|
||||||
|
|
||||||
|
1. **Install pre-commit** (already included when you run `uv sync`):
|
||||||
|
```bash
|
||||||
|
uv pip install pre-commit
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install the git hooks**:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run pre-commit manually** (optional):
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-commit Checks
|
||||||
|
|
||||||
|
Our pre-commit configuration includes:
|
||||||
|
- **Trailing whitespace removal**
|
||||||
|
- **End-of-file fixing**
|
||||||
|
- **YAML validation**
|
||||||
|
- **Large file prevention**
|
||||||
|
- **Merge conflict detection**
|
||||||
|
- **Debug statement detection**
|
||||||
|
- **Code formatting with ruff**
|
||||||
|
- **Code linting with ruff**
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test/test_filename.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Place tests in the `test/` directory
|
||||||
|
- Follow the naming convention `test_*.py`
|
||||||
|
- Use descriptive test names that explain what's being tested
|
||||||
|
- Include both positive and negative test cases
|
||||||
|
|
||||||
|
## 📝 Code Style
|
||||||
|
|
||||||
|
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||||
|
|
||||||
|
### Format Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format all files
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
ruff format --check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run linter with auto-fix
|
||||||
|
ruff check --fix
|
||||||
|
|
||||||
|
# Just check without fixing
|
||||||
|
ruff check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Style Guidelines
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use descriptive variable names
|
||||||
|
- Add type hints where appropriate
|
||||||
|
- Write docstrings for all public functions and classes
|
||||||
|
- Keep functions focused and single-purpose
|
||||||
|
|
||||||
|
## 🚦 CI/CD
|
||||||
|
|
||||||
|
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||||
|
|
||||||
|
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||||
|
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||||
|
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||||
|
4. **Wheel building**: Ensures packages can be built and distributed
|
||||||
|
|
||||||
|
### CI Commands
|
||||||
|
|
||||||
|
The CI uses the same commands as pre-commit to ensure consistency:
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format checking
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure your code passes these checks locally before pushing!
|
||||||
|
|
||||||
|
## 🔄 Pull Request Process
|
||||||
|
|
||||||
|
1. **Fork the repository** and create your branch from `main`:
|
||||||
|
```bash
|
||||||
|
git checkout -b feature/your-feature-name
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Make your changes**:
|
||||||
|
- Write clean, documented code
|
||||||
|
- Add tests for new functionality
|
||||||
|
- Update documentation as needed
|
||||||
|
|
||||||
|
3. **Run pre-commit checks**:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Test your changes**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Commit with descriptive messages**:
|
||||||
|
```bash
|
||||||
|
git commit -m "feat: add new search algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||||
|
- `feat:` for new features
|
||||||
|
- `fix:` for bug fixes
|
||||||
|
- `docs:` for documentation changes
|
||||||
|
- `test:` for test additions/changes
|
||||||
|
- `refactor:` for code refactoring
|
||||||
|
- `perf:` for performance improvements
|
||||||
|
|
||||||
|
6. **Push and create a pull request**:
|
||||||
|
- Provide a clear description of your changes
|
||||||
|
- Reference any related issues
|
||||||
|
- Include examples or screenshots if applicable
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
When adding new features or making significant changes:
|
||||||
|
|
||||||
|
1. Update relevant documentation in `/docs`
|
||||||
|
2. Add docstrings to new functions/classes
|
||||||
|
3. Update README.md if needed
|
||||||
|
4. Include usage examples
|
||||||
|
|
||||||
|
## 🤔 Getting Help
|
||||||
|
|
||||||
|
- **Discord**: Join our community for discussions
|
||||||
|
- **Issues**: Check existing issues or create a new one
|
||||||
|
- **Discussions**: For general questions and ideas
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## Setup (One-time)
|
||||||
|
|
||||||
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
|
|
||||||
|
## Release (One-click)
|
||||||
|
|
||||||
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
|
2. Click "Run workflow"
|
||||||
|
3. Enter version: `0.1.2`
|
||||||
|
4. Click green "Run workflow" button
|
||||||
|
|
||||||
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Comparison between Sentence Transformers and OpenAI embeddings
|
||||||
|
|
||||||
|
This example shows how different embedding models handle complex queries
|
||||||
|
and demonstrates the differences between local and API-based embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
# OpenAI API key should be set as environment variable
|
||||||
|
# export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||||
|
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||||
|
|
||||||
|
# Two queries with same intent but different wording
|
||||||
|
query1 = "Tell me my browser history about some conference i often visit"
|
||||||
|
query2 = "browser history about conference I often visit"
|
||||||
|
|
||||||
|
texts = [query1, query2, conference_text, browser_text]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) # Already normalized
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_embeddings(embeddings, model_name):
|
||||||
|
print(f"\n=== {model_name} Results ===")
|
||||||
|
|
||||||
|
# Results for Query 1
|
||||||
|
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||||
|
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||||
|
|
||||||
|
print(f"Query 1: '{query1}'")
|
||||||
|
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Results for Query 2
|
||||||
|
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||||
|
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||||
|
|
||||||
|
print(f"\nQuery 2: '{query2}'")
|
||||||
|
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Show the impact
|
||||||
|
print("\n=== Impact Analysis ===")
|
||||||
|
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||||
|
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||||
|
|
||||||
|
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||||
|
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||||
|
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||||
|
print("✅ STABLE: Conference remains winner in both queries")
|
||||||
|
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||||
|
print("✅ STABLE: Browser remains winner in both queries")
|
||||||
|
else:
|
||||||
|
print("🔄 MIXED: Results vary between queries")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query1_conf": sim1_conf,
|
||||||
|
"query1_browser": sim1_browser,
|
||||||
|
"query2_conf": sim2_conf,
|
||||||
|
"query2_browser": sim2_browser,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test Sentence Transformers
|
||||||
|
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||||
|
try:
|
||||||
|
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||||
|
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Sentence Transformers failed: {e}")
|
||||||
|
st_results = None
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing OpenAI (text-embedding-3-small)...")
|
||||||
|
try:
|
||||||
|
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||||
|
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ OpenAI failed: {e}")
|
||||||
|
openai_results = None
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if st_results and openai_results:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("=== COMPARISON SUMMARY ===")
|
||||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
22
docs/features.md
Normal file
22
docs/features.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# ✨ Detailed Features
|
||||||
|
|
||||||
|
## 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
|
||||||
|
|
||||||
|
## 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
|
||||||
|
|
||||||
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Normalized Embeddings Support in LEANN
|
||||||
|
|
||||||
|
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||||
|
|
||||||
|
## What are Normalized Embeddings?
|
||||||
|
|
||||||
|
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||||
|
|
||||||
|
## Automatic Detection
|
||||||
|
|
||||||
|
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||||
|
|
||||||
|
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||||
|
2. **Show a warning** if you manually specify a different distance metric
|
||||||
|
3. **Provide optimal search performance** with the correct metric
|
||||||
|
|
||||||
|
## Supported Normalized Embedding Models
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
All OpenAI text embedding models are normalized:
|
||||||
|
- `text-embedding-ada-002`
|
||||||
|
- `text-embedding-3-small`
|
||||||
|
- `text-embedding-3-large`
|
||||||
|
|
||||||
|
### Voyage AI
|
||||||
|
All Voyage AI embedding models are normalized:
|
||||||
|
- `voyage-2`
|
||||||
|
- `voyage-3`
|
||||||
|
- `voyage-large-2`
|
||||||
|
- `voyage-multilingual-2`
|
||||||
|
- `voyage-code-2`
|
||||||
|
|
||||||
|
### Cohere
|
||||||
|
All Cohere embedding models are normalized:
|
||||||
|
- `embed-english-v3.0`
|
||||||
|
- `embed-multilingual-v3.0`
|
||||||
|
- `embed-english-light-v3.0`
|
||||||
|
- `embed-multilingual-light-v3.0`
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
# Automatic detection - will use cosine distance
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai"
|
||||||
|
)
|
||||||
|
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||||
|
# Automatically setting distance_metric='cosine'
|
||||||
|
|
||||||
|
# Manual override (not recommended)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
distance_metric="mips" # Will show warning
|
||||||
|
)
|
||||||
|
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Non-Normalized Embeddings
|
||||||
|
|
||||||
|
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||||
|
|
||||||
|
## Why This Matters
|
||||||
|
|
||||||
|
Using the wrong distance metric with normalized embeddings can lead to:
|
||||||
|
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||||
|
- **Incorrect ranking** of search results
|
||||||
|
- **Suboptimal performance** compared to using the correct metric
|
||||||
|
|
||||||
|
For more details on why this happens, see our analysis of [OpenAI embeddings with MIPS](../examples/main_cli_example.py).
|
||||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# 📈 Roadmap
|
||||||
|
|
||||||
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
## 🚀 Q3 2025
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
## 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
@@ -3,14 +3,15 @@
|
|||||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "Peak Memory:" in line:
|
if "Peak Memory:" in line:
|
||||||
peak_memory = float(
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"peak_memory": peak_memory}
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
@@ -111,9 +110,8 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
tracker.checkpoint("After imports")
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
# Load and parse documents
|
# Load and parse documents
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
@@ -135,6 +133,7 @@ def test_leann_hnsw():
|
|||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
@@ -196,16 +195,14 @@ def test_leann_hnsw():
|
|||||||
runtime_start_mem = get_memory_usage()
|
runtime_start_mem = get_memory_usage()
|
||||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
tracker.checkpoint("Before load memory")
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
# Load searcher
|
# Load searcher
|
||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
tracker.checkpoint("After searcher loading")
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Running search queries...")
|
print("Running search queries...")
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
@@ -303,21 +300,15 @@ def main():
|
|||||||
|
|
||||||
print("\nLEANN vs Faiss Performance:")
|
print("\nLEANN vs Faiss Performance:")
|
||||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
print(
|
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Storage comparison
|
# Storage comparison
|
||||||
if leann_storage_size > faiss_storage_size:
|
if leann_storage_size > faiss_storage_size:
|
||||||
storage_ratio = leann_storage_size / faiss_storage_size
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
|
||||||
)
|
|
||||||
elif faiss_storage_size > leann_storage_size:
|
elif faiss_storage_size > leann_storage_size:
|
||||||
storage_ratio = faiss_storage_size / leann_storage_size
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(" Storage Size: similar")
|
print(" Storage Size: similar")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
The Project Gutenberg eBook of Pride and Prejudice
|
The Project Gutenberg eBook of Pride and Prejudice
|
||||||
|
|
||||||
This ebook is for the use of anyone anywhere in the United States and
|
This ebook is for the use of anyone anywhere in the United States and
|
||||||
most other parts of the world at no cost and with almost no restrictions
|
most other parts of the world at no cost and with almost no restrictions
|
||||||
whatsoever. You may copy it, give it away or re-use it under the terms
|
whatsoever. You may copy it, give it away or re-use it under the terms
|
||||||
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
|
|||||||
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
|
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Updated editions will replace the previous one—the old editions will
|
Updated editions will replace the previous one—the old editions will
|
||||||
be renamed.
|
be renamed.
|
||||||
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
|
|||||||
at www.gutenberg.org. If you
|
at www.gutenberg.org. If you
|
||||||
are not located in the United States, you will have to check the laws
|
are not located in the United States, you will have to check the laws
|
||||||
of the country where you are located before using this eBook.
|
of the country where you are located before using this eBook.
|
||||||
|
|
||||||
1.E.2. If an individual Project Gutenberg™ electronic work is
|
1.E.2. If an individual Project Gutenberg™ electronic work is
|
||||||
derived from texts not protected by U.S. copyright law (does not
|
derived from texts not protected by U.S. copyright law (does not
|
||||||
contain a notice indicating that it is posted with permission of the
|
contain a notice indicating that it is posted with permission of the
|
||||||
@@ -14724,7 +14724,7 @@ provided that:
|
|||||||
Gutenberg Literary Archive Foundation at the address specified in
|
Gutenberg Literary Archive Foundation at the address specified in
|
||||||
Section 4, “Information about donations to the Project Gutenberg
|
Section 4, “Information about donations to the Project Gutenberg
|
||||||
Literary Archive Foundation.”
|
Literary Archive Foundation.”
|
||||||
|
|
||||||
• You provide a full refund of any money paid by a user who notifies
|
• You provide a full refund of any money paid by a user who notifies
|
||||||
you in writing (or by e-mail) within 30 days of receipt that s/he
|
you in writing (or by e-mail) within 30 days of receipt that s/he
|
||||||
does not agree to the terms of the full Project Gutenberg™
|
does not agree to the terms of the full Project Gutenberg™
|
||||||
@@ -14732,15 +14732,15 @@ provided that:
|
|||||||
copies of the works possessed in a physical medium and discontinue
|
copies of the works possessed in a physical medium and discontinue
|
||||||
all use of and all access to other copies of Project Gutenberg™
|
all use of and all access to other copies of Project Gutenberg™
|
||||||
works.
|
works.
|
||||||
|
|
||||||
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
||||||
any money paid for a work or a replacement copy, if a defect in the
|
any money paid for a work or a replacement copy, if a defect in the
|
||||||
electronic work is discovered and reported to you within 90 days of
|
electronic work is discovered and reported to you within 90 days of
|
||||||
receipt of the work.
|
receipt of the work.
|
||||||
|
|
||||||
• You comply with all other terms of this agreement for free
|
• You comply with all other terms of this agreement for free
|
||||||
distribution of Project Gutenberg™ works.
|
distribution of Project Gutenberg™ works.
|
||||||
|
|
||||||
|
|
||||||
1.E.9. If you wish to charge a fee or distribute a Project
|
1.E.9. If you wish to charge a fee or distribute a Project
|
||||||
Gutenberg™ electronic work or group of works on different terms than
|
Gutenberg™ electronic work or group of works on different terms than
|
||||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
|||||||
including how to make donations to the Project Gutenberg Literary
|
including how to make donations to the Project Gutenberg Literary
|
||||||
Archive Foundation, how to help produce our new eBooks, and how to
|
Archive Foundation, how to help produce our new eBooks, and how to
|
||||||
subscribe to our email newsletter to hear about new eBooks.
|
subscribe to our email newsletter to hear about new eBooks.
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,37 +3,47 @@
|
|||||||
Document search demo with recompute mode
|
Document search demo with recompute mode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# Import backend packages to trigger plugin registration
|
# Import backend packages to trigger plugin registration
|
||||||
try:
|
try:
|
||||||
import leann_backend_diskann
|
import leann_backend_diskann # noqa: F401
|
||||||
import leann_backend_hnsw
|
import leann_backend_hnsw # noqa: F401
|
||||||
|
|
||||||
print("INFO: Backend packages imported successfully.")
|
print("INFO: Backend packages imported successfully.")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||||
|
|
||||||
# Import upper-level API from leann-core
|
# Import upper-level API from leann-core
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def load_sample_documents():
|
def load_sample_documents():
|
||||||
"""Create sample documents for demonstration"""
|
"""Create sample documents for demonstration"""
|
||||||
docs = [
|
docs = [
|
||||||
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
|
{
|
||||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
"title": "Intro to Python",
|
||||||
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
|
"content": "Python is a high-level, interpreted language known for simplicity.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "ML Basics",
|
||||||
|
"content": "Machine learning builds systems that learn from data.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "Data Structures",
|
||||||
|
"content": "Data structures like arrays, lists, and graphs organize data.",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("==========================================================")
|
print("==========================================================")
|
||||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||||
print("==========================================================")
|
print("==========================================================")
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_indices")
|
INDEX_DIR = Path("./test_indices")
|
||||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||||
BACKEND_TO_TEST = "diskann"
|
BACKEND_TO_TEST = "diskann"
|
||||||
@@ -44,94 +54,96 @@ def main():
|
|||||||
|
|
||||||
# --- 1. Build index ---
|
# --- 1. Build index ---
|
||||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(backend_name=BACKEND_TO_TEST, graph_degree=32, complexity=64)
|
||||||
backend_name=BACKEND_TO_TEST,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = load_sample_documents()
|
documents = load_sample_documents()
|
||||||
print(f"Loaded {len(documents)} sample documents.")
|
print(f"Loaded {len(documents)} sample documents.")
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
print(f"\nIndex built!")
|
print("\nIndex built!")
|
||||||
|
|
||||||
# --- 2. Basic search demo ---
|
# --- 2. Basic search demo ---
|
||||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||||
|
|
||||||
query = "What is machine learning?"
|
query = "What is machine learning?"
|
||||||
print(f"\nQuery: '{query}'")
|
print(f"\nQuery: '{query}'")
|
||||||
|
|
||||||
print("\n--- Basic search mode (PQ computation) ---")
|
print("\n--- Basic search mode (PQ computation) ---")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = searcher.search(query, top_k=2)
|
results = searcher.search(query, top_k=2)
|
||||||
basic_time = time.time() - start_time
|
basic_time = time.time() - start_time
|
||||||
|
|
||||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||||
print(">>> Basic search results <<<")
|
print(">>> Basic search results <<<")
|
||||||
for i, res in enumerate(results, 1):
|
for i, res in enumerate(results, 1):
|
||||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
print(
|
||||||
|
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- 3. Recompute search demo ---
|
# --- 3. Recompute search demo ---
|
||||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
print("\n[PHASE 3] Recompute search using embedding server...")
|
||||||
|
|
||||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||||
|
|
||||||
# Configure recompute parameters
|
# Configure recompute parameters
|
||||||
recompute_params = {
|
recompute_params = {
|
||||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||||
"skip_search_reorder": True, # Skip search reordering
|
"skip_search_reorder": True, # Skip search reordering
|
||||||
"dedup_node_dis": True, # Enable node distance deduplication
|
"dedup_node_dis": True, # Enable node distance deduplication
|
||||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||||
"batch_recompute": False, # Don't use batch recomputation
|
"batch_recompute": False, # Don't use batch recomputation
|
||||||
"global_pruning": False, # Don't use global pruning
|
"global_pruning": False, # Don't use global pruning
|
||||||
"zmq_port": 5555, # ZMQ port
|
"zmq_port": 5555, # ZMQ port
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
|
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||||||
}
|
}
|
||||||
|
|
||||||
print("Recompute parameter configuration:")
|
print("Recompute parameter configuration:")
|
||||||
for key, value in recompute_params.items():
|
for key, value in recompute_params.items():
|
||||||
print(f" {key}: {value}")
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
print(f"\n🔄 Executing Recompute search...")
|
print("\n🔄 Executing Recompute search...")
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||||
recompute_time = time.time() - start_time
|
recompute_time = time.time() - start_time
|
||||||
|
|
||||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||||
print(">>> Recompute search results <<<")
|
print(">>> Recompute search results <<<")
|
||||||
for i, res in enumerate(recompute_results, 1):
|
for i, res in enumerate(recompute_results, 1):
|
||||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
print(
|
||||||
|
f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
print(f"\n--- Result comparison ---")
|
print("\n--- Result comparison ---")
|
||||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||||
|
|
||||||
print("\nBasic search vs Recompute results:")
|
print("\nBasic search vs Recompute results:")
|
||||||
for i in range(min(len(results), len(recompute_results))):
|
for i in range(min(len(results), len(recompute_results))):
|
||||||
basic_score = results[i].score
|
basic_score = results[i].score
|
||||||
recompute_score = recompute_results[i].score
|
recompute_score = recompute_results[i].score
|
||||||
score_diff = abs(basic_score - recompute_score)
|
score_diff = abs(basic_score - recompute_score)
|
||||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
print(
|
||||||
|
f" Position {i + 1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
if recompute_time > basic_time:
|
if recompute_time > basic_time:
|
||||||
print(f"✅ Recompute mode working correctly (more accurate but slower)")
|
print("✅ Recompute mode working correctly (more accurate but slower)")
|
||||||
else:
|
else:
|
||||||
print(f"ℹ️ Recompute time is unusually fast, network recomputation may not be enabled")
|
print("i️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Recompute search failed: {e}")
|
print(f"❌ Recompute search failed: {e}")
|
||||||
print("This usually indicates an embedding server connection issue")
|
print("This usually indicates an embedding server connection issue")
|
||||||
|
|
||||||
# --- 4. Chat demo ---
|
# --- 4. Chat demo ---
|
||||||
print(f"\n[PHASE 4] Starting chat session...")
|
print("\n[PHASE 4] Starting chat session...")
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
chat_response = chat.ask(query)
|
chat_response = chat.ask(query)
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
@@ -143,4 +155,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import os
|
|
||||||
import email
|
import email
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
|
||||||
|
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Recursively find all 'Messages' directories under the given root.
|
Recursively find all 'Messages' directories under the given root.
|
||||||
Returns a list of Path objects.
|
Returns a list of Path objects.
|
||||||
@@ -14,86 +16,97 @@ def find_all_messages_directories(root: str = None) -> List[Path]:
|
|||||||
# Auto-detect user's mail path
|
# Auto-detect user's mail path
|
||||||
home_dir = os.path.expanduser("~")
|
home_dir = os.path.expanduser("~")
|
||||||
root = os.path.join(home_dir, "Library", "Mail")
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
messages_dirs = []
|
messages_dirs = []
|
||||||
for dirpath, dirnames, filenames in os.walk(root):
|
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||||
if os.path.basename(dirpath) == "Messages":
|
if os.path.basename(dirpath) == "Messages":
|
||||||
messages_dirs.append(Path(dirpath))
|
messages_dirs.append(Path(dirpath))
|
||||||
return messages_dirs
|
return messages_dirs
|
||||||
|
|
||||||
|
|
||||||
class EmlxReader(BaseReader):
|
class EmlxReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
Apple Mail .emlx file reader with embedded metadata.
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
Reads individual .emlx files from Apple Mail's storage format.
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, include_html: bool = False) -> None:
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize.
|
Initialize.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
include_html: Whether to include HTML content in the email body (default: False)
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
"""
|
"""
|
||||||
self.include_html = include_html
|
self.include_html = include_html
|
||||||
|
|
||||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load data from the input directory containing .emlx files.
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dir: Directory containing .emlx files
|
input_dir: Directory containing .emlx files
|
||||||
**load_kwargs:
|
**load_kwargs:
|
||||||
max_count (int): Maximum amount of messages to read.
|
max_count (int): Maximum amount of messages to read.
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
# Walk through the directory recursively
|
# Walk through the directory recursively
|
||||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
# Skip hidden directories
|
# Skip hidden directories
|
||||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if count >= max_count:
|
if count >= max_count:
|
||||||
break
|
break
|
||||||
|
|
||||||
if filename.endswith(".emlx"):
|
if filename.endswith(".emlx"):
|
||||||
filepath = os.path.join(dirpath, filename)
|
filepath = os.path.join(dirpath, filename)
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# .emlx files have a length prefix followed by the email content
|
# .emlx files have a length prefix followed by the email content
|
||||||
# The first line contains the length, followed by the email
|
# The first line contains the length, followed by the email
|
||||||
lines = content.split('\n', 1)
|
lines = content.split("\n", 1)
|
||||||
if len(lines) >= 2:
|
if len(lines) >= 2:
|
||||||
email_content = lines[1]
|
email_content = lines[1]
|
||||||
|
|
||||||
# Parse the email using Python's email module
|
# Parse the email using Python's email module
|
||||||
try:
|
try:
|
||||||
msg = email.message_from_string(email_content)
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
# Extract email metadata
|
# Extract email metadata
|
||||||
subject = msg.get('Subject', 'No Subject')
|
subject = msg.get("Subject", "No Subject")
|
||||||
from_addr = msg.get('From', 'Unknown')
|
from_addr = msg.get("From", "Unknown")
|
||||||
to_addr = msg.get('To', 'Unknown')
|
to_addr = msg.get("To", "Unknown")
|
||||||
date = msg.get('Date', 'Unknown')
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
# Extract email body
|
# Extract email body
|
||||||
body = ""
|
body = ""
|
||||||
if msg.is_multipart():
|
if msg.is_multipart():
|
||||||
for part in msg.walk():
|
for part in msg.walk():
|
||||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
if (
|
||||||
if part.get_content_type() == "text/html" and not self.include_html:
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/html"
|
||||||
|
and not self.include_html
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
body += part.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
# break
|
# break
|
||||||
else:
|
else:
|
||||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
body = msg.get_payload(decode=True).decode(
|
||||||
|
"utf-8", errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
[File]: {filename}
|
[File]: {filename}
|
||||||
@@ -104,19 +117,19 @@ class EmlxReader(BaseReader):
|
|||||||
[EMAIL BODY Start]:
|
[EMAIL BODY Start]:
|
||||||
{body}
|
{body}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# No separate metadata - everything is in the text
|
# No separate metadata - everything is in the text
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(text=doc_content, metadata={})
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing email from {filepath}: {e}")
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading file {filepath}: {e}")
|
print(f"Error reading file {filepath}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loaded {len(docs)} email documents")
|
print(f"Loaded {len(docs)} email documents")
|
||||||
return docs
|
return docs
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
from fsspec import AbstractFileSystem
|
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.core.schema import Document
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MESSAGE_FORMAT: str = (
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
"Date: {_date}\n"
|
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||||
"From: {_from}\n"
|
|
||||||
"To: {_to}\n"
|
|
||||||
"Subject: {_subject}\n"
|
|
||||||
"Content: {_content}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup # noqa
|
from bs4 import BeautifulSoup # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
file: Path,
|
file: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse file into string."""
|
"""Parse file into string."""
|
||||||
# Import required libraries
|
# Import required libraries
|
||||||
import mailbox
|
import mailbox
|
||||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
results: List[str] = []
|
results: list[str] = []
|
||||||
# Load file using mailbox
|
# Load file using mailbox
|
||||||
bytes_parser = BytesParser(policy=default).parse
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
@@ -124,7 +118,7 @@ class MboxReader(BaseReader):
|
|||||||
class EmlxMboxReader(MboxReader):
|
class EmlxMboxReader(MboxReader):
|
||||||
"""
|
"""
|
||||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||||
|
|
||||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||||
1. Reading .emlx files from a directory
|
1. Reading .emlx files from a directory
|
||||||
2. Converting them to mbox format in memory
|
2. Converting them to mbox format in memory
|
||||||
@@ -134,13 +128,13 @@ class EmlxMboxReader(MboxReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
directory: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
import tempfile
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
if fs:
|
if fs:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||||
@@ -150,37 +144,37 @@ class EmlxMboxReader(MboxReader):
|
|||||||
# Find all .emlx files in the directory
|
# Find all .emlx files in the directory
|
||||||
emlx_files = list(directory.glob("*.emlx"))
|
emlx_files = list(directory.glob("*.emlx"))
|
||||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||||
|
|
||||||
if not emlx_files:
|
if not emlx_files:
|
||||||
logger.warning(f"No .emlx files found in {directory}")
|
logger.warning(f"No .emlx files found in {directory}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a temporary mbox file
|
# Create a temporary mbox file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||||
temp_mbox_path = temp_mbox.name
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
# Convert .emlx files to mbox format
|
# Convert .emlx files to mbox format
|
||||||
for emlx_file in emlx_files:
|
for emlx_file in emlx_files:
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# .emlx format: first line is length, rest is email content
|
# .emlx format: first line is length, rest is email content
|
||||||
lines = content.split('\n', 1)
|
lines = content.split("\n", 1)
|
||||||
if len(lines) >= 2:
|
if len(lines) >= 2:
|
||||||
email_content = lines[1] # Skip the length line
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Close the temporary file so MboxReader can read it
|
# Close the temporary file so MboxReader can read it
|
||||||
temp_mbox.close()
|
temp_mbox.close()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the parent MboxReader's logic to parse the mbox file
|
# Use the parent MboxReader's logic to parse the mbox file
|
||||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||||
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
|
|||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
try:
|
try:
|
||||||
os.unlink(temp_mbox_path)
|
os.unlink(temp_mbox_path)
|
||||||
except:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Test only Faiss HNSW"""
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
def get_memory_usage():
|
||||||
@@ -37,20 +37,20 @@ def main():
|
|||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Faiss is not installed.")
|
print("Faiss is not installed.")
|
||||||
print("Please install it with `uv pip install faiss-cpu`")
|
print(
|
||||||
|
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from llama_index.core import (
|
from llama_index.core import (
|
||||||
SimpleDirectoryReader,
|
|
||||||
VectorStoreIndex,
|
|
||||||
StorageContext,
|
|
||||||
Settings,
|
Settings,
|
||||||
node_parser,
|
SimpleDirectoryReader,
|
||||||
Document,
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
)
|
)
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
tracker.checkpoint("Initial")
|
tracker.checkpoint("Initial")
|
||||||
@@ -90,8 +90,9 @@ def main():
|
|||||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
)
|
)
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
index = load_index_from_storage(storage_context=storage_context)
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
print(f"Index loaded from ./storage_faiss")
|
print("Index loaded from ./storage_faiss")
|
||||||
tracker.checkpoint("After loading existing index")
|
tracker.checkpoint("After loading existing index")
|
||||||
index_loaded = True
|
index_loaded = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -99,19 +100,18 @@ def main():
|
|||||||
print("Cleaning up corrupted index and building new one...")
|
print("Cleaning up corrupted index and building new one...")
|
||||||
# Clean up corrupted index
|
# Clean up corrupted index
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if os.path.exists("./storage_faiss"):
|
if os.path.exists("./storage_faiss"):
|
||||||
shutil.rmtree("./storage_faiss")
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
if not index_loaded:
|
if not index_loaded:
|
||||||
print("Building new Faiss HNSW index...")
|
print("Building new Faiss HNSW index...")
|
||||||
|
|
||||||
# Use the correct Faiss building pattern from the example
|
# Use the correct Faiss building pattern from the example
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents, storage_context=storage_context, transformations=[node_parser]
|
||||||
storage_context=storage_context,
|
|
||||||
transformations=[node_parser]
|
|
||||||
)
|
)
|
||||||
tracker.checkpoint("After index building")
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
@@ -124,10 +124,10 @@ def main():
|
|||||||
runtime_start_mem = get_memory_usage()
|
runtime_start_mem = get_memory_usage()
|
||||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||||
tracker.checkpoint("Before load memory")
|
tracker.checkpoint("Before load memory")
|
||||||
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
@@ -141,7 +141,7 @@ def main():
|
|||||||
|
|
||||||
runtime_end_mem = get_memory_usage()
|
runtime_end_mem = get_memory_usage()
|
||||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||||
|
|
||||||
peak_memory = tracker.summary()
|
peak_memory = tracker.summary()
|
||||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
# python-dotenv is not installed; skip loading environment variables
|
# python-dotenv is not installed; skip loading environment variables
|
||||||
dotenv = None
|
dotenv = None
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
||||||
@@ -17,42 +19,51 @@ from llama_index.core.node_parser import SentenceSplitter
|
|||||||
# Default Chrome profile path
|
# Default Chrome profile path
|
||||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||||
|
|
||||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
|
||||||
|
def create_leann_index_from_multiple_chrome_profiles(
|
||||||
|
profile_dirs: list[Path],
|
||||||
|
index_path: str = "chrome_history_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from multiple Chrome profile data sources.
|
Create LEANN index from multiple Chrome profile data sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
profile_dirs: List of Path objects pointing to Chrome profile directories
|
||||||
index_path: Path to save the LEANN index
|
index_path: Path to save the LEANN index
|
||||||
max_count: Maximum number of history entries to process per profile
|
max_count: Maximum number of history entries to process per profile
|
||||||
|
embedding_model: The embedding model to use
|
||||||
|
embedding_mode: The embedding backend mode
|
||||||
"""
|
"""
|
||||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
print("Creating LEANN index from multiple Chrome profile data sources...")
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from history_data
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
from history_data.history import ChromeHistoryReader
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
reader = ChromeHistoryReader()
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
all_documents = []
|
all_documents = []
|
||||||
total_processed = 0
|
total_processed = 0
|
||||||
|
|
||||||
# Process each Chrome profile directory
|
# Process each Chrome profile directory
|
||||||
for i, profile_dir in enumerate(profile_dirs):
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
print(f"\nProcessing Chrome profile {i + 1}/{len(profile_dirs)}: {profile_dir}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents = reader.load_data(
|
documents = reader.load_data(
|
||||||
chrome_profile_path=str(profile_dir),
|
chrome_profile_path=str(profile_dir), max_count=max_count
|
||||||
max_count=max_count
|
|
||||||
)
|
)
|
||||||
if documents:
|
if documents:
|
||||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
total_processed += len(documents)
|
total_processed += len(documents)
|
||||||
|
|
||||||
# Check if we've reached the max count
|
# Check if we've reached the max count
|
||||||
if max_count > 0 and total_processed >= max_count:
|
if max_count > 0 and total_processed >= max_count:
|
||||||
print(f"Reached max count of {max_count} documents")
|
print(f"Reached max count of {max_count} documents")
|
||||||
@@ -62,18 +73,22 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {profile_dir}: {e}")
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not all_documents:
|
if not all_documents:
|
||||||
print("No documents loaded from any source. Exiting.")
|
print("No documents loaded from any source. Exiting.")
|
||||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
||||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
print(
|
||||||
|
"\033[91mYou need to close or quit all chrome browser before running this script\033[0m"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles"
|
||||||
|
)
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in all_documents:
|
for doc in all_documents:
|
||||||
@@ -83,77 +98,86 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
|
|||||||
text = node.get_content()
|
text = node.get_content()
|
||||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
||||||
all_texts.append(text)
|
all_texts.append(text)
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model=embedding_model,
|
||||||
graph_degree=32,
|
embedding_mode=embedding_mode,
|
||||||
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
for chunk_text in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk_text)
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
else:
|
else:
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
|
||||||
|
def create_leann_index(
|
||||||
|
profile_path: str | None = None,
|
||||||
|
index_path: str = "chrome_history_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from Chrome history data.
|
Create LEANN index from Chrome history data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
||||||
index_path: Path to save the LEANN index
|
index_path: Path to save the LEANN index
|
||||||
max_count: Maximum number of history entries to process
|
max_count: Maximum number of history entries to process
|
||||||
|
embedding_model: The embedding model to use
|
||||||
|
embedding_mode: The embedding backend mode
|
||||||
"""
|
"""
|
||||||
print("Creating LEANN index from Chrome history data...")
|
print("Creating LEANN index from Chrome history data...")
|
||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from history_data
|
# Load documents using ChromeHistoryReader from history_data
|
||||||
from history_data.history import ChromeHistoryReader
|
from history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
reader = ChromeHistoryReader()
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
documents = reader.load_data(
|
documents = reader.load_data(chrome_profile_path=profile_path, max_count=max_count)
|
||||||
chrome_profile_path=profile_path,
|
|
||||||
max_count=max_count
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
print("No documents loaded. Exiting.")
|
print("No documents loaded. Exiting.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} history documents")
|
print(f"Loaded {len(documents)} history documents")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
@@ -161,54 +185,57 @@ def create_leann_index(profile_path: str = None, index_path: str = "chrome_histo
|
|||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
|
# LeannBuilder will automatically detect normalized embeddings and set appropriate distance metric
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model=embedding_model,
|
||||||
graph_degree=32,
|
embedding_mode=embedding_mode,
|
||||||
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
print(f"Adding {len(all_texts)} history chunks to index...")
|
||||||
for chunk_text in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk_text)
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
else:
|
else:
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
async def query_leann_index(index_path: str, query: str):
|
||||||
"""
|
"""
|
||||||
Query the LEANN index.
|
Query the LEANN index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to the LEANN index
|
index_path: Path to the LEANN index
|
||||||
query: The query string
|
query: The query string
|
||||||
"""
|
"""
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
chat = LeannChat(index_path=index_path)
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(
|
chat_response = chat.ask(
|
||||||
query,
|
query,
|
||||||
top_k=10,
|
top_k=10,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
complexity=32,
|
complexity=32,
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
@@ -217,55 +244,104 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
"model": "gpt-4o",
|
"model": "gpt-4o",
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
},
|
},
|
||||||
llm_kwargs={
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": 1000
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
description="LEANN Chrome History Reader - Create and query browser history index"
|
||||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
)
|
||||||
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
parser.add_argument(
|
||||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
"--chrome-profile",
|
||||||
parser.add_argument('--max-entries', type=int, default=1000,
|
type=str,
|
||||||
help='Maximum number of history entries to process (default: 1000)')
|
default=DEFAULT_CHROME_PROFILE,
|
||||||
parser.add_argument('--query', type=str, default=None,
|
help=f"Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this",
|
||||||
help='Single query to run (default: runs example queries)')
|
)
|
||||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
parser.add_argument(
|
||||||
help='Automatically find all Chrome profiles (default: True)')
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./google_history_index",
|
||||||
|
help="Directory to store the LEANN index (default: ./chrome_history_index_leann_test)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-entries",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Maximum number of history entries to process (default: 1000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="The embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-existing-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Use existing index without rebuilding",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
INDEX_DIR = Path(args.index_dir)
|
||||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
||||||
|
|
||||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
print(f"Using Chrome profile: {args.chrome_profile}")
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
print(f"Max entries: {args.max_entries}")
|
print(f"Max entries: {args.max_entries}")
|
||||||
|
|
||||||
# Find Chrome profile directories
|
if args.use_existing_index:
|
||||||
from history_data.history import ChromeHistoryReader
|
# Use existing index without rebuilding
|
||||||
|
if not Path(INDEX_PATH).exists():
|
||||||
if args.auto_find_profiles:
|
print(f"Error: Index file not found at {INDEX_PATH}")
|
||||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
|
||||||
if not profile_dirs:
|
|
||||||
print("No Chrome profiles found automatically. Exiting.")
|
|
||||||
return
|
return
|
||||||
|
print(f"Using existing index at {INDEX_PATH}")
|
||||||
|
index_path = INDEX_PATH
|
||||||
else:
|
else:
|
||||||
# Use single specified profile
|
# Find Chrome profile directories
|
||||||
profile_path = Path(args.chrome_profile)
|
from history_data.history import ChromeHistoryReader
|
||||||
if not profile_path.exists():
|
|
||||||
print(f"Chrome profile not found: {profile_path}")
|
if args.auto_find_profiles:
|
||||||
return
|
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
||||||
profile_dirs = [profile_path]
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found automatically. Exiting.")
|
||||||
# Create or load the LEANN index from all sources
|
return
|
||||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
else:
|
||||||
|
# Use single specified profile
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
if not profile_path.exists():
|
||||||
|
print(f"Chrome profile not found: {profile_path}")
|
||||||
|
return
|
||||||
|
profile_dirs = [profile_path]
|
||||||
|
|
||||||
|
# Create or load the LEANN index from all sources
|
||||||
|
index_path = create_leann_index_from_multiple_chrome_profiles(
|
||||||
|
profile_dirs, INDEX_PATH, args.max_entries, args.embedding_model, args.embedding_mode
|
||||||
|
)
|
||||||
|
|
||||||
if index_path:
|
if index_path:
|
||||||
if args.query:
|
if args.query:
|
||||||
# Run single query
|
# Run single query
|
||||||
@@ -274,12 +350,13 @@ async def main():
|
|||||||
# Example queries
|
# Example queries
|
||||||
queries = [
|
queries = [
|
||||||
"What websites did I visit about machine learning?",
|
"What websites did I visit about machine learning?",
|
||||||
"Find my search history about programming"
|
"Find my search history about programming",
|
||||||
]
|
]
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
await query_leann_index(index_path, query)
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .history import ChromeHistoryReader
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
__all__ = ['ChromeHistoryReader']
|
__all__ = ["ChromeHistoryReader"]
|
||||||
|
|||||||
@@ -1,77 +1,81 @@
|
|||||||
import sqlite3
|
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
class ChromeHistoryReader(BaseReader):
|
class ChromeHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
|
|
||||||
Reads Chrome history from the default Chrome profile location and creates documents
|
Reads Chrome history from the default Chrome profile location and creates documents
|
||||||
with embedded metadata similar to the email reader structure.
|
with embedded metadata similar to the email reader structure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load Chrome history data from the default Chrome profile location.
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dir: Not used for Chrome history (kept for compatibility)
|
input_dir: Not used for Chrome history (kept for compatibility)
|
||||||
**load_kwargs:
|
**load_kwargs:
|
||||||
max_count (int): Maximum amount of history entries to read.
|
max_count (int): Maximum amount of history entries to read.
|
||||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||||
|
|
||||||
# Default Chrome profile path on macOS
|
# Default Chrome profile path on macOS
|
||||||
if chrome_profile_path is None:
|
if chrome_profile_path is None:
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
if not os.path.exists(history_db_path):
|
||||||
print(f"Chrome history database not found at: {history_db_path}")
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Connect to the Chrome history database
|
# Connect to the Chrome history database
|
||||||
print(f"Connecting to database: {history_db_path}")
|
print(f"Connecting to database: {history_db_path}")
|
||||||
conn = sqlite3.connect(history_db_path)
|
conn = sqlite3.connect(history_db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Query to get browsing history with metadata (removed created_time column)
|
# Query to get browsing history with metadata (removed created_time column)
|
||||||
query = """
|
query = """
|
||||||
SELECT
|
SELECT
|
||||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
url,
|
url,
|
||||||
title,
|
title,
|
||||||
visit_count,
|
visit_count,
|
||||||
typed_count,
|
typed_count,
|
||||||
hidden
|
hidden
|
||||||
FROM urls
|
FROM urls
|
||||||
ORDER BY last_visit_time DESC
|
ORDER BY last_visit_time DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"Executing query on database: {history_db_path}")
|
print(f"Executing query on database: {history_db_path}")
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
print(f"Query returned {len(rows)} rows")
|
print(f"Query returned {len(rows)} rows")
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for row in rows:
|
for row in rows:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
[Title]: {title}
|
[Title]: {title}
|
||||||
@@ -80,38 +84,38 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
[Visit times]: {visit_count}
|
[Visit times]: {visit_count}
|
||||||
[Typed times]: {typed_count}
|
[Typed times]: {typed_count}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||||
# if len(title) > 150:
|
# if len(title) > 150:
|
||||||
# print(f"Title is too long: {title}")
|
# print(f"Title is too long: {title}")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"Loaded {len(docs)} Chrome history documents")
|
print(f"Loaded {len(docs)} Chrome history documents")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading Chrome history: {e}")
|
print(f"Error reading Chrome history: {e}")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_chrome_profiles() -> List[Path]:
|
def find_chrome_profiles() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all Chrome profile directories.
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Path objects pointing to Chrome profile directories
|
List of Path objects pointing to Chrome profile directories
|
||||||
"""
|
"""
|
||||||
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||||
profile_dirs = []
|
profile_dirs = []
|
||||||
|
|
||||||
if not chrome_base_path.exists():
|
if not chrome_base_path.exists():
|
||||||
print(f"Chrome directory not found at: {chrome_base_path}")
|
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||||
return profile_dirs
|
return profile_dirs
|
||||||
|
|
||||||
# Find all profile directories
|
# Find all profile directories
|
||||||
for profile_dir in chrome_base_path.iterdir():
|
for profile_dir in chrome_base_path.iterdir():
|
||||||
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||||
@@ -119,53 +123,59 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if history_path.exists():
|
if history_path.exists():
|
||||||
profile_dirs.append(profile_dir)
|
profile_dirs.append(profile_dir)
|
||||||
print(f"Found Chrome profile: {profile_dir}")
|
print(f"Found Chrome profile: {profile_dir}")
|
||||||
|
|
||||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
return profile_dirs
|
return profile_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
def export_history_to_file(
|
||||||
|
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export Chrome history to a text file using the same SQL query format.
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_file: Path to the output file
|
output_file: Path to the output file
|
||||||
max_count: Maximum number of entries to export
|
max_count: Maximum number of entries to export
|
||||||
"""
|
"""
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
if not os.path.exists(history_db_path):
|
||||||
print(f"Chrome history database not found at: {history_db_path}")
|
print(f"Chrome history database not found at: {history_db_path}")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(history_db_path)
|
conn = sqlite3.connect(history_db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
SELECT
|
SELECT
|
||||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||||
url,
|
url,
|
||||||
title,
|
title,
|
||||||
visit_count,
|
visit_count,
|
||||||
typed_count,
|
typed_count,
|
||||||
hidden
|
hidden
|
||||||
FROM urls
|
FROM urls
|
||||||
ORDER BY last_visit_time DESC
|
ORDER BY last_visit_time DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cursor.execute(query, (max_count,))
|
cursor.execute(query, (max_count,))
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
f.write(
|
||||||
|
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||||
|
)
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error exporting Chrome history: {e}")
|
print(f"Error exporting Chrome history: {e}")
|
||||||
|
|||||||
@@ -2,30 +2,31 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
class WeChatHistoryReader(BaseReader):
|
class WeChatHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
WeChat chat history reader that extracts chat data from exported JSON files.
|
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||||
|
|
||||||
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||||
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||||
|
|
||||||
Also includes utilities for automatic WeChat chat history export.
|
Also includes utilities for automatic WeChat chat history export.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||||
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||||
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||||
|
|
||||||
def check_wechat_running(self) -> bool:
|
def check_wechat_running(self) -> bool:
|
||||||
"""Check if WeChat is currently running."""
|
"""Check if WeChat is currently running."""
|
||||||
try:
|
try:
|
||||||
@@ -33,24 +34,30 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def install_wechattweak(self) -> bool:
|
def install_wechattweak(self) -> bool:
|
||||||
"""Install WeChatTweak CLI tool."""
|
"""Install WeChatTweak CLI tool."""
|
||||||
try:
|
try:
|
||||||
# Create wechat-exporter directory if it doesn't exist
|
# Create wechat-exporter directory if it doesn't exist
|
||||||
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
if not wechattweak_path.exists():
|
if not wechattweak_path.exists():
|
||||||
print("Downloading WeChatTweak CLI...")
|
print("Downloading WeChatTweak CLI...")
|
||||||
subprocess.run([
|
subprocess.run(
|
||||||
"curl", "-L", "-o", str(wechattweak_path),
|
[
|
||||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
"curl",
|
||||||
], check=True)
|
"-L",
|
||||||
|
"-o",
|
||||||
|
str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Make executable
|
# Make executable
|
||||||
wechattweak_path.chmod(0o755)
|
wechattweak_path.chmod(0o755)
|
||||||
|
|
||||||
# Install WeChatTweak
|
# Install WeChatTweak
|
||||||
print("Installing WeChatTweak...")
|
print("Installing WeChatTweak...")
|
||||||
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||||
@@ -58,7 +65,7 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error installing WeChatTweak: {e}")
|
print(f"Error installing WeChatTweak: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def restart_wechat(self):
|
def restart_wechat(self):
|
||||||
"""Restart WeChat to apply WeChatTweak."""
|
"""Restart WeChat to apply WeChatTweak."""
|
||||||
try:
|
try:
|
||||||
@@ -69,302 +76,325 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
time.sleep(5) # Wait for WeChat to start
|
time.sleep(5) # Wait for WeChat to start
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error restarting WeChat: {e}")
|
print(f"Error restarting WeChat: {e}")
|
||||||
|
|
||||||
def check_api_available(self) -> bool:
|
def check_api_available(self) -> bool:
|
||||||
"""Check if WeChatTweak API is available."""
|
"""Check if WeChatTweak API is available."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||||
], capture_output=True, text=True, timeout=5)
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_readable_text(self, content: str) -> str:
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extract readable text from message content, removing XML and system messages.
|
Extract readable text from message content, removing XML and system messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The raw message content (can be string or dict)
|
content: The raw message content (can be string or dict)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Cleaned, readable text
|
Cleaned, readable text
|
||||||
"""
|
"""
|
||||||
if not content:
|
if not content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Handle dictionary content (like quoted messages)
|
# Handle dictionary content (like quoted messages)
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Extract text from dictionary structure
|
# Extract text from dictionary structure
|
||||||
text_parts = []
|
text_parts = []
|
||||||
if 'title' in content:
|
if "title" in content:
|
||||||
text_parts.append(str(content['title']))
|
text_parts.append(str(content["title"]))
|
||||||
if 'quoted' in content:
|
if "quoted" in content:
|
||||||
text_parts.append(str(content['quoted']))
|
text_parts.append(str(content["quoted"]))
|
||||||
if 'content' in content:
|
if "content" in content:
|
||||||
text_parts.append(str(content['content']))
|
text_parts.append(str(content["content"]))
|
||||||
if 'text' in content:
|
if "text" in content:
|
||||||
text_parts.append(str(content['text']))
|
text_parts.append(str(content["text"]))
|
||||||
|
|
||||||
if text_parts:
|
if text_parts:
|
||||||
return " | ".join(text_parts)
|
return " | ".join(text_parts)
|
||||||
else:
|
else:
|
||||||
# If we can't extract meaningful text from dict, return empty
|
# If we can't extract meaningful text from dict, return empty
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Handle string content
|
# Handle string content
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n"
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If it's just XML or system message, return empty
|
# If it's just XML or system message, return empty
|
||||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return clean_content.strip()
|
return clean_content.strip()
|
||||||
|
|
||||||
def _is_text_message(self, content: str) -> bool:
|
def _is_text_message(self, content: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a message contains readable text content.
|
Check if a message contains readable text content.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The message content (can be string or dict)
|
content: The message content (can be string or dict)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the message contains readable text, False otherwise
|
True if the message contains readable text, False otherwise
|
||||||
"""
|
"""
|
||||||
if not content:
|
if not content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Handle dictionary content
|
# Handle dictionary content
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Check if dict has any readable text fields
|
# Check if dict has any readable text fields
|
||||||
text_fields = ['title', 'quoted', 'content', 'text']
|
text_fields = ["title", "quoted", "content", "text"]
|
||||||
for field in text_fields:
|
for field in text_fields:
|
||||||
if field in content and content[field]:
|
if content.get(field):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Handle string content
|
# Handle string content
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip image messages (contain XML with img tags)
|
# Skip image messages (contain XML with img tags)
|
||||||
if '<img' in content and 'cdnurl' in content:
|
if "<img" in content and "cdnurl" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip emoji messages (contain emoji XML tags)
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
if '<emoji' in content and 'productid' in content:
|
if "<emoji" in content and "productid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip voice messages
|
# Skip voice messages
|
||||||
if '<voice' in content:
|
if "<voice" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip video messages
|
# Skip video messages
|
||||||
if '<video' in content:
|
if "<video" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip file messages
|
# Skip file messages
|
||||||
if '<appmsg' in content and 'appid' in content:
|
if "<appmsg" in content and "appid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip system messages (like "recalled a message")
|
# Skip system messages (like "recalled a message")
|
||||||
if 'recalled a message' in content:
|
if "recalled a message" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if there's actual readable text (not just XML or system messages)
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If after cleaning we have meaningful text, consider it readable
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
def _concatenate_messages(
|
||||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30,
|
||||||
|
overlap_messages: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Concatenate messages based on length and time rules.
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dictionaries
|
messages: List of message dictionaries
|
||||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||||
overlap_messages: Number of messages to overlap between consecutive groups
|
overlap_messages: Number of messages to overlap between consecutive groups
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of concatenated message groups
|
List of concatenated message groups
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
concatenated_groups = []
|
concatenated_groups = []
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
last_timestamp = None
|
last_timestamp = None
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Extract message info
|
# Extract message info
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
if not readable_text:
|
if not readable_text:
|
||||||
readable_text = message_text
|
readable_text = message_text
|
||||||
|
|
||||||
# Skip empty messages
|
# Skip empty messages
|
||||||
if not readable_text.strip():
|
if not readable_text.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check time window constraint (only if time_window_minutes != -1)
|
# Check time window constraint (only if time_window_minutes != -1)
|
||||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||||
if time_diff_minutes > time_window_minutes:
|
if time_diff_minutes > time_window_minutes:
|
||||||
# Time gap too large, start new group
|
# Time gap too large, start new group
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
|
|
||||||
# Check length constraint (only if max_length != -1)
|
# Check length constraint (only if max_length != -1)
|
||||||
message_length = len(readable_text)
|
message_length = len(readable_text)
|
||||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
# Current group would exceed max length, save it and start new
|
# Current group would exceed max length, save it and start new
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
|
|
||||||
# Add message to current group
|
# Add message to current group
|
||||||
current_group.append(message)
|
current_group.append(message)
|
||||||
current_length += message_length
|
current_length += message_length
|
||||||
last_timestamp = create_time
|
last_timestamp = create_time
|
||||||
|
|
||||||
# Add the last group if it exists
|
# Add the last group if it exists
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return concatenated_groups
|
return concatenated_groups
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create concatenated content from a group of messages.
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_group: Dictionary containing messages and metadata
|
message_group: Dictionary containing messages and metadata
|
||||||
contact_name: Name of the contact
|
contact_name: Name of the contact
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted concatenated content
|
Formatted concatenated content
|
||||||
"""
|
"""
|
||||||
messages = message_group['messages']
|
messages = message_group["messages"]
|
||||||
start_time = message_group['start_time']
|
start_time = message_group["start_time"]
|
||||||
end_time = message_group['end_time']
|
end_time = message_group["end_time"]
|
||||||
|
|
||||||
# Format timestamps
|
# Format timestamps
|
||||||
if start_time:
|
if start_time:
|
||||||
try:
|
try:
|
||||||
start_timestamp = datetime.fromtimestamp(start_time)
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
start_time_str = str(start_time)
|
start_time_str = str(start_time)
|
||||||
else:
|
else:
|
||||||
start_time_str = "Unknown"
|
start_time_str = "Unknown"
|
||||||
|
|
||||||
if end_time:
|
if end_time:
|
||||||
try:
|
try:
|
||||||
end_timestamp = datetime.fromtimestamp(end_time)
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
end_time_str = str(end_time)
|
end_time_str = str(end_time)
|
||||||
else:
|
else:
|
||||||
end_time_str = "Unknown"
|
end_time_str = "Unknown"
|
||||||
|
|
||||||
# Build concatenated message content
|
# Build concatenated message content
|
||||||
message_parts = []
|
message_parts = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
if not readable_text:
|
if not readable_text:
|
||||||
readable_text = message_text
|
readable_text = message_text
|
||||||
|
|
||||||
# Format individual message
|
# Format individual message
|
||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
# change to YYYY-MM-DD HH:MM:SS
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
|
|
||||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||||
|
|
||||||
concatenated_text = "\n".join(message_parts)
|
concatenated_text = "\n".join(message_parts)
|
||||||
|
|
||||||
# Create final document content
|
# Create final document content
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
Contact: {contact_name}
|
||||||
Time Range: {start_time_str} - {end_time_str}
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
# TODO @yichuan give better format and rich info here!
|
# TODO @yichuan give better format and rich info here!
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
return doc_content, contact_name
|
return doc_content, contact_name
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load WeChat chat history data from exported JSON files.
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dir: Directory containing exported WeChat JSON files
|
input_dir: Directory containing exported WeChat JSON files
|
||||||
**load_kwargs:
|
**load_kwargs:
|
||||||
@@ -376,97 +406,104 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
include_non_text = load_kwargs.get('include_non_text', False)
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
max_length = load_kwargs.get('max_length', 1000)
|
load_kwargs.get("max_length", 1000)
|
||||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
# Default WeChat export path
|
# Default WeChat export path
|
||||||
if wechat_export_dir is None:
|
if wechat_export_dir is None:
|
||||||
wechat_export_dir = "./wechat_export_test"
|
wechat_export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
if not os.path.exists(wechat_export_dir):
|
if not os.path.exists(wechat_export_dir):
|
||||||
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find all JSON files in the export directory
|
# Find all JSON files in the export directory
|
||||||
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||||
print(f"Found {len(json_files)} WeChat chat history files")
|
print(f"Found {len(json_files)} WeChat chat history files")
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as f:
|
with open(json_file, encoding="utf-8") as f:
|
||||||
chat_data = json.load(f)
|
chat_data = json.load(f)
|
||||||
|
|
||||||
# Extract contact name from filename
|
# Extract contact name from filename
|
||||||
contact_name = json_file.stem
|
contact_name = json_file.stem
|
||||||
|
|
||||||
if concatenate_messages:
|
if concatenate_messages:
|
||||||
# Filter messages to only include readable text messages
|
# Filter messages to only include readable text messages
|
||||||
readable_messages = []
|
readable_messages = []
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
try:
|
try:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
if not include_non_text and not self._is_text_message(content):
|
if not include_non_text and not self._is_text_message(content):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
if not readable_text and not include_non_text:
|
if not readable_text and not include_non_text:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
readable_messages.append(message)
|
readable_messages.append(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing message in {json_file}: {e}")
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Concatenate messages based on rules
|
# Concatenate messages based on rules
|
||||||
message_groups = self._concatenate_messages(
|
message_groups = self._concatenate_messages(
|
||||||
readable_messages,
|
readable_messages,
|
||||||
max_length=-1,
|
max_length=-1,
|
||||||
time_window_minutes=-1,
|
time_window_minutes=-1,
|
||||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
overlap_messages=0, # Keep 2 messages overlap between groups
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
# Create documents from concatenated groups
|
||||||
for message_group in message_groups:
|
for message_group in message_groups:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
doc_content, contact_name = self._create_concatenated_content(
|
||||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
message_group, contact_name
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
print(
|
||||||
|
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Original single-message processing
|
# Original single-message processing
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Extract message information
|
# Extract message information
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Handle content that might be dict or string
|
# Handle content that might be dict or string
|
||||||
try:
|
try:
|
||||||
# Check if this is a readable text message
|
# Check if this is a readable text message
|
||||||
if not include_non_text and not self._is_text_message(content):
|
if not include_non_text and not self._is_text_message(content):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
if not readable_text and not include_non_text:
|
if not readable_text and not include_non_text:
|
||||||
@@ -475,17 +512,17 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
# Skip messages that cause processing errors
|
# Skip messages that cause processing errors
|
||||||
print(f"Error processing message in {json_file}: {e}")
|
print(f"Error processing message in {json_file}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Convert timestamp to readable format
|
# Convert timestamp to readable format
|
||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
|
|
||||||
# Create document content with metadata header and contact info
|
# Create document content with metadata header and contact info
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
Contact: {contact_name}
|
||||||
@@ -493,57 +530,64 @@ Is sent from self: {is_sent_from_self}
|
|||||||
Time: {time_str}
|
Time: {time_str}
|
||||||
Message: {readable_text if readable_text else message_text}
|
Message: {readable_text if readable_text else message_text}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(text=doc_content, metadata={})
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading {json_file}: {e}")
|
print(f"Error reading {json_file}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loaded {len(docs)} WeChat chat documents")
|
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading WeChat history: {e}")
|
print(f"Error reading WeChat history: {e}")
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_wechat_export_dirs() -> List[Path]:
|
def find_wechat_export_dirs() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all WeChat export directories.
|
Find all WeChat export directories.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Path objects pointing to WeChat export directories
|
List of Path objects pointing to WeChat export directories
|
||||||
"""
|
"""
|
||||||
export_dirs = []
|
export_dirs = []
|
||||||
|
|
||||||
# Look for common export directory names
|
# Look for common export directory names
|
||||||
possible_dirs = [
|
possible_dirs = [
|
||||||
Path("./wechat_export_test"),
|
Path("./wechat_export_test"),
|
||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir in possible_dirs:
|
for export_dir in possible_dirs:
|
||||||
if export_dir.exists() and export_dir.is_dir():
|
if export_dir.exists() and export_dir.is_dir():
|
||||||
json_files = list(export_dir.glob("*.json"))
|
json_files = list(export_dir.glob("*.json"))
|
||||||
if json_files:
|
if json_files:
|
||||||
export_dirs.append(export_dir)
|
export_dirs.append(export_dir)
|
||||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
print(
|
||||||
|
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
return export_dirs
|
return export_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
def export_chat_to_file(
|
||||||
|
output_file: str = "wechat_chat_export.txt",
|
||||||
|
max_count: int = 1000,
|
||||||
|
export_dir: str | None = None,
|
||||||
|
include_non_text: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history to a text file.
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_file: Path to the output file
|
output_file: Path to the output file
|
||||||
max_count: Maximum number of entries to export
|
max_count: Maximum number of entries to export
|
||||||
@@ -552,36 +596,36 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
"""
|
"""
|
||||||
if export_dir is None:
|
if export_dir is None:
|
||||||
export_dir = "./wechat_export_test"
|
export_dir = "./wechat_export_test"
|
||||||
|
|
||||||
if not os.path.exists(export_dir):
|
if not os.path.exists(export_dir):
|
||||||
print(f"WeChat export directory not found at: {export_dir}")
|
print(f"WeChat export directory not found at: {export_dir}")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
json_files = list(Path(export_dir).glob("*.json"))
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
count = 0
|
count = 0
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
with open(json_file, encoding="utf-8") as json_f:
|
||||||
chat_data = json.load(json_f)
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
contact_name = json_file.stem
|
contact_name = json_file.stem
|
||||||
f.write(f"\n=== Chat with {contact_name} ===\n")
|
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||||
|
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
from_user = message.get('fromUser', '')
|
from_user = message.get("fromUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
|
|
||||||
# Skip non-text messages unless requested
|
# Skip non-text messages unless requested
|
||||||
if not include_non_text:
|
if not include_non_text:
|
||||||
reader = WeChatHistoryReader()
|
reader = WeChatHistoryReader()
|
||||||
@@ -591,83 +635,90 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if not readable_text:
|
if not readable_text:
|
||||||
continue
|
continue
|
||||||
message_text = readable_text
|
message_text = readable_text
|
||||||
|
|
||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
|
|
||||||
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {json_file}: {e}")
|
print(f"Error processing {json_file}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Exported {count} chat entries to {output_file}")
|
print(f"Exported {count} chat entries to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error exporting WeChat chat history: {e}")
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history using wechat-exporter tool.
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
export_dir: Directory to save exported chat history
|
export_dir: Directory to save exported chat history
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to export directory if successful, None otherwise
|
Path to export directory if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Create export directory
|
# Create export directory
|
||||||
export_path = Path(export_dir)
|
export_path = Path(export_dir)
|
||||||
export_path.mkdir(exist_ok=True)
|
export_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"Exporting WeChat chat history to {export_path}...")
|
print(f"Exporting WeChat chat history to {export_path}...")
|
||||||
|
|
||||||
# Check if wechat-exporter directory exists
|
# Check if wechat-exporter directory exists
|
||||||
if not self.wechat_exporter_dir.exists():
|
if not self.wechat_exporter_dir.exists():
|
||||||
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Install requirements if needed
|
# Install requirements if needed
|
||||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
if requirements_file.exists():
|
if requirements_file.exists():
|
||||||
print("Installing wechat-exporter requirements...")
|
print("Installing wechat-exporter requirements...")
|
||||||
subprocess.run([
|
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||||
"uv", "pip", "install", "-r", str(requirements_file)
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
# Run the export command
|
# Run the export command
|
||||||
print("Running wechat-exporter...")
|
print("Running wechat-exporter...")
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
[
|
||||||
"export-all", str(export_path)
|
sys.executable,
|
||||||
], capture_output=True, text=True, check=True)
|
str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all",
|
||||||
|
str(export_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
print("Export command output:")
|
print("Export command output:")
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
if result.stderr:
|
if result.stderr:
|
||||||
print("Export errors:")
|
print("Export errors:")
|
||||||
print(result.stderr)
|
print(result.stderr)
|
||||||
|
|
||||||
# Check if export was successful
|
# Check if export was successful
|
||||||
if export_path.exists() and any(export_path.glob("*.json")):
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
json_files = list(export_path.glob("*.json"))
|
json_files = list(export_path.glob("*.json"))
|
||||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
print(
|
||||||
|
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||||
|
)
|
||||||
return export_path
|
return export_path
|
||||||
else:
|
else:
|
||||||
print("Export completed but no JSON files found")
|
print("Export completed but no JSON files found")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
print(f"Export command failed: {e}")
|
print(f"Export command failed: {e}")
|
||||||
print(f"Command output: {e.stdout}")
|
print(f"Command output: {e.stdout}")
|
||||||
@@ -678,18 +729,18 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find existing WeChat exports or create new ones.
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
export_dir: Directory to save exported chat history if needed
|
export_dir: Directory to save exported chat history if needed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of Path objects pointing to WeChat export directories
|
List of Path objects pointing to WeChat export directories
|
||||||
"""
|
"""
|
||||||
export_dirs = []
|
export_dirs = []
|
||||||
|
|
||||||
# Look for existing exports in common locations
|
# Look for existing exports in common locations
|
||||||
possible_export_dirs = [
|
possible_export_dirs = [
|
||||||
Path("./wechat_database_export"),
|
Path("./wechat_database_export"),
|
||||||
@@ -697,23 +748,25 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir_path in possible_export_dirs:
|
for export_dir_path in possible_export_dirs:
|
||||||
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||||
export_dirs.append(export_dir_path)
|
export_dirs.append(export_dir_path)
|
||||||
print(f"Found existing export: {export_dir_path}")
|
print(f"Found existing export: {export_dir_path}")
|
||||||
|
|
||||||
# If no existing exports, try to export automatically
|
# If no existing exports, try to export automatically
|
||||||
if not export_dirs:
|
if not export_dirs:
|
||||||
print("No existing WeChat exports found. Starting direct export...")
|
print("No existing WeChat exports found. Starting direct export...")
|
||||||
|
|
||||||
# Try to export using wechat-exporter
|
# Try to export using wechat-exporter
|
||||||
exported_path = self.export_wechat_chat_history(export_dir)
|
exported_path = self.export_wechat_chat_history(export_dir)
|
||||||
if exported_path:
|
if exported_path:
|
||||||
export_dirs = [exported_path]
|
export_dirs = [exported_path]
|
||||||
else:
|
else:
|
||||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
print(
|
||||||
|
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||||
return export_dirs
|
)
|
||||||
|
|
||||||
|
return export_dirs
|
||||||
|
|||||||
@@ -1,33 +1,42 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
# Add the project root to Python path so we can import from examples
|
# Add the project root to Python path so we can import from examples
|
||||||
project_root = Path(__file__).parent.parent
|
project_root = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
# Auto-detect user's mail path
|
# Auto-detect user's mail path
|
||||||
def get_mail_path():
|
def get_mail_path():
|
||||||
"""Get the mail path for the current user"""
|
"""Get the mail path for the current user"""
|
||||||
home_dir = os.path.expanduser("~")
|
home_dir = os.path.expanduser("~")
|
||||||
return os.path.join(home_dir, "Library", "Mail")
|
return os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
# Default mail path for macOS
|
|
||||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
# Default mail path for macOS
|
||||||
|
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
||||||
|
|
||||||
|
|
||||||
|
def create_leann_index_from_multiple_sources(
|
||||||
|
messages_dirs: list[Path],
|
||||||
|
index_path: str = "mail_index.leann",
|
||||||
|
max_count: int = -1,
|
||||||
|
include_html: bool = False,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from multiple mail data sources.
|
Create LEANN index from multiple mail data sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages_dirs: List of Path objects pointing to Messages directories
|
messages_dirs: List of Path objects pointing to Messages directories
|
||||||
index_path: Path to save the LEANN index
|
index_path: Path to save the LEANN index
|
||||||
@@ -35,31 +44,32 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
include_html: Whether to include HTML content in email processing
|
include_html: Whether to include HTML content in email processing
|
||||||
"""
|
"""
|
||||||
print("Creating LEANN index from multiple mail data sources...")
|
print("Creating LEANN index from multiple mail data sources...")
|
||||||
|
|
||||||
# Load documents using EmlxReader from LEANN_email_reader
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
reader = EmlxReader(include_html=include_html)
|
reader = EmlxReader(include_html=include_html)
|
||||||
# from email_data.email import EmlxMboxReader
|
# from email_data.email import EmlxMboxReader
|
||||||
# from pathlib import Path
|
# from pathlib import Path
|
||||||
# reader = EmlxMboxReader()
|
# reader = EmlxMboxReader()
|
||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
all_documents = []
|
all_documents = []
|
||||||
total_processed = 0
|
total_processed = 0
|
||||||
|
|
||||||
# Process each Messages directory
|
# Process each Messages directory
|
||||||
for i, messages_dir in enumerate(messages_dirs):
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
print(f"\nProcessing Messages directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents = reader.load_data(messages_dir)
|
documents = reader.load_data(messages_dir)
|
||||||
if documents:
|
if documents:
|
||||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
total_processed += len(documents)
|
total_processed += len(documents)
|
||||||
|
|
||||||
# Check if we've reached the max count
|
# Check if we've reached the max count
|
||||||
if max_count > 0 and total_processed >= max_count:
|
if max_count > 0 and total_processed >= max_count:
|
||||||
print(f"Reached max count of {max_count} documents")
|
print(f"Reached max count of {max_count} documents")
|
||||||
@@ -69,16 +79,18 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {messages_dir}: {e}")
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not all_documents:
|
if not all_documents:
|
||||||
print("No documents loaded from any source. Exiting.")
|
print("No documents loaded from any source. Exiting.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
print(
|
||||||
|
f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks"
|
||||||
|
)
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in all_documents:
|
for doc in all_documents:
|
||||||
@@ -88,44 +100,53 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
|
|||||||
text = node.get_content()
|
text = node.get_content()
|
||||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
||||||
all_texts.append(text)
|
all_texts.append(text)
|
||||||
|
|
||||||
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
print(
|
||||||
|
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
||||||
|
)
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
for chunk_text in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk_text)
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
else:
|
else:
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
|
||||||
|
def create_leann_index(
|
||||||
|
mail_path: str,
|
||||||
|
index_path: str = "mail_index.leann",
|
||||||
|
max_count: int = 1000,
|
||||||
|
include_html: bool = False,
|
||||||
|
embedding_model: str = "facebook/contriever",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Create LEANN index from mail data.
|
Create LEANN index from mail data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mail_path: Path to the mail directory
|
mail_path: Path to the mail directory
|
||||||
index_path: Path to save the LEANN index
|
index_path: Path to save the LEANN index
|
||||||
@@ -134,32 +155,33 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
|||||||
"""
|
"""
|
||||||
print("Creating LEANN index from mail data...")
|
print("Creating LEANN index from mail data...")
|
||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Load documents using EmlxReader from LEANN_email_reader
|
# Load documents using EmlxReader from LEANN_email_reader
|
||||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
reader = EmlxReader(include_html=include_html)
|
reader = EmlxReader(include_html=include_html)
|
||||||
# from email_data.email import EmlxMboxReader
|
# from email_data.email import EmlxMboxReader
|
||||||
# from pathlib import Path
|
# from pathlib import Path
|
||||||
# reader = EmlxMboxReader()
|
# reader = EmlxMboxReader()
|
||||||
documents = reader.load_data(Path(mail_path))
|
documents = reader.load_data(Path(mail_path))
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
print("No documents loaded. Exiting.")
|
print("No documents loaded. Exiting.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} email documents")
|
print(f"Loaded {len(documents)} email documents")
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
@@ -167,108 +189,139 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
|
|||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
is_recompute=True,
|
is_recompute=True,
|
||||||
num_threads=1 # Force single-threaded mode
|
num_threads=1, # Force single-threaded mode
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
print(f"Adding {len(all_texts)} email chunks to index...")
|
||||||
for chunk_text in all_texts:
|
for chunk_text in all_texts:
|
||||||
builder.add_text(chunk_text)
|
builder.add_text(chunk_text)
|
||||||
|
|
||||||
builder.build_index(index_path)
|
builder.build_index(index_path)
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
print(f"\nLEANN index built at {index_path}!")
|
||||||
else:
|
else:
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
return index_path
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
async def query_leann_index(index_path: str, query: str):
|
||||||
"""
|
"""
|
||||||
Query the LEANN index.
|
Query the LEANN index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_path: Path to the LEANN index
|
index_path: Path to the LEANN index
|
||||||
query: The query string
|
query: The query string
|
||||||
"""
|
"""
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
chat = LeannChat(index_path=index_path,
|
chat = LeannChat(index_path=index_path, llm_config={"type": "openai", "model": "gpt-4o"})
|
||||||
llm_config={"type": "openai", "model": "gpt-4o"})
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
|
||||||
|
time.time()
|
||||||
chat_response = chat.ask(
|
chat_response = chat.ask(
|
||||||
query,
|
query,
|
||||||
top_k=10,
|
top_k=20,
|
||||||
recompute_beighbor_embeddings=True,
|
recompute_beighbor_embeddings=True,
|
||||||
complexity=12,
|
complexity=32,
|
||||||
beam_width=1,
|
beam_width=1,
|
||||||
|
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
time.time()
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
# print(f"Time taken: {end_time - start_time} seconds")
|
||||||
print(f"Leann: {chat_response}")
|
# highlight the answer
|
||||||
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
parser = argparse.ArgumentParser(description="LEANN Mail Reader - Create and query email index")
|
||||||
# Remove --mail-path argument and auto-detect all Messages directories
|
# Remove --mail-path argument and auto-detect all Messages directories
|
||||||
# Remove DEFAULT_MAIL_PATH
|
# Remove DEFAULT_MAIL_PATH
|
||||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
|
parser.add_argument(
|
||||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
"--index-dir",
|
||||||
parser.add_argument('--max-emails', type=int, default=1000,
|
type=str,
|
||||||
help='Maximum number of emails to process (-1 means all)')
|
default="./mail_index",
|
||||||
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
help="Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)",
|
||||||
help='Single query to run (default: runs example queries)')
|
)
|
||||||
parser.add_argument('--include-html', action='store_true', default=False,
|
parser.add_argument(
|
||||||
help='Include HTML content in email processing (default: False)')
|
"--max-emails",
|
||||||
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
type=int,
|
||||||
help='Embedding model to use (default: facebook/contriever)')
|
default=1000,
|
||||||
|
help="Maximum number of emails to process (-1 means all)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="Give me some funny advertisement about apple or other companies",
|
||||||
|
help="Single query to run (default: runs example queries)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-html",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Include HTML content in email processing (default: False)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="Embedding model to use (default: facebook/contriever)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(f"args: {args}")
|
print(f"args: {args}")
|
||||||
|
|
||||||
# Automatically find all Messages directories under the current user's Mail directory
|
# Automatically find all Messages directories under the current user's Mail directory
|
||||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
||||||
|
|
||||||
mail_path = get_mail_path()
|
mail_path = get_mail_path()
|
||||||
print(f"Searching for email data in: {mail_path}")
|
print(f"Searching for email data in: {mail_path}")
|
||||||
messages_dirs = find_all_messages_directories(mail_path)
|
messages_dirs = find_all_messages_directories(mail_path)
|
||||||
|
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
|
||||||
print('len(messages_dirs): ', len(messages_dirs))
|
# messages_dirs = [DEFAULT_MAIL_PATH]
|
||||||
|
# messages_dirs = messages_dirs[:1]
|
||||||
|
|
||||||
|
print("len(messages_dirs): ", len(messages_dirs))
|
||||||
|
|
||||||
if not messages_dirs:
|
if not messages_dirs:
|
||||||
print("No Messages directories found. Exiting.")
|
print("No Messages directories found. Exiting.")
|
||||||
return
|
return
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
INDEX_DIR = Path(args.index_dir)
|
||||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
print(f"Index directory: {INDEX_DIR}")
|
||||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
print(f"Found {len(messages_dirs)} Messages directories.")
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
# Create or load the LEANN index from all sources
|
||||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
index_path = create_leann_index_from_multiple_sources(
|
||||||
|
messages_dirs,
|
||||||
|
INDEX_PATH,
|
||||||
|
args.max_emails,
|
||||||
|
args.include_html,
|
||||||
|
args.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
if index_path:
|
if index_path:
|
||||||
if args.query:
|
if args.query:
|
||||||
# Run single query
|
# Run single query
|
||||||
@@ -278,11 +331,12 @@ async def main():
|
|||||||
queries = [
|
queries = [
|
||||||
"Hows Berkeley Graduate Student Instructor",
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
"how's the icloud related advertisement saying",
|
"how's the icloud related advertisement saying",
|
||||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||||
]
|
]
|
||||||
for query in queries:
|
for query in queries:
|
||||||
print("\n" + "="*60)
|
print("\n" + "=" * 60)
|
||||||
await query_leann_index(index_path, query)
|
await query_leann_index(index_path, query)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,26 +1,30 @@
|
|||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
# Add the project root to Python path so we can import from examples
|
# Add the project root to Python path so we can import from examples
|
||||||
project_root = Path(__file__).parent.parent
|
project_root = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(project_root))
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
from llama_index.core import VectorStoreIndex, StorageContext
|
import torch
|
||||||
|
from llama_index.core import StorageContext, VectorStoreIndex
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# --- EMBEDDING MODEL ---
|
# --- EMBEDDING MODEL ---
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
import torch
|
|
||||||
|
|
||||||
# --- END EMBEDDING MODEL ---
|
# --- END EMBEDDING MODEL ---
|
||||||
|
|
||||||
# Import EmlxReader from the new module
|
# Import EmlxReader from the new module
|
||||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
from examples.email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
|
|
||||||
|
def create_and_save_index(
|
||||||
|
mail_path: str,
|
||||||
|
save_dir: str = "mail_index_embedded",
|
||||||
|
max_count: int = 1000,
|
||||||
|
include_html: bool = False,
|
||||||
|
):
|
||||||
print("Creating index from mail data with embedded metadata...")
|
print("Creating index from mail data with embedded metadata...")
|
||||||
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
||||||
if not documents:
|
if not documents:
|
||||||
@@ -30,7 +34,7 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded",
|
|||||||
# Use facebook/contriever as the embedder
|
# Use facebook/contriever as the embedder
|
||||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
||||||
# set on device
|
# set on device
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
embed_model._model.to("cuda")
|
embed_model._model.to("cuda")
|
||||||
# set mps
|
# set mps
|
||||||
@@ -39,21 +43,19 @@ def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded",
|
|||||||
else:
|
else:
|
||||||
embed_model._model.to("cpu")
|
embed_model._model.to("cpu")
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents, transformations=[text_splitter], embed_model=embed_model
|
||||||
transformations=[text_splitter],
|
|
||||||
embed_model=embed_model
|
|
||||||
)
|
)
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
index.storage_context.persist(persist_dir=save_dir)
|
index.storage_context.persist(persist_dir=save_dir)
|
||||||
print(f"Index saved to {save_dir}")
|
print(f"Index saved to {save_dir}")
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
|
||||||
def load_index(save_dir: str = "mail_index_embedded"):
|
def load_index(save_dir: str = "mail_index_embedded"):
|
||||||
try:
|
try:
|
||||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
||||||
index = VectorStoreIndex.from_vector_store(
|
index = VectorStoreIndex.from_vector_store(
|
||||||
storage_context.vector_store,
|
storage_context.vector_store, storage_context=storage_context
|
||||||
storage_context=storage_context
|
|
||||||
)
|
)
|
||||||
print(f"Index loaded from {save_dir}")
|
print(f"Index loaded from {save_dir}")
|
||||||
return index
|
return index
|
||||||
@@ -61,6 +63,7 @@ def load_index(save_dir: str = "mail_index_embedded"):
|
|||||||
print(f"Error loading index: {e}")
|
print(f"Error loading index: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def query_index(index, query: str):
|
def query_index(index, query: str):
|
||||||
if index is None:
|
if index is None:
|
||||||
print("No index available for querying.")
|
print("No index available for querying.")
|
||||||
@@ -70,39 +73,63 @@ def query_index(index, query: str):
|
|||||||
print(f"Query: {query}")
|
print(f"Query: {query}")
|
||||||
print(f"Response: {response}")
|
print(f"Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--mail-path', type=str,
|
description="LlamaIndex Mail Reader - Create and query email index"
|
||||||
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
)
|
||||||
help='Path to mail data directory')
|
parser.add_argument(
|
||||||
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
|
"--mail-path",
|
||||||
help='Directory to store the index (default: mail_index_embedded)')
|
type=str,
|
||||||
parser.add_argument('--max-emails', type=int, default=10000,
|
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
||||||
help='Maximum number of emails to process')
|
help="Path to mail data directory",
|
||||||
parser.add_argument('--include-html', action='store_true', default=False,
|
)
|
||||||
help='Include HTML content in email processing (default: False)')
|
parser.add_argument(
|
||||||
|
"--save-dir",
|
||||||
|
type=str,
|
||||||
|
default="mail_index_embedded",
|
||||||
|
help="Directory to store the index (default: mail_index_embedded)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-emails",
|
||||||
|
type=int,
|
||||||
|
default=10000,
|
||||||
|
help="Maximum number of emails to process",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include-html",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Include HTML content in email processing (default: False)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
mail_path = args.mail_path
|
mail_path = args.mail_path
|
||||||
save_dir = args.save_dir
|
save_dir = args.save_dir
|
||||||
|
|
||||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
||||||
print("Loading existing index...")
|
print("Loading existing index...")
|
||||||
index = load_index(save_dir)
|
index = load_index(save_dir)
|
||||||
else:
|
else:
|
||||||
print("Creating new index...")
|
print("Creating new index...")
|
||||||
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
|
index = create_and_save_index(
|
||||||
|
mail_path,
|
||||||
|
save_dir,
|
||||||
|
max_count=args.max_emails,
|
||||||
|
include_html=args.include_html,
|
||||||
|
)
|
||||||
if index:
|
if index:
|
||||||
queries = [
|
queries = [
|
||||||
"Hows Berkeley Graduate Student Instructor",
|
"Hows Berkeley Graduate Student Instructor",
|
||||||
"how's the icloud related advertisement saying",
|
"how's the icloud related advertisement saying",
|
||||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
"Whats the number of class recommend to take per semester for incoming EECS students",
|
||||||
]
|
]
|
||||||
for query in queries:
|
for query in queries:
|
||||||
print("\n" + "="*50)
|
print("\n" + "=" * 50)
|
||||||
query_index(index, query)
|
query_index(index, query)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from pathlib import Path
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -29,17 +30,22 @@ async def main(args):
|
|||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
if nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
|
||||||
print("--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
|
|
||||||
print("\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
|
# LeannBuilder now automatically detects normalized embeddings and sets appropriate distance metric
|
||||||
|
print(f"Using {args.embedding_model} with {args.embedding_mode} mode")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name="hnsw",
|
backend_name="hnsw",
|
||||||
embedding_model="facebook/contriever",
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
# distance_metric is automatically set based on embedding model
|
||||||
graph_degree=32,
|
graph_degree=32,
|
||||||
complexity=64,
|
complexity=64,
|
||||||
is_compact=True,
|
is_compact=True,
|
||||||
@@ -56,29 +62,35 @@ async def main(args):
|
|||||||
else:
|
else:
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
|
|
||||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
# Build llm_config based on command line arguments
|
||||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
if args.llm == "simulated":
|
||||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
llm_config = {"type": "simulated"}
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
llm_config = {"type": "ollama", "model": args.model, "host": args.host}
|
||||||
|
elif args.llm == "hf":
|
||||||
|
llm_config = {"type": "hf", "model": args.model}
|
||||||
|
elif args.llm == "openai":
|
||||||
|
llm_config = {"type": "openai", "model": args.model}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown LLM type: {args.llm}")
|
||||||
|
|
||||||
|
print(f"Using LLM: {args.llm} with model: {args.model if args.llm != 'simulated' else 'N/A'}")
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
|
||||||
|
|
||||||
# query = (
|
# query = (
|
||||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
# )
|
# )
|
||||||
|
query = args.query
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.")
|
||||||
description="Run Leann Chat with various LLM backends."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -92,6 +104,19 @@ if __name__ == "__main__":
|
|||||||
default="Qwen/Qwen3-0.6B",
|
default="Qwen/Qwen3-0.6B",
|
||||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="The embedding model to use (e.g., 'facebook/contriever', 'text-embedding-3-small').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="The embedding backend mode.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -110,6 +135,12 @@ if __name__ == "__main__":
|
|||||||
default="examples/data",
|
default="examples/data",
|
||||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
help="Directory containing documents to index (PDF, TXT, MD files).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
|
||||||
|
help="The query to ask the Leann chat system.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
asyncio.run(main(args))
|
asyncio.run(main(args))
|
||||||
|
|||||||
@@ -14,48 +14,55 @@ Key features:
|
|||||||
- Document-level result consolidation
|
- Document-level result consolidation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import json
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PatchResult:
|
class PatchResult:
|
||||||
"""Represents a single patch search result."""
|
"""Represents a single patch search result."""
|
||||||
|
|
||||||
patch_id: int
|
patch_id: int
|
||||||
image_name: str
|
image_name: str
|
||||||
image_path: str
|
image_path: str
|
||||||
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
coordinates: tuple[int, int, int, int] # (x1, y1, x2, y2)
|
||||||
score: float
|
score: float
|
||||||
attention_score: float
|
attention_score: float
|
||||||
scale: float
|
scale: float
|
||||||
metadata: Dict[str, Any]
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AggregatedResult:
|
class AggregatedResult:
|
||||||
"""Represents an aggregated document-level result."""
|
"""Represents an aggregated document-level result."""
|
||||||
|
|
||||||
image_name: str
|
image_name: str
|
||||||
image_path: str
|
image_path: str
|
||||||
doc_score: float
|
doc_score: float
|
||||||
patch_count: int
|
patch_count: int
|
||||||
best_patch: PatchResult
|
best_patch: PatchResult
|
||||||
all_patches: List[PatchResult]
|
all_patches: list[PatchResult]
|
||||||
aggregation_method: str
|
aggregation_method: str
|
||||||
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
spatial_clusters: list[list[PatchResult]] | None = None
|
||||||
|
|
||||||
|
|
||||||
class MultiVectorAggregator:
|
class MultiVectorAggregator:
|
||||||
"""
|
"""
|
||||||
Aggregates multiple patch-level results into document-level results.
|
Aggregates multiple patch-level results into document-level results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
aggregation_method: str = "maxsim",
|
self,
|
||||||
spatial_clustering: bool = True,
|
aggregation_method: str = "maxsim",
|
||||||
cluster_distance_threshold: float = 100.0):
|
spatial_clustering: bool = True,
|
||||||
|
cluster_distance_threshold: float = 100.0,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the aggregator.
|
Initialize the aggregator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
||||||
spatial_clustering: Whether to cluster spatially close patches
|
spatial_clustering: Whether to cluster spatially close patches
|
||||||
@@ -64,23 +71,23 @@ class MultiVectorAggregator:
|
|||||||
self.aggregation_method = aggregation_method
|
self.aggregation_method = aggregation_method
|
||||||
self.spatial_clustering = spatial_clustering
|
self.spatial_clustering = spatial_clustering
|
||||||
self.cluster_distance_threshold = cluster_distance_threshold
|
self.cluster_distance_threshold = cluster_distance_threshold
|
||||||
|
|
||||||
def aggregate_results(self,
|
def aggregate_results(
|
||||||
search_results: List[Dict[str, Any]],
|
self, search_results: list[dict[str, Any]], top_k: int = 10
|
||||||
top_k: int = 10) -> List[AggregatedResult]:
|
) -> list[AggregatedResult]:
|
||||||
"""
|
"""
|
||||||
Aggregate patch-level search results into document-level results.
|
Aggregate patch-level search results into document-level results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_results: List of search results from LeannSearcher
|
search_results: List of search results from LeannSearcher
|
||||||
top_k: Number of top documents to return
|
top_k: Number of top documents to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of aggregated document results
|
List of aggregated document results
|
||||||
"""
|
"""
|
||||||
# Group results by image
|
# Group results by image
|
||||||
image_groups = defaultdict(list)
|
image_groups = defaultdict(list)
|
||||||
|
|
||||||
for result in search_results:
|
for result in search_results:
|
||||||
metadata = result.metadata
|
metadata = result.metadata
|
||||||
if "image_name" in metadata and "patch_id" in metadata:
|
if "image_name" in metadata and "patch_id" in metadata:
|
||||||
@@ -92,55 +99,57 @@ class MultiVectorAggregator:
|
|||||||
score=result.score,
|
score=result.score,
|
||||||
attention_score=metadata.get("attention_score", 0.0),
|
attention_score=metadata.get("attention_score", 0.0),
|
||||||
scale=metadata.get("scale", 1.0),
|
scale=metadata.get("scale", 1.0),
|
||||||
metadata=metadata
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
image_groups[metadata["image_name"]].append(patch_result)
|
image_groups[metadata["image_name"]].append(patch_result)
|
||||||
|
|
||||||
# Aggregate each image group
|
# Aggregate each image group
|
||||||
aggregated_results = []
|
aggregated_results = []
|
||||||
for image_name, patches in image_groups.items():
|
for image_name, patches in image_groups.items():
|
||||||
if len(patches) == 0:
|
if len(patches) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
agg_result = self._aggregate_image_patches(image_name, patches)
|
||||||
aggregated_results.append(agg_result)
|
aggregated_results.append(agg_result)
|
||||||
|
|
||||||
# Sort by aggregated score and return top-k
|
# Sort by aggregated score and return top-k
|
||||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
||||||
return aggregated_results[:top_k]
|
return aggregated_results[:top_k]
|
||||||
|
|
||||||
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
def _aggregate_image_patches(
|
||||||
|
self, image_name: str, patches: list[PatchResult]
|
||||||
|
) -> AggregatedResult:
|
||||||
"""Aggregate patches for a single image."""
|
"""Aggregate patches for a single image."""
|
||||||
|
|
||||||
if self.aggregation_method == "maxsim":
|
if self.aggregation_method == "maxsim":
|
||||||
doc_score = max(patch.score for patch in patches)
|
doc_score = max(patch.score for patch in patches)
|
||||||
best_patch = max(patches, key=lambda p: p.score)
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
elif self.aggregation_method == "voting":
|
elif self.aggregation_method == "voting":
|
||||||
# Count patches above threshold
|
# Count patches above threshold
|
||||||
threshold = np.percentile([p.score for p in patches], 75)
|
threshold = np.percentile([p.score for p in patches], 75)
|
||||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
||||||
best_patch = max(patches, key=lambda p: p.score)
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
elif self.aggregation_method == "weighted":
|
elif self.aggregation_method == "weighted":
|
||||||
# Weight by attention scores
|
# Weight by attention scores
|
||||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
||||||
total_weights = sum(p.attention_score for p in patches)
|
total_weights = sum(p.attention_score for p in patches)
|
||||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
||||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
||||||
|
|
||||||
elif self.aggregation_method == "mean":
|
elif self.aggregation_method == "mean":
|
||||||
doc_score = np.mean([patch.score for patch in patches])
|
doc_score = np.mean([patch.score for patch in patches])
|
||||||
best_patch = max(patches, key=lambda p: p.score)
|
best_patch = max(patches, key=lambda p: p.score)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
||||||
|
|
||||||
# Spatial clustering if enabled
|
# Spatial clustering if enabled
|
||||||
spatial_clusters = None
|
spatial_clusters = None
|
||||||
if self.spatial_clustering:
|
if self.spatial_clustering:
|
||||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
spatial_clusters = self._cluster_patches_spatially(patches)
|
||||||
|
|
||||||
return AggregatedResult(
|
return AggregatedResult(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_path=patches[0].image_path,
|
image_path=patches[0].image_path,
|
||||||
@@ -149,23 +158,23 @@ class MultiVectorAggregator:
|
|||||||
best_patch=best_patch,
|
best_patch=best_patch,
|
||||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
||||||
aggregation_method=self.aggregation_method,
|
aggregation_method=self.aggregation_method,
|
||||||
spatial_clusters=spatial_clusters
|
spatial_clusters=spatial_clusters,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
def _cluster_patches_spatially(self, patches: list[PatchResult]) -> list[list[PatchResult]]:
|
||||||
"""Cluster patches that are spatially close to each other."""
|
"""Cluster patches that are spatially close to each other."""
|
||||||
if len(patches) <= 1:
|
if len(patches) <= 1:
|
||||||
return [patches]
|
return [patches]
|
||||||
|
|
||||||
clusters = []
|
clusters = []
|
||||||
remaining_patches = patches.copy()
|
remaining_patches = patches.copy()
|
||||||
|
|
||||||
while remaining_patches:
|
while remaining_patches:
|
||||||
# Start new cluster with highest scoring remaining patch
|
# Start new cluster with highest scoring remaining patch
|
||||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
||||||
current_cluster = [seed_patch]
|
current_cluster = [seed_patch]
|
||||||
remaining_patches.remove(seed_patch)
|
remaining_patches.remove(seed_patch)
|
||||||
|
|
||||||
# Add nearby patches to cluster
|
# Add nearby patches to cluster
|
||||||
added_to_cluster = True
|
added_to_cluster = True
|
||||||
while added_to_cluster:
|
while added_to_cluster:
|
||||||
@@ -175,145 +184,177 @@ class MultiVectorAggregator:
|
|||||||
current_cluster.append(patch)
|
current_cluster.append(patch)
|
||||||
remaining_patches.remove(patch)
|
remaining_patches.remove(patch)
|
||||||
added_to_cluster = True
|
added_to_cluster = True
|
||||||
|
|
||||||
clusters.append(current_cluster)
|
clusters.append(current_cluster)
|
||||||
|
|
||||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
||||||
|
|
||||||
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
def _is_patch_nearby(self, patch: PatchResult, cluster: list[PatchResult]) -> bool:
|
||||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
"""Check if a patch is spatially close to any patch in the cluster."""
|
||||||
patch_center = self._get_patch_center(patch.coordinates)
|
patch_center = self._get_patch_center(patch.coordinates)
|
||||||
|
|
||||||
for cluster_patch in cluster:
|
for cluster_patch in cluster:
|
||||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
||||||
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
distance = np.sqrt(
|
||||||
(patch_center[1] - cluster_center[1])**2)
|
(patch_center[0] - cluster_center[0]) ** 2
|
||||||
|
+ (patch_center[1] - cluster_center[1]) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
if distance <= self.cluster_distance_threshold:
|
if distance <= self.cluster_distance_threshold:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
def _get_patch_center(self, coordinates: tuple[int, int, int, int]) -> tuple[float, float]:
|
||||||
"""Get center point of a patch."""
|
"""Get center point of a patch."""
|
||||||
x1, y1, x2, y2 = coordinates
|
x1, y1, x2, y2 = coordinates
|
||||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||||
|
|
||||||
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
def print_aggregated_results(
|
||||||
|
self, results: list[AggregatedResult], max_patches_per_doc: int = 3
|
||||||
|
):
|
||||||
"""Pretty print aggregated results."""
|
"""Pretty print aggregated results."""
|
||||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
print(f"\n{i+1}. {result.image_name}")
|
print(f"\n{i + 1}. {result.image_name}")
|
||||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
||||||
print(f" Path: {result.image_path}")
|
print(f" Path: {result.image_path}")
|
||||||
|
|
||||||
# Show best patch
|
# Show best patch
|
||||||
best = result.best_patch
|
best = result.best_patch
|
||||||
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
print(
|
||||||
|
f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
# Show top patches
|
# Show top patches
|
||||||
print(f" 📍 Top Patches:")
|
print(" 📍 Top Patches:")
|
||||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
||||||
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
print(
|
||||||
|
f" {j + 1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}"
|
||||||
|
)
|
||||||
|
|
||||||
# Show spatial clusters if available
|
# Show spatial clusters if available
|
||||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
||||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
||||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
||||||
cluster_score = max(p.score for p in cluster)
|
cluster_score = max(p.score for p in cluster)
|
||||||
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
print(
|
||||||
|
f" Cluster {j + 1}: {len(cluster)} patches (best: {cluster_score:.4f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def demo_aggregation():
|
def demo_aggregation():
|
||||||
"""Demonstrate the multi-vector aggregation functionality."""
|
"""Demonstrate the multi-vector aggregation functionality."""
|
||||||
print("=== Multi-Vector Aggregation Demo ===")
|
print("=== Multi-Vector Aggregation Demo ===")
|
||||||
|
|
||||||
# Simulate some patch-level search results
|
# Simulate some patch-level search results
|
||||||
# In real usage, these would come from LeannSearcher.search()
|
# In real usage, these would come from LeannSearcher.search()
|
||||||
|
|
||||||
class MockResult:
|
class MockResult:
|
||||||
def __init__(self, score, metadata):
|
def __init__(self, score, metadata):
|
||||||
self.score = score
|
self.score = score
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|
||||||
# Simulate results for 2 images with multiple patches each
|
# Simulate results for 2 images with multiple patches each
|
||||||
mock_results = [
|
mock_results = [
|
||||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
# Image 1: cats_and_kitchen.jpg - 4 patches
|
||||||
MockResult(0.85, {
|
MockResult(
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
0.85,
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
{
|
||||||
"patch_id": 3,
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
"attention_score": 0.92,
|
"patch_id": 3,
|
||||||
"scale": 1.0
|
"coordinates": [100, 50, 224, 174], # Kitchen area
|
||||||
}),
|
"attention_score": 0.92,
|
||||||
MockResult(0.78, {
|
"scale": 1.0,
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
},
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
),
|
||||||
"patch_id": 7,
|
MockResult(
|
||||||
"coordinates": [200, 300, 324, 424], # Cat area
|
0.78,
|
||||||
"attention_score": 0.88,
|
{
|
||||||
"scale": 1.0
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
}),
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
MockResult(0.72, {
|
"patch_id": 7,
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
"coordinates": [200, 300, 324, 424], # Cat area
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
"attention_score": 0.88,
|
||||||
"patch_id": 12,
|
"scale": 1.0,
|
||||||
"coordinates": [150, 100, 274, 224], # Appliances
|
},
|
||||||
"attention_score": 0.75,
|
),
|
||||||
"scale": 1.0
|
MockResult(
|
||||||
}),
|
0.72,
|
||||||
MockResult(0.65, {
|
{
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
"patch_id": 15,
|
"patch_id": 12,
|
||||||
"coordinates": [50, 250, 174, 374], # Furniture
|
"coordinates": [150, 100, 274, 224], # Appliances
|
||||||
"attention_score": 0.70,
|
"attention_score": 0.75,
|
||||||
"scale": 1.0
|
"scale": 1.0,
|
||||||
}),
|
},
|
||||||
|
),
|
||||||
# Image 2: city_street.jpg - 3 patches
|
MockResult(
|
||||||
MockResult(0.68, {
|
0.65,
|
||||||
"image_name": "city_street.jpg",
|
{
|
||||||
"image_path": "/path/to/city_street.jpg",
|
"image_name": "cats_and_kitchen.jpg",
|
||||||
"patch_id": 2,
|
"image_path": "/path/to/cats_and_kitchen.jpg",
|
||||||
"coordinates": [300, 100, 424, 224], # Buildings
|
"patch_id": 15,
|
||||||
"attention_score": 0.80,
|
"coordinates": [50, 250, 174, 374], # Furniture
|
||||||
"scale": 1.0
|
"attention_score": 0.70,
|
||||||
}),
|
"scale": 1.0,
|
||||||
MockResult(0.62, {
|
},
|
||||||
"image_name": "city_street.jpg",
|
),
|
||||||
"image_path": "/path/to/city_street.jpg",
|
# Image 2: city_street.jpg - 3 patches
|
||||||
"patch_id": 8,
|
MockResult(
|
||||||
"coordinates": [100, 350, 224, 474], # Street level
|
0.68,
|
||||||
"attention_score": 0.75,
|
{
|
||||||
"scale": 1.0
|
"image_name": "city_street.jpg",
|
||||||
}),
|
"image_path": "/path/to/city_street.jpg",
|
||||||
MockResult(0.55, {
|
"patch_id": 2,
|
||||||
"image_name": "city_street.jpg",
|
"coordinates": [300, 100, 424, 224], # Buildings
|
||||||
"image_path": "/path/to/city_street.jpg",
|
"attention_score": 0.80,
|
||||||
"patch_id": 11,
|
"scale": 1.0,
|
||||||
"coordinates": [400, 200, 524, 324], # Sky area
|
},
|
||||||
"attention_score": 0.60,
|
),
|
||||||
"scale": 1.0
|
MockResult(
|
||||||
}),
|
0.62,
|
||||||
|
{
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 8,
|
||||||
|
"coordinates": [100, 350, 224, 474], # Street level
|
||||||
|
"attention_score": 0.75,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
MockResult(
|
||||||
|
0.55,
|
||||||
|
{
|
||||||
|
"image_name": "city_street.jpg",
|
||||||
|
"image_path": "/path/to/city_street.jpg",
|
||||||
|
"patch_id": 11,
|
||||||
|
"coordinates": [400, 200, 524, 324], # Sky area
|
||||||
|
"attention_score": 0.60,
|
||||||
|
"scale": 1.0,
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Test different aggregation methods
|
# Test different aggregation methods
|
||||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
methods = ["maxsim", "voting", "weighted", "mean"]
|
||||||
|
|
||||||
for method in methods:
|
for method in methods:
|
||||||
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
print(f"\n{'=' * 20} {method.upper()} AGGREGATION {'=' * 20}")
|
||||||
|
|
||||||
aggregator = MultiVectorAggregator(
|
aggregator = MultiVectorAggregator(
|
||||||
aggregation_method=method,
|
aggregation_method=method,
|
||||||
spatial_clustering=True,
|
spatial_clustering=True,
|
||||||
cluster_distance_threshold=100.0
|
cluster_distance_threshold=100.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
||||||
aggregator.print_aggregated_results(aggregated)
|
aggregator.print_aggregated_results(aggregated)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo_aggregation()
|
demo_aggregation()
|
||||||
|
|||||||
@@ -6,22 +6,24 @@ Complete example showing how to build and search with OpenAI embeddings using HN
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import dotenv
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Check if OpenAI API key is available
|
# Check if OpenAI API key is available
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
||||||
|
|
||||||
# Sample texts
|
# Sample texts
|
||||||
sample_texts = [
|
sample_texts = [
|
||||||
"Machine learning is a powerful technology that enables computers to learn from data.",
|
"Machine learning is a powerful technology that enables computers to learn from data.",
|
||||||
@@ -33,15 +35,15 @@ def main():
|
|||||||
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
||||||
"Python is a popular programming language used extensively in data science and AI.",
|
"Python is a popular programming language used extensively in data science and AI.",
|
||||||
"Neural networks are inspired by the structure and function of the human brain.",
|
"Neural networks are inspired by the structure and function of the human brain.",
|
||||||
"Big data refers to extremely large datasets that require special tools to process."
|
"Big data refers to extremely large datasets that require special tools to process.",
|
||||||
]
|
]
|
||||||
|
|
||||||
INDEX_DIR = Path("./simple_openai_test_index")
|
INDEX_DIR = Path("./simple_openai_test_index")
|
||||||
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
||||||
|
|
||||||
print(f"\n=== Building Index with OpenAI Embeddings ===")
|
print("\n=== Building Index with OpenAI Embeddings ===")
|
||||||
print(f"Index path: {INDEX_PATH}")
|
print(f"Index path: {INDEX_PATH}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use proper configuration for OpenAI embeddings
|
# Use proper configuration for OpenAI embeddings
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
@@ -49,60 +51,63 @@ def main():
|
|||||||
embedding_model="text-embedding-3-small",
|
embedding_model="text-embedding-3-small",
|
||||||
embedding_mode="openai",
|
embedding_mode="openai",
|
||||||
# HNSW settings for OpenAI embeddings
|
# HNSW settings for OpenAI embeddings
|
||||||
M=16, # Smaller graph degree
|
M=16, # Smaller graph degree
|
||||||
efConstruction=64, # Smaller construction complexity
|
efConstruction=64, # Smaller construction complexity
|
||||||
is_compact=True, # Enable compact storage for recompute
|
is_compact=True, # Enable compact storage for recompute
|
||||||
is_recompute=True, # MUST enable for OpenAI embeddings
|
is_recompute=True, # MUST enable for OpenAI embeddings
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Adding {len(sample_texts)} texts to the index...")
|
print(f"Adding {len(sample_texts)} texts to the index...")
|
||||||
for i, text in enumerate(sample_texts):
|
for i, text in enumerate(sample_texts):
|
||||||
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
||||||
builder.add_text(text, metadata)
|
builder.add_text(text, metadata)
|
||||||
|
|
||||||
print("Building index...")
|
print("Building index...")
|
||||||
builder.build_index(INDEX_PATH)
|
builder.build_index(INDEX_PATH)
|
||||||
print(f"✅ Index built successfully!")
|
print("✅ Index built successfully!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error building index: {e}")
|
print(f"❌ Error building index: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
print(f"\n=== Testing Search ===")
|
print("\n=== Testing Search ===")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
|
||||||
test_queries = [
|
test_queries = [
|
||||||
"What is machine learning?",
|
"What is machine learning?",
|
||||||
"How do neural networks work?",
|
"How do neural networks work?",
|
||||||
"Programming languages for data science"
|
"Programming languages for data science",
|
||||||
]
|
]
|
||||||
|
|
||||||
for query in test_queries:
|
for query in test_queries:
|
||||||
print(f"\n🔍 Query: '{query}'")
|
print(f"\n🔍 Query: '{query}'")
|
||||||
results = searcher.search(query, top_k=3)
|
results = searcher.search(query, top_k=3)
|
||||||
|
|
||||||
print(f" Found {len(results)} results:")
|
print(f" Found {len(results)} results:")
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
print(f" {i+1}. Score: {result.score:.4f}")
|
print(f" {i + 1}. Score: {result.score:.4f}")
|
||||||
print(f" Text: {result.text[:80]}...")
|
print(f" Text: {result.text[:80]}...")
|
||||||
|
|
||||||
print(f"\n✅ Search test completed successfully!")
|
print("\n✅ Search test completed successfully!")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error during search: {e}")
|
print(f"❌ Error during search: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
success = main()
|
success = main()
|
||||||
if success:
|
if success:
|
||||||
print(f"\n🎉 Simple OpenAI index test completed successfully!")
|
print("\n🎉 Simple OpenAI index test completed successfully!")
|
||||||
else:
|
else:
|
||||||
print(f"\n💥 Simple OpenAI index test failed!")
|
print("\n💥 Simple OpenAI index test failed!")
|
||||||
|
|||||||
@@ -1,18 +1,23 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from leann.api import LeannChat
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.api import LeannChat
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
INDEX_DIR = Path("./test_pdf_index_huawei")
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
chat = LeannChat(index_path=INDEX_PATH)
|
||||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
||||||
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
response = chat.ask(
|
||||||
|
query, top_k=20, recompute_beighbor_embeddings=True, complexity=32, beam_width=1
|
||||||
|
)
|
||||||
print(f"\n[PHASE 2] Response: {response}")
|
print(f"\n[PHASE 2] Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
|||||||
results and the golden standard results, making the comparison robust to ID changes.
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher, LeannBuilder
|
import numpy as np
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
print(
|
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -63,7 +60,7 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||||
"""Download embeddings files specifically."""
|
"""Download embeddings files specifically."""
|
||||||
embeddings_dir = data_root / "embeddings"
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
|||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||||
"""
|
"""
|
||||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
passage manager.
|
passage manager.
|
||||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
|||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
golden_texts.add(passage_data["text"])
|
golden_texts.add(passage_data["text"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(
|
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
|
||||||
)
|
|
||||||
return golden_texts
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> List[str]:
|
def load_queries(file_path: Path) -> list[str]:
|
||||||
queries = []
|
queries = []
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
queries.append(data["query"])
|
queries.append(data["query"])
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
def build_index_from_embeddings(
|
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Build a LEANN index from pre-computed embeddings.
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||||
description="Run recall evaluation on a LEANN index."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"index_path",
|
"index_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -202,9 +193,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
@@ -219,9 +208,7 @@ def main():
|
|||||||
# Download data based on mode
|
# Download data based on mode
|
||||||
if args.mode == "build":
|
if args.mode == "build":
|
||||||
# For building mode, we need embeddings
|
# For building mode, we need embeddings
|
||||||
download_data_if_needed(
|
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||||
data_root, download_embeddings=False
|
|
||||||
) # Basic data first
|
|
||||||
|
|
||||||
# Auto-detect dataset type and download embeddings
|
# Auto-detect dataset type and download embeddings
|
||||||
if args.embeddings_file:
|
if args.embeddings_file:
|
||||||
@@ -262,9 +249,7 @@ def main():
|
|||||||
print(f"Index built successfully: {built_index_path}")
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
# Ask if user wants to run evaluation
|
# Ask if user wants to run evaluation
|
||||||
eval_response = (
|
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
|
||||||
)
|
|
||||||
if eval_response != "y":
|
if eval_response != "y":
|
||||||
print("Index building complete. Exiting.")
|
print("Index building complete. Exiting.")
|
||||||
return
|
return
|
||||||
@@ -293,12 +278,8 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not args.index_path:
|
if not args.index_path:
|
||||||
print(
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
"No indices found. The data download should have included pre-built indices."
|
print("Please check the data/indices/ directory or provide --index-path manually.")
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Please check the data/indices/ directory or provide --index-path manually."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Detect dataset type from index path to select the correct ground truth
|
# Detect dataset type from index path to select the correct ground truth
|
||||||
@@ -310,14 +291,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# Fallback: try to infer from the index directory name
|
# Fallback: try to infer from the index directory name
|
||||||
dataset_type = Path(args.index_path).name
|
dataset_type = Path(args.index_path).name
|
||||||
print(
|
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
golden_results_file = (
|
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
@@ -327,7 +304,7 @@ def main():
|
|||||||
searcher = LeannSearcher(args.index_path)
|
searcher = LeannSearcher(args.index_path)
|
||||||
queries = load_queries(queries_file)
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
with open(golden_results_file, "r") as f:
|
with open(golden_results_file) as f:
|
||||||
golden_results_data = json.load(f)
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
@@ -339,9 +316,7 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(
|
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
||||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
|
||||||
)
|
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
|
|||||||
@@ -4,18 +4,25 @@ Run: uv run python examples/simple_demo.py
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
description="Simple demo of Leann with selectable embedding models."
|
||||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Sample knowledge base
|
# Sample knowledge base
|
||||||
chunks = [
|
chunks = [
|
||||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
||||||
@@ -27,7 +34,7 @@ def main():
|
|||||||
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
||||||
"Cloud computing provides on-demand access to computing resources over the internet.",
|
"Cloud computing provides on-demand access to computing resources over the internet.",
|
||||||
]
|
]
|
||||||
|
|
||||||
print("1. Building index (no embeddings stored)...")
|
print("1. Building index (no embeddings stored)...")
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
@@ -37,45 +44,45 @@ def main():
|
|||||||
builder.add_text(chunk)
|
builder.add_text(chunk)
|
||||||
builder.build_index("demo_knowledge.leann")
|
builder.build_index("demo_knowledge.leann")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("2. Searching with real-time embeddings...")
|
print("2. Searching with real-time embeddings...")
|
||||||
searcher = LeannSearcher("demo_knowledge.leann")
|
searcher = LeannSearcher("demo_knowledge.leann")
|
||||||
|
|
||||||
queries = [
|
queries = [
|
||||||
"What is machine learning?",
|
"What is machine learning?",
|
||||||
"How does neural network work?",
|
"How does neural network work?",
|
||||||
"Tell me about data processing",
|
"Tell me about data processing",
|
||||||
]
|
]
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
print(f"Query: {query}")
|
print(f"Query: {query}")
|
||||||
results = searcher.search(query, top_k=2)
|
results = searcher.search(query, top_k=2)
|
||||||
|
|
||||||
for i, result in enumerate(results, 1):
|
for i, result in enumerate(results, 1):
|
||||||
print(f" {i}. Score: {result.score:.3f}")
|
print(f" {i}. Score: {result.score:.3f}")
|
||||||
print(f" Text: {result.text[:100]}...")
|
print(f" Text: {result.text[:100]}...")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("3. Interactive chat demo:")
|
print("3. Interactive chat demo:")
|
||||||
print(" (Note: Requires OpenAI API key for real responses)")
|
print(" (Note: Requires OpenAI API key for real responses)")
|
||||||
|
|
||||||
chat = LeannChat("demo_knowledge.leann")
|
chat = LeannChat("demo_knowledge.leann")
|
||||||
|
|
||||||
# Demo questions
|
# Demo questions
|
||||||
demo_questions: list[str] = [
|
demo_questions: list[str] = [
|
||||||
"What is the difference between machine learning and deep learning?",
|
"What is the difference between machine learning and deep learning?",
|
||||||
"How is data science related to big data?",
|
"How is data science related to big data?",
|
||||||
]
|
]
|
||||||
|
|
||||||
for question in demo_questions:
|
for question in demo_questions:
|
||||||
print(f" Q: {question}")
|
print(f" Q: {question}")
|
||||||
response = chat.ask(question)
|
response = chat.ask(question)
|
||||||
print(f" A: {response}")
|
print(f" A: {response}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("Demo completed! Try running:")
|
print("Demo completed! Try running:")
|
||||||
print(" uv run python examples/document_search.py")
|
print(" uv run python examples/document_search.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Optional
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -16,7 +14,7 @@ DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
|||||||
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_wechat_exports(
|
def create_leann_index_from_multiple_wechat_exports(
|
||||||
export_dirs: List[Path],
|
export_dirs: list[Path],
|
||||||
index_path: str = "wechat_history_index.leann",
|
index_path: str = "wechat_history_index.leann",
|
||||||
max_count: int = -1,
|
max_count: int = -1,
|
||||||
):
|
):
|
||||||
@@ -38,15 +36,13 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
all_documents = []
|
all_documents = []
|
||||||
total_processed = 0
|
total_processed = 0
|
||||||
|
|
||||||
# Process each WeChat export directory
|
# Process each WeChat export directory
|
||||||
for i, export_dir in enumerate(export_dirs):
|
for i, export_dir in enumerate(export_dirs):
|
||||||
print(
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents = reader.load_data(
|
documents = reader.load_data(
|
||||||
@@ -78,7 +74,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
# Create text splitter with 256 chunk size
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
# Convert Documents to text strings and chunk them
|
||||||
all_texts = []
|
all_texts = []
|
||||||
@@ -86,7 +82,12 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
# Split the document into chunks
|
# Split the document into chunks
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
text = (
|
||||||
|
"[Contact] means the message is from: "
|
||||||
|
+ doc.metadata["contact_name"]
|
||||||
|
+ "\n"
|
||||||
|
+ node.get_content()
|
||||||
|
)
|
||||||
all_texts.append(text)
|
all_texts.append(text)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
@@ -94,12 +95,12 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
@@ -125,7 +126,7 @@ def create_leann_index_from_multiple_wechat_exports(
|
|||||||
|
|
||||||
|
|
||||||
def create_leann_index(
|
def create_leann_index(
|
||||||
export_dir: str = None,
|
export_dir: str | None = None,
|
||||||
index_path: str = "wechat_history_index.leann",
|
index_path: str = "wechat_history_index.leann",
|
||||||
max_count: int = 1000,
|
max_count: int = 1000,
|
||||||
):
|
):
|
||||||
@@ -141,12 +142,12 @@ def create_leann_index(
|
|||||||
INDEX_DIR = Path(index_path).parent
|
INDEX_DIR = Path(index_path).parent
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
if not INDEX_DIR.exists():
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Load documents using WeChatHistoryReader from history_data
|
# Load documents using WeChatHistoryReader from history_data
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
from history_data.wechat_history import WeChatHistoryReader
|
||||||
@@ -179,12 +180,12 @@ def create_leann_index(
|
|||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
||||||
|
|
||||||
# Create LEANN index directory
|
# Create LEANN index directory
|
||||||
print(f"--- Index directory not found, building new index ---")
|
print("--- Index directory not found, building new index ---")
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
INDEX_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
print("--- Building new LEANN index ---")
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
print("\n[PHASE 1] Building Leann index...")
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
# Use HNSW backend for better macOS compatibility
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
@@ -217,7 +218,7 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
index_path: Path to the LEANN index
|
index_path: Path to the LEANN index
|
||||||
query: The query string
|
query: The query string
|
||||||
"""
|
"""
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
print("\n[PHASE 2] Starting Leann chat session...")
|
||||||
chat = LeannChat(index_path=index_path)
|
chat = LeannChat(index_path=index_path)
|
||||||
|
|
||||||
print(f"You: {query}")
|
print(f"You: {query}")
|
||||||
@@ -234,7 +235,7 @@ async def query_leann_index(index_path: str, query: str):
|
|||||||
},
|
},
|
||||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
||||||
)
|
)
|
||||||
print(f"Leann: {chat_response}")
|
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@@ -307,7 +308,7 @@ async def main():
|
|||||||
else:
|
else:
|
||||||
# Example queries
|
# Example queries
|
||||||
queries = [
|
queries = [
|
||||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
||||||
]
|
]
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
# This file makes the directory a Python package
|
# This file makes the directory a Python package
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from . import diskann_backend
|
from . import diskann_backend as diskann_backend
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
import numpy as np
|
import contextlib
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Literal, Optional
|
from typing import Any, Literal
|
||||||
import contextlib
|
|
||||||
|
|
||||||
import logging
|
import numpy as np
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
|
||||||
from leann.registry import register_backend
|
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendFactoryInterface,
|
|
||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
|
LeannBackendFactoryInterface,
|
||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
|
from leann.registry import register_backend
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,7 +99,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.build_params = kwargs
|
self.build_params = kwargs
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
index_dir = path.parent
|
index_dir = path.parent
|
||||||
index_prefix = path.stem
|
index_prefix = path.stem
|
||||||
@@ -164,18 +163,44 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
|
|
||||||
self.num_threads = kwargs.get("num_threads", 8)
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
|
|
||||||
fake_zmq_port = 6666
|
# For DiskANN, we need to reinitialize the index when zmq_port changes
|
||||||
|
# Store the initialization parameters for later use
|
||||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
self._init_params = {
|
||||||
metric_enum,
|
"metric_enum": metric_enum,
|
||||||
full_index_prefix,
|
"full_index_prefix": full_index_prefix,
|
||||||
self.num_threads,
|
"num_threads": self.num_threads,
|
||||||
kwargs.get("num_nodes_to_cache", 0),
|
"num_nodes_to_cache": kwargs.get("num_nodes_to_cache", 0),
|
||||||
1,
|
"cache_mechanism": 1,
|
||||||
fake_zmq_port, # Initial port, can be updated at runtime
|
"pq_prefix": "",
|
||||||
"",
|
"partition_prefix": "",
|
||||||
"",
|
}
|
||||||
)
|
self._diskannpy = diskannpy
|
||||||
|
self._current_zmq_port = None
|
||||||
|
self._index = None
|
||||||
|
logger.debug("DiskANN searcher initialized (index will be loaded on first search)")
|
||||||
|
|
||||||
|
def _ensure_index_loaded(self, zmq_port: int):
|
||||||
|
"""Ensure the index is loaded with the correct zmq_port."""
|
||||||
|
if self._index is None or self._current_zmq_port != zmq_port:
|
||||||
|
# Need to (re)load the index with the correct zmq_port
|
||||||
|
with suppress_cpp_output_if_needed():
|
||||||
|
if self._index is not None:
|
||||||
|
logger.debug(f"Reloading DiskANN index with new zmq_port: {zmq_port}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Loading DiskANN index with zmq_port: {zmq_port}")
|
||||||
|
|
||||||
|
self._index = self._diskannpy.StaticDiskFloatIndex(
|
||||||
|
self._init_params["metric_enum"],
|
||||||
|
self._init_params["full_index_prefix"],
|
||||||
|
self._init_params["num_threads"],
|
||||||
|
self._init_params["num_nodes_to_cache"],
|
||||||
|
self._init_params["cache_mechanism"],
|
||||||
|
zmq_port,
|
||||||
|
self._init_params["pq_prefix"],
|
||||||
|
self._init_params["partition_prefix"],
|
||||||
|
)
|
||||||
|
self._current_zmq_port = zmq_port
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -186,11 +211,11 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for nearest neighbors using DiskANN index.
|
Search for nearest neighbors using DiskANN index.
|
||||||
|
|
||||||
@@ -213,18 +238,15 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
"""
|
"""
|
||||||
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
# Handle zmq_port compatibility: Ensure index is loaded with correct port
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError(
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
"zmq_port must be provided if recompute_embeddings is True"
|
self._ensure_index_loaded(zmq_port)
|
||||||
)
|
else:
|
||||||
current_port = self._index.get_zmq_port()
|
# If not recomputing, we still need an index, use a default port
|
||||||
if zmq_port != current_port:
|
if self._index is None:
|
||||||
logger.debug(
|
self._ensure_index_loaded(6666) # Default port when not recomputing
|
||||||
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
|
||||||
)
|
|
||||||
self._index.set_zmq_port(zmq_port)
|
|
||||||
|
|
||||||
# DiskANN doesn't support "proportional" strategy
|
# DiskANN doesn't support "proportional" strategy
|
||||||
if pruning_strategy == "proportional":
|
if pruning_strategy == "proportional":
|
||||||
@@ -259,8 +281,6 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
use_global_pruning,
|
use_global_pruning,
|
||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ DiskANN-specific embedding server
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import os
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
import sys
|
import numpy as np
|
||||||
import logging
|
import zmq
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -32,10 +32,11 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
distance_metric: str = "l2",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create and start a ZMQ-based embedding server for DiskANN backend.
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||||
@@ -50,8 +51,8 @@ def create_diskann_embedding_server(
|
|||||||
sys.path.insert(0, str(leann_core_path))
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from leann.embedding_compute import compute_embeddings
|
|
||||||
from leann.api import PassageManager
|
from leann.api import PassageManager
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
logger.info("Successfully imported unified embedding computation module")
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -76,7 +77,7 @@ def create_diskann_embedding_server(
|
|||||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
# Load metadata to get passage sources
|
# Load metadata to get passage sources
|
||||||
with open(passages_file, "r") as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
passages = PassageManager(meta["passage_sources"])
|
||||||
@@ -150,9 +151,7 @@ def create_diskann_embedding_server(
|
|||||||
):
|
):
|
||||||
texts = request
|
texts = request
|
||||||
is_text_request = True
|
is_text_request = True
|
||||||
logger.info(
|
logger.info(f"✅ MSGPACK: Direct text request for {len(texts)} texts")
|
||||||
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not a valid msgpack text request")
|
raise ValueError("Not a valid msgpack text request")
|
||||||
except Exception as msgpack_error:
|
except Exception as msgpack_error:
|
||||||
@@ -167,9 +166,7 @@ def create_diskann_embedding_server(
|
|||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data["text"]
|
||||||
if not txt:
|
if not txt:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
f"FATAL: Empty text for passage ID {nid}"
|
|
||||||
)
|
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(f"Passage ID {nid} not found: {e}")
|
logger.error(f"Passage ID {nid} not found: {e}")
|
||||||
@@ -180,9 +177,7 @@ def create_diskann_embedding_server(
|
|||||||
|
|
||||||
# Debug logging
|
# Debug logging
|
||||||
logger.debug(f"Processing {len(texts)} texts")
|
logger.debug(f"Processing {len(texts)} texts")
|
||||||
logger.debug(
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
|
||||||
) # Show first 5
|
|
||||||
|
|
||||||
# Process embeddings using unified computation
|
# Process embeddings using unified computation
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
@@ -199,9 +194,7 @@ def create_diskann_embedding_server(
|
|||||||
else:
|
else:
|
||||||
# For DiskANN C++ compatibility: return protobuf format
|
# For DiskANN C++ compatibility: return protobuf format
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
hidden_contiguous = np.ascontiguousarray(
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
embeddings, dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize embeddings data
|
# Serialize embeddings data
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
@@ -271,6 +264,13 @@ if __name__ == "__main__":
|
|||||||
choices=["sentence-transformers", "openai", "mlx"],
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
help="Embedding backend mode",
|
help="Embedding backend mode",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
type=str,
|
||||||
|
default="l2",
|
||||||
|
choices=["l2", "mips", "cosine"],
|
||||||
|
help="Distance metric for similarity computation",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -280,4 +280,5 @@ if __name__ == "__main__":
|
|||||||
zmq_port=args.zmq_port,
|
zmq_port=args.zmq_port,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,27 +1,28 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: embedding.proto
|
# source: embedding.proto
|
||||||
|
# ruff: noqa
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
|
||||||
# @@protoc_insertion_point(imports)
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||||
|
b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3'
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
)
|
||||||
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embedding_pb2", globals())
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._options = None
|
||||||
DESCRIPTOR._options = None
|
_NODEEMBEDDINGREQUEST._serialized_start = 35
|
||||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
_NODEEMBEDDINGREQUEST._serialized_end = 75
|
||||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
_NODEEMBEDDINGRESPONSE._serialized_start = 77
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
_NODEEMBEDDINGRESPONSE._serialized_end = 166
|
||||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.1.0"
|
version = "0.1.15"
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = ["leann-core==0.1.15", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
@@ -16,4 +16,4 @@ wheel.packages = ["leann_backend_diskann"]
|
|||||||
editable.mode = "redirect"
|
editable.mode = "redirect"
|
||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
build.verbose = true
|
build.verbose = true
|
||||||
build.tool-args = ["-j8"]
|
build.tool-args = ["-j8"]
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ syntax = "proto3";
|
|||||||
|
|
||||||
package protoembedding;
|
package protoembedding;
|
||||||
|
|
||||||
message NodeEmbeddingRequest {
|
message NodeEmbeddingRequest {
|
||||||
repeated uint32 node_ids = 1;
|
repeated uint32 node_ids = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message NodeEmbeddingResponse {
|
message NodeEmbeddingResponse {
|
||||||
bytes embeddings_data = 1; // All embedded binary datas
|
bytes embeddings_data = 1; // All embedded binary datas
|
||||||
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
|
repeated int32 dimensions = 2; // Shape [batch_size, embedding_dim]
|
||||||
repeated uint32 missing_ids = 3; // Missing node ids
|
repeated uint32 missing_ids = 3; // Missing node ids
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,14 @@ if(APPLE)
|
|||||||
set(OpenMP_C_LIB_NAMES "omp")
|
set(OpenMP_C_LIB_NAMES "omp")
|
||||||
set(OpenMP_CXX_LIB_NAMES "omp")
|
set(OpenMP_CXX_LIB_NAMES "omp")
|
||||||
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib")
|
||||||
|
|
||||||
|
# Force use of system libc++ to avoid version mismatch
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -stdlib=libc++")
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++")
|
||||||
|
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -stdlib=libc++")
|
||||||
|
|
||||||
|
# Set minimum macOS version for better compatibility
|
||||||
|
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Use system ZeroMQ instead of building from source
|
# Use system ZeroMQ instead of building from source
|
||||||
@@ -52,4 +60,4 @@ set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE)
|
|||||||
# IMPORTANT: Disable building AVX versions to speed up compilation
|
# IMPORTANT: Disable building AVX versions to speed up compilation
|
||||||
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
add_subdirectory(third_party/faiss)
|
add_subdirectory(third_party/faiss)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from . import hnsw_backend
|
from . import hnsw_backend as hnsw_backend
|
||||||
|
|||||||
@@ -1,87 +1,115 @@
|
|||||||
|
import argparse
|
||||||
|
import gc # Import garbage collector interface
|
||||||
|
import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
import gc # Import garbage collector interface
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# --- FourCCs (add more if needed) ---
|
# --- FourCCs (add more if needed) ---
|
||||||
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b'IHNf', 'little')
|
INDEX_HNSW_FLAT_FOURCC = int.from_bytes(b"IHNf", "little")
|
||||||
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
# Add other HNSW fourccs if you expect different storage types inside HNSW
|
||||||
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
# INDEX_HNSW_PQ_FOURCC = int.from_bytes(b'IHNp', 'little')
|
||||||
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
# INDEX_HNSW_SQ_FOURCC = int.from_bytes(b'IHNs', 'little')
|
||||||
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
# INDEX_HNSW_CAGRA_FOURCC = int.from_bytes(b'IHNc', 'little') # Example
|
||||||
|
|
||||||
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
EXPECTED_HNSW_FOURCCS = {INDEX_HNSW_FLAT_FOURCC} # Modify if needed
|
||||||
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
|
NULL_INDEX_FOURCC = int.from_bytes(b"null", "little")
|
||||||
|
|
||||||
# --- Helper functions for reading/writing binary data ---
|
# --- Helper functions for reading/writing binary data ---
|
||||||
|
|
||||||
|
|
||||||
def read_struct(f, fmt):
|
def read_struct(f, fmt):
|
||||||
"""Reads data according to the struct format."""
|
"""Reads data according to the struct format."""
|
||||||
size = struct.calcsize(fmt)
|
size = struct.calcsize(fmt)
|
||||||
data = f.read(size)
|
data = f.read(size)
|
||||||
if len(data) != size:
|
if len(data) != size:
|
||||||
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}.")
|
raise EOFError(
|
||||||
|
f"File ended unexpectedly reading struct fmt '{fmt}'. Expected {size} bytes, got {len(data)}."
|
||||||
|
)
|
||||||
return struct.unpack(fmt, data)[0]
|
return struct.unpack(fmt, data)[0]
|
||||||
|
|
||||||
|
|
||||||
def read_vector_raw(f, element_fmt_char):
|
def read_vector_raw(f, element_fmt_char):
|
||||||
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
"""Reads a vector (size followed by data), returns count and raw bytes."""
|
||||||
count = -1 # Initialize count
|
count = -1 # Initialize count
|
||||||
total_bytes = -1 # Initialize total_bytes
|
total_bytes = -1 # Initialize total_bytes
|
||||||
try:
|
try:
|
||||||
count = read_struct(f, '<Q') # size_t usually 64-bit unsigned
|
count = read_struct(f, "<Q") # size_t usually 64-bit unsigned
|
||||||
element_size = struct.calcsize(element_fmt_char)
|
element_size = struct.calcsize(element_fmt_char)
|
||||||
# --- FIX for MemoryError: Check for unreasonably large count ---
|
# --- FIX for MemoryError: Check for unreasonably large count ---
|
||||||
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
max_reasonable_count = 10 * (10**9) # ~10 billion elements limit
|
||||||
if count > max_reasonable_count or count < 0:
|
if count > max_reasonable_count or count < 0:
|
||||||
raise MemoryError(f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read.")
|
raise MemoryError(
|
||||||
|
f"Vector count {count} seems unreasonably large, possibly due to file corruption or incorrect format read."
|
||||||
|
)
|
||||||
|
|
||||||
total_bytes = count * element_size
|
total_bytes = count * element_size
|
||||||
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
# --- FIX for MemoryError: Check for huge byte size before allocation ---
|
||||||
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
max_reasonable_bytes = 50 * (1024**3) # ~50 GB limit
|
||||||
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
if total_bytes > max_reasonable_bytes or total_bytes < 0: # Check for overflow
|
||||||
raise MemoryError(f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch.")
|
raise MemoryError(
|
||||||
|
f"Attempting to read {total_bytes} bytes ({count} elements * {element_size} bytes/element), which exceeds the safety limit. File might be corrupted or format mismatch."
|
||||||
|
)
|
||||||
|
|
||||||
data_bytes = f.read(total_bytes)
|
data_bytes = f.read(total_bytes)
|
||||||
|
|
||||||
if len(data_bytes) != total_bytes:
|
if len(data_bytes) != total_bytes:
|
||||||
raise EOFError(f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}.")
|
raise EOFError(
|
||||||
|
f"File ended unexpectedly reading vector data. Expected {total_bytes} bytes, got {len(data_bytes)}."
|
||||||
|
)
|
||||||
return count, data_bytes
|
return count, data_bytes
|
||||||
except (MemoryError, OverflowError) as e:
|
except (MemoryError, OverflowError) as e:
|
||||||
# Add context to the error message
|
# Add context to the error message
|
||||||
print(f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}", file=sys.stderr)
|
print(
|
||||||
raise e # Re-raise the original error type
|
f"\nError during raw vector read (element_fmt='{element_fmt_char}', count={count}, total_bytes={total_bytes}): {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
raise e # Re-raise the original error type
|
||||||
|
|
||||||
|
|
||||||
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
def read_numpy_vector(f, np_dtype, struct_fmt_char):
|
||||||
"""Reads a vector into a NumPy array."""
|
"""Reads a vector into a NumPy array."""
|
||||||
count = -1 # Initialize count for robust error handling
|
count = -1 # Initialize count for robust error handling
|
||||||
print(f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ", end='', flush=True)
|
print(
|
||||||
|
f" Reading vector (dtype={np_dtype}, fmt='{struct_fmt_char}')... ",
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
count, data_bytes = read_vector_raw(f, struct_fmt_char)
|
||||||
print(f"Count={count}, Bytes={len(data_bytes)}")
|
print(f"Count={count}, Bytes={len(data_bytes)}")
|
||||||
if count > 0 and len(data_bytes) > 0:
|
if count > 0 and len(data_bytes) > 0:
|
||||||
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
arr = np.frombuffer(data_bytes, dtype=np_dtype)
|
||||||
if arr.size != count:
|
if arr.size != count:
|
||||||
raise ValueError(f"Inconsistent array size after reading. Expected {count}, got {arr.size}")
|
raise ValueError(
|
||||||
|
f"Inconsistent array size after reading. Expected {count}, got {arr.size}"
|
||||||
|
)
|
||||||
return arr
|
return arr
|
||||||
elif count == 0:
|
elif count == 0:
|
||||||
return np.array([], dtype=np_dtype)
|
return np.array([], dtype=np_dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Read zero bytes but count > 0.")
|
raise ValueError("Read zero bytes but count > 0.")
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
# Now count should be defined (or -1 if error was in read_struct)
|
# Now count should be defined (or -1 if error was in read_struct)
|
||||||
print(f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}", file=sys.stderr)
|
print(
|
||||||
|
f"\nMemoryError creating NumPy array (dtype={np_dtype}, count={count}). {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e: # Catch other potential errors like ValueError
|
except Exception as e: # Catch other potential errors like ValueError
|
||||||
print(f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}", file=sys.stderr)
|
print(
|
||||||
|
f"\nError reading numpy vector (dtype={np_dtype}, fmt='{struct_fmt_char}', count={count}): {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def write_numpy_vector(f, arr, struct_fmt_char):
|
def write_numpy_vector(f, arr, struct_fmt_char):
|
||||||
"""Writes a NumPy array as a vector (size followed by data)."""
|
"""Writes a NumPy array as a vector (size followed by data)."""
|
||||||
count = arr.size
|
count = arr.size
|
||||||
f.write(struct.pack('<Q', count))
|
f.write(struct.pack("<Q", count))
|
||||||
try:
|
try:
|
||||||
expected_dtype = np.dtype(struct_fmt_char)
|
expected_dtype = np.dtype(struct_fmt_char)
|
||||||
if arr.dtype != expected_dtype:
|
if arr.dtype != expected_dtype:
|
||||||
@@ -89,23 +117,30 @@ def write_numpy_vector(f, arr, struct_fmt_char):
|
|||||||
else:
|
else:
|
||||||
data_to_write = arr.tobytes()
|
data_to_write = arr.tobytes()
|
||||||
f.write(data_to_write)
|
f.write(data_to_write)
|
||||||
del data_to_write # Hint GC
|
del data_to_write # Hint GC
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}", file=sys.stderr)
|
print(
|
||||||
raise e
|
f"\nMemoryError converting NumPy array to bytes for writing (size={count}, dtype={arr.dtype}). {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def write_list_vector(f, lst, struct_fmt_char):
|
def write_list_vector(f, lst, struct_fmt_char):
|
||||||
"""Writes a Python list as a vector iteratively."""
|
"""Writes a Python list as a vector iteratively."""
|
||||||
count = len(lst)
|
count = len(lst)
|
||||||
f.write(struct.pack('<Q', count))
|
f.write(struct.pack("<Q", count))
|
||||||
fmt = '<' + struct_fmt_char
|
fmt = "<" + struct_fmt_char
|
||||||
chunk_size = 1024 * 1024
|
chunk_size = 1024 * 1024
|
||||||
element_size = struct.calcsize(fmt)
|
element_size = struct.calcsize(fmt)
|
||||||
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
# Allocate buffer outside the loop if possible, or handle MemoryError during allocation
|
||||||
try:
|
try:
|
||||||
buffer = bytearray(chunk_size * element_size)
|
buffer = bytearray(chunk_size * element_size)
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
print(f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).", file=sys.stderr)
|
print(
|
||||||
|
f"MemoryError: Cannot allocate buffer for writing list vector chunk (size {chunk_size * element_size} bytes).",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
buffer_count = 0
|
buffer_count = 0
|
||||||
|
|
||||||
@@ -116,66 +151,80 @@ def write_list_vector(f, lst, struct_fmt_char):
|
|||||||
buffer_count += 1
|
buffer_count += 1
|
||||||
|
|
||||||
if buffer_count == chunk_size or i == count - 1:
|
if buffer_count == chunk_size or i == count - 1:
|
||||||
f.write(buffer[:buffer_count * element_size])
|
f.write(buffer[: buffer_count * element_size])
|
||||||
buffer_count = 0
|
buffer_count = 0
|
||||||
|
|
||||||
except struct.error as e:
|
except struct.error as e:
|
||||||
print(f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}", file=sys.stderr)
|
print(
|
||||||
|
f"\nStruct packing error for item {item} at index {i} with format '{fmt}'. {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
def get_cum_neighbors(cum_nneighbor_per_level_np, level):
|
||||||
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
"""Helper to get cumulative neighbors count, matching C++ logic."""
|
||||||
if level < 0: return 0
|
if level < 0:
|
||||||
|
return 0
|
||||||
if level < len(cum_nneighbor_per_level_np):
|
if level < len(cum_nneighbor_per_level_np):
|
||||||
return cum_nneighbor_per_level_np[level]
|
return cum_nneighbor_per_level_np[level]
|
||||||
else:
|
else:
|
||||||
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
return cum_nneighbor_per_level_np[-1] if len(cum_nneighbor_per_level_np) > 0 else 0
|
||||||
|
|
||||||
def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
def write_compact_format(
|
||||||
compact_neighbors_data, storage_fourcc, storage_data):
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
storage_fourcc,
|
||||||
|
storage_data,
|
||||||
|
):
|
||||||
"""Write HNSW data in compact format following C++ read order exactly."""
|
"""Write HNSW data in compact format following C++ read order exactly."""
|
||||||
# Write IndexHNSW Header
|
# Write IndexHNSW Header
|
||||||
f_out.write(struct.pack('<I', original_hnsw_data['index_fourcc']))
|
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['d']))
|
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||||
f_out.write(struct.pack('<q', original_hnsw_data['ntotal']))
|
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy1']))
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||||
f_out.write(struct.pack('<q', original_hnsw_data['dummy2']))
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||||
f_out.write(struct.pack('<?', original_hnsw_data['is_trained']))
|
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['metric_type']))
|
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||||
if original_hnsw_data['metric_type'] > 1:
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
f_out.write(struct.pack('<f', original_hnsw_data['metric_arg']))
|
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||||
|
|
||||||
# Write HNSW struct parts (standard order)
|
# Write HNSW struct parts (standard order)
|
||||||
write_numpy_vector(f_out, assign_probas_np, 'd')
|
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||||
write_numpy_vector(f_out, cum_nneighbor_per_level_np, 'i')
|
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||||
write_numpy_vector(f_out, levels_np, 'i')
|
write_numpy_vector(f_out, levels_np, "i")
|
||||||
|
|
||||||
# Write compact format flag
|
# Write compact format flag
|
||||||
f_out.write(struct.pack('<?', True)) # storage_is_compact = True
|
f_out.write(struct.pack("<?", True)) # storage_is_compact = True
|
||||||
|
|
||||||
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
# Write compact data in CORRECT C++ read order: level_ptr, node_offsets FIRST
|
||||||
if isinstance(compact_level_ptr, np.ndarray):
|
if isinstance(compact_level_ptr, np.ndarray):
|
||||||
write_numpy_vector(f_out, compact_level_ptr, 'Q')
|
write_numpy_vector(f_out, compact_level_ptr, "Q")
|
||||||
else:
|
else:
|
||||||
write_list_vector(f_out, compact_level_ptr, 'Q')
|
write_list_vector(f_out, compact_level_ptr, "Q")
|
||||||
|
|
||||||
write_numpy_vector(f_out, compact_node_offsets_np, 'Q')
|
write_numpy_vector(f_out, compact_node_offsets_np, "Q")
|
||||||
|
|
||||||
# Write HNSW scalar parameters
|
# Write HNSW scalar parameters
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['entry_point']))
|
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['max_level']))
|
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['efConstruction']))
|
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['efSearch']))
|
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||||
f_out.write(struct.pack('<i', original_hnsw_data['dummy_upper_beam']))
|
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||||
|
|
||||||
# Write storage fourcc (this determines how to read what follows)
|
# Write storage fourcc (this determines how to read what follows)
|
||||||
f_out.write(struct.pack('<I', storage_fourcc))
|
f_out.write(struct.pack("<I", storage_fourcc))
|
||||||
|
|
||||||
# Write compact neighbors data AFTER storage fourcc
|
# Write compact neighbors data AFTER storage fourcc
|
||||||
write_list_vector(f_out, compact_neighbors_data, 'i')
|
write_list_vector(f_out, compact_neighbors_data, "i")
|
||||||
|
|
||||||
# Write storage data if not NULL (only after neighbors)
|
# Write storage data if not NULL (only after neighbors)
|
||||||
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||||
f_out.write(storage_data)
|
f_out.write(storage_data)
|
||||||
@@ -183,11 +232,12 @@ def write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneigh
|
|||||||
|
|
||||||
# --- Main Conversion Logic ---
|
# --- Main Conversion Logic ---
|
||||||
|
|
||||||
|
|
||||||
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=True):
|
||||||
"""
|
"""
|
||||||
Converts an HNSW graph file to the CSR format.
|
Converts an HNSW graph file to the CSR format.
|
||||||
Supports both original and already-compact formats (backward compatibility).
|
Supports both original and already-compact formats (backward compatibility).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_filename: Input HNSW index file
|
input_filename: Input HNSW index file
|
||||||
output_filename: Output CSR index file
|
output_filename: Output CSR index file
|
||||||
@@ -196,172 +246,228 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
print(f"Starting conversion: {input_filename} -> {output_filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
original_hnsw_data = {}
|
original_hnsw_data = {}
|
||||||
neighbors_np = None # Initialize to allow check in finally block
|
neighbors_np = None # Initialize to allow check in finally block
|
||||||
try:
|
try:
|
||||||
with open(input_filename, 'rb') as f_in, open(output_filename, 'wb') as f_out:
|
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||||
|
|
||||||
# --- Read IndexHNSW FourCC and Header ---
|
# --- Read IndexHNSW FourCC and Header ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
print(f"[{time.time() - start_time:.2f}s] Reading Index HNSW header...")
|
||||||
# ... (Keep the header reading logic as before) ...
|
# ... (Keep the header reading logic as before) ...
|
||||||
hnsw_index_fourcc = read_struct(f_in, '<I')
|
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||||
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
print(f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.", file=sys.stderr)
|
print(
|
||||||
return False
|
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||||
original_hnsw_data['index_fourcc'] = hnsw_index_fourcc
|
file=sys.stderr,
|
||||||
original_hnsw_data['d'] = read_struct(f_in, '<i')
|
)
|
||||||
original_hnsw_data['ntotal'] = read_struct(f_in, '<q')
|
return False
|
||||||
original_hnsw_data['dummy1'] = read_struct(f_in, '<q')
|
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||||
original_hnsw_data['dummy2'] = read_struct(f_in, '<q')
|
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['is_trained'] = read_struct(f_in, '?')
|
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||||
original_hnsw_data['metric_type'] = read_struct(f_in, '<i')
|
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||||
original_hnsw_data['metric_arg'] = 0.0
|
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||||
if original_hnsw_data['metric_type'] > 1:
|
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||||
original_hnsw_data['metric_arg'] = read_struct(f_in, '<f')
|
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}")
|
original_hnsw_data["metric_arg"] = 0.0
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||||
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Header read: d={original_hnsw_data['d']}, ntotal={original_hnsw_data['ntotal']}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Read original HNSW struct data ---
|
# --- Read original HNSW struct data ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
print(f"[{time.time() - start_time:.2f}s] Reading HNSW struct vectors...")
|
||||||
assign_probas_np = read_numpy_vector(f_in, np.float64, 'd')
|
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read assign_probas ({assign_probas_np.size})"
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, 'i')
|
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read cum_nneighbor_per_level ({cum_nneighbor_per_level_np.size})"
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
levels_np = read_numpy_vector(f_in, np.int32, 'i')
|
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read levels ({levels_np.size})")
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
ntotal = len(levels_np)
|
ntotal = len(levels_np)
|
||||||
if ntotal != original_hnsw_data['ntotal']:
|
if ntotal != original_hnsw_data["ntotal"]:
|
||||||
print(f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.", file=sys.stderr)
|
print(
|
||||||
original_hnsw_data['ntotal'] = ntotal
|
f"Warning: ntotal mismatch! Header says {original_hnsw_data['ntotal']}, levels vector size is {ntotal}. Using levels vector size.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
original_hnsw_data["ntotal"] = ntotal
|
||||||
|
|
||||||
# --- Check for compact format flag ---
|
# --- Check for compact format flag ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
print(f"[{time.time() - start_time:.2f}s] Probing for compact storage flag...")
|
||||||
pos_before_compact = f_in.tell()
|
pos_before_compact = f_in.tell()
|
||||||
try:
|
try:
|
||||||
is_compact_flag = read_struct(f_in, '<?')
|
is_compact_flag = read_struct(f_in, "<?")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
print(f"[{time.time() - start_time:.2f}s] Found compact flag: {is_compact_flag}")
|
||||||
|
|
||||||
if is_compact_flag:
|
if is_compact_flag:
|
||||||
# Input is already in compact format - read compact data
|
# Input is already in compact format - read compact data
|
||||||
print(f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Input is already in compact format, reading compact data..."
|
||||||
compact_level_ptr = read_numpy_vector(f_in, np.uint64, 'Q')
|
)
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})")
|
|
||||||
|
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
print(
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})")
|
f"[{time.time() - start_time:.2f}s] Read compact_level_ptr ({compact_level_ptr.size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read compact_node_offsets ({compact_node_offsets_np.size})"
|
||||||
|
)
|
||||||
|
|
||||||
# Read scalar parameters
|
# Read scalar parameters
|
||||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||||
|
)
|
||||||
|
|
||||||
# Read storage fourcc
|
# Read storage fourcc
|
||||||
storage_fourcc = read_struct(f_in, '<I')
|
storage_fourcc = read_struct(f_in, "<I")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}"
|
||||||
|
)
|
||||||
|
|
||||||
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
if prune_embeddings and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
# Read compact neighbors data
|
# Read compact neighbors data
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read compact neighbors data ({compact_neighbors_data_np.size})"
|
||||||
|
)
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
|
|
||||||
# Skip storage data and write with NULL marker
|
# Skip storage data and write with NULL marker
|
||||||
print(f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker.")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Pruning embeddings: Writing NULL storage marker."
|
||||||
|
)
|
||||||
storage_fourcc = NULL_INDEX_FOURCC
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
elif not prune_embeddings:
|
elif not prune_embeddings:
|
||||||
# Read and preserve compact neighbors and storage
|
# Read and preserve compact neighbors and storage
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
|
|
||||||
# Read remaining storage data
|
# Read remaining storage data
|
||||||
storage_data = f_in.read()
|
storage_data = f_in.read()
|
||||||
else:
|
else:
|
||||||
# Already pruned (NULL storage)
|
# Already pruned (NULL storage)
|
||||||
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, 'i')
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
del compact_neighbors_data_np
|
del compact_neighbors_data_np
|
||||||
storage_data = b''
|
storage_data = b""
|
||||||
|
|
||||||
# Write the updated compact format
|
# Write the updated compact format
|
||||||
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
print(f"[{time.time() - start_time:.2f}s] Writing updated compact format...")
|
||||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
write_compact_format(
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
f_out,
|
||||||
compact_neighbors_data, storage_fourcc, storage_data if not prune_embeddings else b'')
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
storage_fourcc,
|
||||||
|
storage_data if not prune_embeddings else b"",
|
||||||
|
)
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
print(f"[{time.time() - start_time:.2f}s] Conversion complete.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# is_compact=False, rewind and read original format
|
# is_compact=False, rewind and read original format
|
||||||
f_in.seek(pos_before_compact)
|
f_in.seek(pos_before_compact)
|
||||||
print(f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Compact flag is False, reading original format..."
|
||||||
|
)
|
||||||
|
|
||||||
except EOFError:
|
except EOFError:
|
||||||
# No compact flag found, assume original format
|
# No compact flag found, assume original format
|
||||||
f_in.seek(pos_before_compact)
|
f_in.seek(pos_before_compact)
|
||||||
print(f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] No compact flag found, assuming original format..."
|
||||||
|
)
|
||||||
|
|
||||||
# --- Handle potential extra byte in original format (like C++ code) ---
|
# --- Handle potential extra byte in original format (like C++ code) ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Probing for potential extra byte before non-compact offsets..."
|
||||||
|
)
|
||||||
pos_before_probe = f_in.tell()
|
pos_before_probe = f_in.tell()
|
||||||
try:
|
try:
|
||||||
suspected_flag = read_struct(f_in, '<B') # Read 1 byte
|
suspected_flag = read_struct(f_in, "<B") # Read 1 byte
|
||||||
if suspected_flag == 0x00:
|
if suspected_flag == 0x00:
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte.")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Found and consumed an unexpected 0x00 byte."
|
||||||
|
)
|
||||||
elif suspected_flag == 0x01:
|
elif suspected_flag == 0x01:
|
||||||
print(f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] ERROR: Found 0x01 but is_compact should be False"
|
||||||
|
)
|
||||||
raise ValueError("Inconsistent compact flag state")
|
raise ValueError("Inconsistent compact flag state")
|
||||||
else:
|
else:
|
||||||
# Rewind - this byte is part of offsets data
|
# Rewind - this byte is part of offsets data
|
||||||
f_in.seek(pos_before_probe)
|
f_in.seek(pos_before_probe)
|
||||||
print(f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Rewound to original position (byte was 0x{suspected_flag:02x})"
|
||||||
|
)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
f_in.seek(pos_before_probe)
|
f_in.seek(pos_before_probe)
|
||||||
print(f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] No extra byte found (EOF), proceeding with offsets read"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Read original format data ---
|
# --- Read original format data ---
|
||||||
offsets_np = read_numpy_vector(f_in, np.uint64, 'Q')
|
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read offsets ({offsets_np.size})")
|
||||||
if len(offsets_np) != ntotal + 1:
|
if len(offsets_np) != ntotal + 1:
|
||||||
raise ValueError(f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}")
|
raise ValueError(
|
||||||
|
f"Inconsistent offsets size: len(levels)={ntotal} but len(offsets)={len(offsets_np)}"
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
print(f"[{time.time() - start_time:.2f}s] Attempting to read neighbors vector...")
|
||||||
neighbors_np = read_numpy_vector(f_in, np.int32, 'i')
|
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
print(f"[{time.time() - start_time:.2f}s] Read neighbors ({neighbors_np.size})")
|
||||||
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
expected_neighbors_size = offsets_np[-1] if ntotal > 0 else 0
|
||||||
if neighbors_np.size != expected_neighbors_size:
|
if neighbors_np.size != expected_neighbors_size:
|
||||||
print(f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}.")
|
print(
|
||||||
|
f"Warning: neighbors vector size mismatch. Expected {expected_neighbors_size} based on offsets, got {neighbors_np.size}."
|
||||||
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
original_hnsw_data['entry_point'] = read_struct(f_in, '<i')
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['max_level'] = read_struct(f_in, '<i')
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efConstruction'] = read_struct(f_in, '<i')
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['efSearch'] = read_struct(f_in, '<i')
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
original_hnsw_data['dummy_upper_beam'] = read_struct(f_in, '<i')
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Read scalar params (ep={original_hnsw_data['entry_point']}, max_lvl={original_hnsw_data['max_level']})"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
print(f"[{time.time() - start_time:.2f}s] Checking for storage data...")
|
||||||
storage_fourcc = None
|
storage_fourcc = None
|
||||||
try:
|
try:
|
||||||
storage_fourcc = read_struct(f_in, '<I')
|
storage_fourcc = read_struct(f_in, "<I")
|
||||||
print(f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}.")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Found storage fourcc: {storage_fourcc:08x}."
|
||||||
|
)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
print(f"[{time.time() - start_time:.2f}s] No storage data found (EOF).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Error reading potential storage data: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Perform Conversion ---
|
# --- Perform Conversion ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
print(f"[{time.time() - start_time:.2f}s] Converting to CSR format...")
|
||||||
@@ -373,17 +479,21 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
|
|
||||||
current_level_ptr_idx = 0
|
current_level_ptr_idx = 0
|
||||||
current_data_idx = 0
|
current_data_idx = 0
|
||||||
total_valid_neighbors_counted = 0 # For validation
|
total_valid_neighbors_counted = 0 # For validation
|
||||||
|
|
||||||
# Optimize calculation by getting slices once per node if possible
|
# Optimize calculation by getting slices once per node if possible
|
||||||
for i in range(ntotal):
|
for i in range(ntotal):
|
||||||
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
if i > 0 and i % (ntotal // 100 or 1) == 0: # Log progress roughly every 1%
|
||||||
progress = (i / ntotal) * 100
|
progress = (i / ntotal) * 100
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...", end="")
|
print(
|
||||||
|
f"\r[{elapsed:.2f}s] Converting node {i}/{ntotal} ({progress:.1f}%)...",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
|
||||||
node_max_level = levels_np[i] - 1
|
node_max_level = levels_np[i] - 1
|
||||||
if node_max_level < -1: node_max_level = -1
|
if node_max_level < -1:
|
||||||
|
node_max_level = -1
|
||||||
|
|
||||||
node_ptr_start_index = current_level_ptr_idx
|
node_ptr_start_index = current_level_ptr_idx
|
||||||
compact_node_offsets_np[i] = node_ptr_start_index
|
compact_node_offsets_np[i] = node_ptr_start_index
|
||||||
@@ -394,13 +504,17 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
for level in range(node_max_level + 1):
|
for level in range(node_max_level + 1):
|
||||||
compact_level_ptr.append(current_data_idx)
|
compact_level_ptr.append(current_data_idx)
|
||||||
|
|
||||||
begin_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level)
|
begin_orig_np = original_offset_start + get_cum_neighbors(
|
||||||
end_orig_np = original_offset_start + get_cum_neighbors(cum_nneighbor_per_level_np, level + 1)
|
cum_nneighbor_per_level_np, level
|
||||||
|
)
|
||||||
|
end_orig_np = original_offset_start + get_cum_neighbors(
|
||||||
|
cum_nneighbor_per_level_np, level + 1
|
||||||
|
)
|
||||||
|
|
||||||
begin_orig = int(begin_orig_np)
|
begin_orig = int(begin_orig_np)
|
||||||
end_orig = int(end_orig_np)
|
end_orig = int(end_orig_np)
|
||||||
|
|
||||||
neighbors_len = len(neighbors_np) # Cache length
|
neighbors_len = len(neighbors_np) # Cache length
|
||||||
begin_orig = min(max(0, begin_orig), neighbors_len)
|
begin_orig = min(max(0, begin_orig), neighbors_len)
|
||||||
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
end_orig = min(max(begin_orig, end_orig), neighbors_len)
|
||||||
|
|
||||||
@@ -413,83 +527,117 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
|
|
||||||
if num_valid > 0:
|
if num_valid > 0:
|
||||||
# Append valid neighbors
|
# Append valid neighbors
|
||||||
compact_neighbors_data.extend(level_neighbors_slice[valid_neighbors_mask])
|
compact_neighbors_data.extend(
|
||||||
|
level_neighbors_slice[valid_neighbors_mask]
|
||||||
|
)
|
||||||
current_data_idx += num_valid
|
current_data_idx += num_valid
|
||||||
total_valid_neighbors_counted += num_valid
|
total_valid_neighbors_counted += num_valid
|
||||||
|
|
||||||
|
|
||||||
compact_level_ptr.append(current_data_idx)
|
compact_level_ptr.append(current_data_idx)
|
||||||
current_level_ptr_idx += num_pointers_expected
|
current_level_ptr_idx += num_pointers_expected
|
||||||
|
|
||||||
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
compact_node_offsets_np[ntotal] = current_level_ptr_idx
|
||||||
print(f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. ") # Clear progress line
|
print(
|
||||||
|
f"\r[{time.time() - start_time:.2f}s] Conversion loop finished. "
|
||||||
|
) # Clear progress line
|
||||||
|
|
||||||
# --- Validation Checks ---
|
# --- Validation Checks ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
print(f"[{time.time() - start_time:.2f}s] Running validation checks...")
|
||||||
valid_check_passed = True
|
valid_check_passed = True
|
||||||
# Check 1: Total valid neighbors count
|
# Check 1: Total valid neighbors count
|
||||||
print(f" Checking total valid neighbor count...")
|
print(" Checking total valid neighbor count...")
|
||||||
expected_valid_count = np.sum(neighbors_np >= 0)
|
expected_valid_count = np.sum(neighbors_np >= 0)
|
||||||
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
if total_valid_neighbors_counted != len(compact_neighbors_data):
|
||||||
print(f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
print(
|
||||||
valid_check_passed = False
|
f"Error: Mismatch between counted valid neighbors ({total_valid_neighbors_counted}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
valid_check_passed = False
|
||||||
if expected_valid_count != len(compact_neighbors_data):
|
if expected_valid_count != len(compact_neighbors_data):
|
||||||
print(f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
print(
|
||||||
valid_check_passed = False
|
f"Error: Mismatch between NumPy count of valid neighbors ({expected_valid_count}) and final compact_data size ({len(compact_neighbors_data)})!",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
valid_check_passed = False
|
||||||
else:
|
else:
|
||||||
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
print(f" OK: Total valid neighbors = {len(compact_neighbors_data)}")
|
||||||
|
|
||||||
# Check 2: Final pointer indices consistency
|
# Check 2: Final pointer indices consistency
|
||||||
print(f" Checking final pointer indices...")
|
print(" Checking final pointer indices...")
|
||||||
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
if compact_node_offsets_np[ntotal] != len(compact_level_ptr):
|
||||||
print(f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!", file=sys.stderr)
|
print(
|
||||||
valid_check_passed = False
|
f"Error: Final node offset ({compact_node_offsets_np[ntotal]}) doesn't match level_ptr size ({len(compact_level_ptr)})!",
|
||||||
if (len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)) or \
|
file=sys.stderr,
|
||||||
(len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
)
|
||||||
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
valid_check_passed = False
|
||||||
print(f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!", file=sys.stderr)
|
if (
|
||||||
valid_check_passed = False
|
len(compact_level_ptr) > 0 and compact_level_ptr[-1] != len(compact_neighbors_data)
|
||||||
|
) or (len(compact_level_ptr) == 0 and len(compact_neighbors_data) != 0):
|
||||||
|
last_ptr = compact_level_ptr[-1] if len(compact_level_ptr) > 0 else -1
|
||||||
|
print(
|
||||||
|
f"Error: Last level pointer ({last_ptr}) doesn't match compact_data size ({len(compact_neighbors_data)})!",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
valid_check_passed = False
|
||||||
else:
|
else:
|
||||||
print(f" OK: Final pointers match data size.")
|
print(" OK: Final pointers match data size.")
|
||||||
|
|
||||||
if not valid_check_passed:
|
if not valid_check_passed:
|
||||||
print("Error: Validation checks failed. Output file might be incorrect.", file=sys.stderr)
|
print(
|
||||||
|
"Error: Validation checks failed. Output file might be incorrect.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
# Optional: Exit here if validation fails
|
# Optional: Exit here if validation fails
|
||||||
# return False
|
# return False
|
||||||
|
|
||||||
# --- Explicitly delete large intermediate arrays ---
|
# --- Explicitly delete large intermediate arrays ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Deleting original neighbors and offsets arrays..."
|
||||||
|
)
|
||||||
del neighbors_np
|
del neighbors_np
|
||||||
del offsets_np
|
del offsets_np
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}")
|
print(
|
||||||
|
f" CSR Stats: |data|={len(compact_neighbors_data)}, |level_ptr|={len(compact_level_ptr)}"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Write CSR HNSW graph data using unified function ---
|
# --- Write CSR HNSW graph data using unified function ---
|
||||||
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
|
print(
|
||||||
|
f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order..."
|
||||||
|
)
|
||||||
|
|
||||||
# Determine storage fourcc and data based on prune_embeddings
|
# Determine storage fourcc and data based on prune_embeddings
|
||||||
if prune_embeddings:
|
if prune_embeddings:
|
||||||
print(f" Pruning embeddings: Writing NULL storage marker.")
|
print(" Pruning embeddings: Writing NULL storage marker.")
|
||||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
storage_data = b''
|
storage_data = b""
|
||||||
else:
|
else:
|
||||||
# Keep embeddings - read and preserve original storage data
|
# Keep embeddings - read and preserve original storage data
|
||||||
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
|
||||||
print(f" Preserving embeddings: Reading original storage data...")
|
print(" Preserving embeddings: Reading original storage data...")
|
||||||
storage_data = f_in.read() # Read remaining storage data
|
storage_data = f_in.read() # Read remaining storage data
|
||||||
output_storage_fourcc = storage_fourcc
|
output_storage_fourcc = storage_fourcc
|
||||||
print(f" Read {len(storage_data)} bytes of storage data")
|
print(f" Read {len(storage_data)} bytes of storage data")
|
||||||
else:
|
else:
|
||||||
print(f" No embeddings found in original file (NULL storage)")
|
print(" No embeddings found in original file (NULL storage)")
|
||||||
output_storage_fourcc = NULL_INDEX_FOURCC
|
output_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
storage_data = b''
|
storage_data = b""
|
||||||
|
|
||||||
# Use the unified write function
|
# Use the unified write function
|
||||||
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
|
write_compact_format(
|
||||||
levels_np, compact_level_ptr, compact_node_offsets_np,
|
f_out,
|
||||||
compact_neighbors_data, output_storage_fourcc, storage_data)
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
output_storage_fourcc,
|
||||||
|
storage_data,
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up memory
|
# Clean up memory
|
||||||
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
del assign_probas_np, cum_nneighbor_per_level_np, levels_np
|
||||||
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
|
del compact_neighbors_data, compact_level_ptr, compact_node_offsets_np
|
||||||
@@ -503,40 +651,66 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
print(f"Error: Input file not found: {input_filename}", file=sys.stderr)
|
||||||
return False
|
return False
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
print(f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.", file=sys.stderr)
|
print(
|
||||||
# Clean up potentially partially written output file?
|
f"\nFatal MemoryError during conversion: {e}. Insufficient RAM.",
|
||||||
try: os.remove(output_filename)
|
file=sys.stderr,
|
||||||
except OSError: pass
|
)
|
||||||
return False
|
# Clean up potentially partially written output file?
|
||||||
|
try:
|
||||||
|
os.remove(output_filename)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
except EOFError as e:
|
except EOFError as e:
|
||||||
print(f"Error: Reached end of file unexpectedly reading {input_filename}. {e}", file=sys.stderr)
|
print(
|
||||||
try: os.remove(output_filename)
|
f"Error: Reached end of file unexpectedly reading {input_filename}. {e}",
|
||||||
except OSError: pass
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.remove(output_filename)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
print(f"An unexpected error occurred during conversion: {e}", file=sys.stderr)
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
try:
|
try:
|
||||||
os.remove(output_filename)
|
os.remove(output_filename)
|
||||||
except OSError: pass
|
except OSError:
|
||||||
|
pass
|
||||||
return False
|
return False
|
||||||
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
# Ensure neighbors_np is deleted even if an error occurs after its allocation
|
||||||
finally:
|
finally:
|
||||||
if 'neighbors_np' in locals() and neighbors_np is not None:
|
try:
|
||||||
del neighbors_np
|
if "neighbors_np" in locals() and neighbors_np is not None:
|
||||||
gc.collect()
|
del neighbors_np
|
||||||
|
gc.collect()
|
||||||
|
except NameError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# --- Script Execution ---
|
# --- Script Execution ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file.")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert a Faiss IndexHNSWFlat file to a CSR-based HNSW graph file."
|
||||||
|
)
|
||||||
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
parser.add_argument("input_index_file", help="Path to the input IndexHNSWFlat file")
|
||||||
parser.add_argument("output_csr_graph_file", help="Path to write the output CSR HNSW graph file")
|
parser.add_argument(
|
||||||
parser.add_argument("--prune-embeddings", action="store_true", default=True,
|
"output_csr_graph_file", help="Path to write the output CSR HNSW graph file"
|
||||||
help="Prune embedding storage (write NULL storage marker)")
|
)
|
||||||
parser.add_argument("--keep-embeddings", action="store_true",
|
parser.add_argument(
|
||||||
help="Keep embedding storage (overrides --prune-embeddings)")
|
"--prune-embeddings",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Prune embedding storage (write NULL storage marker)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--keep-embeddings",
|
||||||
|
action="store_true",
|
||||||
|
help="Keep embedding storage (overrides --prune-embeddings)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -545,10 +719,12 @@ if __name__ == "__main__":
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
if os.path.abspath(args.input_index_file) == os.path.abspath(args.output_csr_graph_file):
|
||||||
print(f"Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
print("Error: Input and output filenames cannot be the same.", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
prune_embeddings = args.prune_embeddings and not args.keep_embeddings
|
||||||
success = convert_hnsw_graph_to_csr(args.input_index_file, args.output_csr_graph_file, prune_embeddings)
|
success = convert_hnsw_graph_to_csr(
|
||||||
|
args.input_index_file, args.output_csr_graph_file, prune_embeddings
|
||||||
|
)
|
||||||
if not success:
|
if not success:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Any, List, Literal, Optional
|
|
||||||
import shutil
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
import numpy as np
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
|
||||||
|
|
||||||
from leann.registry import register_backend
|
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
LeannBackendFactoryInterface,
|
|
||||||
LeannBackendBuilderInterface,
|
LeannBackendBuilderInterface,
|
||||||
|
LeannBackendFactoryInterface,
|
||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
|
from leann.registry import register_backend
|
||||||
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -28,6 +28,12 @@ def get_metric_map():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_l2(data: np.ndarray) -> np.ndarray:
|
||||||
|
norms = np.linalg.norm(data, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1 # Avoid division by zero
|
||||||
|
return data / norms
|
||||||
|
|
||||||
|
|
||||||
@register_backend("hnsw")
|
@register_backend("hnsw")
|
||||||
class HNSWBackend(LeannBackendFactoryInterface):
|
class HNSWBackend(LeannBackendFactoryInterface):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -48,8 +54,14 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
|
||||||
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
|
||||||
self.dimensions = self.build_params.get("dimensions")
|
self.dimensions = self.build_params.get("dimensions")
|
||||||
|
if not self.is_recompute:
|
||||||
|
if self.is_compact:
|
||||||
|
# TODO: support this case @andy
|
||||||
|
raise ValueError(
|
||||||
|
"is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index."
|
||||||
|
)
|
||||||
|
|
||||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -70,7 +82,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
index.hnsw.efConstruction = self.efConstruction
|
index.hnsw.efConstruction = self.efConstruction
|
||||||
|
|
||||||
if self.distance_metric.lower() == "cosine":
|
if self.distance_metric.lower() == "cosine":
|
||||||
faiss.normalize_L2(data)
|
data = normalize_l2(data)
|
||||||
|
|
||||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||||
index_file = index_dir / f"{index_prefix}.index"
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
@@ -92,19 +104,15 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info("✅ CSR conversion successful.")
|
logger.info("✅ CSR conversion successful.")
|
||||||
index_file_old = index_file.with_suffix(".old")
|
# index_file_old = index_file.with_suffix(".old")
|
||||||
shutil.move(str(index_file), str(index_file_old))
|
# shutil.move(str(index_file), str(index_file_old))
|
||||||
shutil.move(str(csr_temp_file), str(index_file))
|
shutil.move(str(csr_temp_file), str(index_file))
|
||||||
logger.info(
|
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
|
||||||
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Clean up and fail fast
|
# Clean up and fail fast
|
||||||
if csr_temp_file.exists():
|
if csr_temp_file.exists():
|
||||||
os.remove(csr_temp_file)
|
os.remove(csr_temp_file)
|
||||||
raise RuntimeError(
|
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
|
||||||
"CSR conversion failed - cannot proceed with compact format"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HNSWSearcher(BaseSearcher):
|
class HNSWSearcher(BaseSearcher):
|
||||||
@@ -116,7 +124,9 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
)
|
)
|
||||||
from . import faiss # type: ignore
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
self.distance_metric = (
|
||||||
|
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||||
|
)
|
||||||
metric_enum = get_metric_map().get(self.distance_metric)
|
metric_enum = get_metric_map().get(self.distance_metric)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
@@ -142,7 +152,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
@@ -150,7 +160,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for nearest neighbors using HNSW index.
|
Search for nearest neighbors using HNSW index.
|
||||||
|
|
||||||
@@ -179,23 +189,29 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
raise RuntimeError("Recompute is required for pruned index.")
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
if recompute_embeddings:
|
if recompute_embeddings:
|
||||||
if zmq_port is None:
|
if zmq_port is None:
|
||||||
raise ValueError(
|
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
|
||||||
"zmq_port must be provided if recompute_embeddings is True"
|
|
||||||
)
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
if self.distance_metric == "cosine":
|
if self.distance_metric == "cosine":
|
||||||
faiss.normalize_L2(query)
|
query = normalize_l2(query)
|
||||||
|
|
||||||
params = faiss.SearchParametersHNSW()
|
params = faiss.SearchParametersHNSW()
|
||||||
if zmq_port is not None:
|
if zmq_port is not None:
|
||||||
params.zmq_port = (
|
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||||
zmq_port # C++ code won't use this if recompute_embeddings is False
|
|
||||||
)
|
|
||||||
params.efSearch = complexity
|
params.efSearch = complexity
|
||||||
params.beam_size = beam_width
|
params.beam_size = beam_width
|
||||||
|
|
||||||
|
# For OpenAI embeddings with cosine distance, disable relative distance check
|
||||||
|
# This prevents early termination when all scores are in a narrow range
|
||||||
|
embedding_model = self.meta.get("embedding_model", "").lower()
|
||||||
|
if self.distance_metric == "cosine" and any(
|
||||||
|
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
|
||||||
|
):
|
||||||
|
params.check_relative_distance = False
|
||||||
|
else:
|
||||||
|
params.check_relative_distance = True
|
||||||
|
|
||||||
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
|
||||||
params.pq_pruning_ratio = prune_ratio
|
params.pq_pruning_ratio = prune_ratio
|
||||||
|
|
||||||
@@ -205,9 +221,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
params.send_neigh_times_ratio = 0.0
|
params.send_neigh_times_ratio = 0.0
|
||||||
elif pruning_strategy == "proportional":
|
elif pruning_strategy == "proportional":
|
||||||
params.local_prune = False
|
params.local_prune = False
|
||||||
params.send_neigh_times_ratio = (
|
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
|
||||||
1.0 # Any value > 1e-6 triggers proportional mode
|
|
||||||
)
|
|
||||||
else: # "global"
|
else: # "global"
|
||||||
params.local_prune = False
|
params.local_prune = False
|
||||||
params.send_neigh_times_ratio = 0.0
|
params.send_neigh_times_ratio = 0.0
|
||||||
@@ -228,8 +242,6 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
|
||||||
]
|
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -3,17 +3,17 @@ HNSW-specific embedding server
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import os
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
import msgpack
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
import sys
|
import msgpack
|
||||||
import logging
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -33,7 +33,7 @@ if not logger.handlers:
|
|||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: str | None = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
distance_metric: str = "mips",
|
distance_metric: str = "mips",
|
||||||
@@ -52,8 +52,8 @@ def create_hnsw_embedding_server(
|
|||||||
sys.path.insert(0, str(leann_core_path))
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from leann.embedding_compute import compute_embeddings
|
|
||||||
from leann.api import PassageManager
|
from leann.api import PassageManager
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
logger.info("Successfully imported unified embedding computation module")
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -78,10 +78,22 @@ def create_hnsw_embedding_server(
|
|||||||
raise ValueError("Only metadata files (.meta.json) are supported")
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
# Load metadata to get passage sources
|
# Load metadata to get passage sources
|
||||||
with open(passages_file, "r") as f:
|
with open(passages_file) as f:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
# Convert relative paths to absolute paths based on metadata file location
|
||||||
|
metadata_dir = Path(passages_file).parent.parent # Go up one level from the metadata file
|
||||||
|
passage_sources = []
|
||||||
|
for source in meta["passage_sources"]:
|
||||||
|
source_copy = source.copy()
|
||||||
|
# Convert relative paths to absolute paths
|
||||||
|
if not Path(source_copy["path"]).is_absolute():
|
||||||
|
source_copy["path"] = str(metadata_dir / source_copy["path"])
|
||||||
|
if not Path(source_copy["index_path"]).is_absolute():
|
||||||
|
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
|
||||||
|
passage_sources.append(source_copy)
|
||||||
|
|
||||||
|
passages = PassageManager(passage_sources)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
@@ -120,9 +132,7 @@ def create_hnsw_embedding_server(
|
|||||||
response = embeddings.tolist()
|
response = embeddings.tolist()
|
||||||
socket.send(msgpack.packb(response))
|
socket.send(msgpack.packb(response))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Handle distance calculation requests
|
||||||
@@ -148,17 +158,13 @@ def create_hnsw_embedding_server(
|
|||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
f"FATAL: Passage with ID {nid} not found"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Process embeddings
|
# Process embeddings
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
texts, model_name, mode=embedding_mode
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
@@ -172,18 +178,12 @@ def create_hnsw_embedding_server(
|
|||||||
distances = -np.dot(embeddings, query_vector)
|
distances = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
response_payload = distances.flatten().tolist()
|
response_payload = distances.flatten().tolist()
|
||||||
response_bytes = msgpack.packb(
|
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
||||||
[response_payload], use_single_float=True
|
logger.debug(f"Sending distance response with {len(distances)} distances")
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Sending distance response with {len(distances)} distances"
|
|
||||||
)
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(
|
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Standard embedding request (passage ID lookup)
|
# Standard embedding request (passage ID lookup)
|
||||||
@@ -208,9 +208,7 @@ def create_hnsw_embedding_server(
|
|||||||
passage_data = passages.get_passage(str(nid))
|
passage_data = passages.get_passage(str(nid))
|
||||||
txt = passage_data["text"]
|
txt = passage_data["text"]
|
||||||
if not txt:
|
if not txt:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
f"FATAL: Empty text for passage ID {nid}"
|
|
||||||
)
|
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
@@ -229,11 +227,9 @@ def create_hnsw_embedding_server(
|
|||||||
logger.error(
|
logger.error(
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
)
|
)
|
||||||
assert False
|
raise AssertionError()
|
||||||
|
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(
|
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
embeddings, dtype=np.float32
|
|
||||||
)
|
|
||||||
response_payload = [
|
response_payload = [
|
||||||
list(hidden_contiguous_f32.shape),
|
list(hidden_contiguous_f32.shape),
|
||||||
hidden_contiguous_f32.flatten().tolist(),
|
hidden_contiguous_f32.flatten().tolist(),
|
||||||
@@ -270,15 +266,15 @@ def create_hnsw_embedding_server(
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
logger.info(f"Received signal {sig}, shutting down gracefully...")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
# Register signal handlers for graceful shutdown
|
# Register signal handlers for graceful shutdown
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -6,9 +6,14 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.1.0"
|
version = "0.1.15"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
dependencies = [
|
||||||
|
"leann-core==0.1.15",
|
||||||
|
"numpy",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
wheel.packages = ["leann_backend_hnsw"]
|
wheel.packages = ["leann_backend_hnsw"]
|
||||||
@@ -19,4 +24,4 @@ build.tool-args = ["-j8"]
|
|||||||
|
|
||||||
# CMake definitions to optimize compilation
|
# CMake definitions to optimize compilation
|
||||||
[tool.scikit-build.cmake.define]
|
[tool.scikit-build.cmake.define]
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
CMAKE_BUILD_PARALLEL_LEVEL = "8"
|
||||||
|
|||||||
@@ -4,19 +4,46 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.1.0"
|
version = "0.1.15"
|
||||||
description = "Core API and plugin system for Leann."
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
|
||||||
|
# All required dependencies included
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
"tqdm>=4.60.0"
|
"tqdm>=4.60.0",
|
||||||
|
"psutil>=5.8.0",
|
||||||
|
"pyzmq>=23.0.0",
|
||||||
|
"msgpack>=1.0.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
|
"llama-index-core>=0.12.0",
|
||||||
|
"llama-index-readers-file>=0.4.0", # Essential for document reading
|
||||||
|
"llama-index-embeddings-huggingface>=0.5.5", # For embeddings
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"openai>=1.0.0",
|
||||||
|
"huggingface-hub>=0.20.0",
|
||||||
|
"transformers>=4.30.0",
|
||||||
|
"requests>=2.25.0",
|
||||||
|
"accelerate>=0.20.0",
|
||||||
|
"PyPDF2>=3.0.0",
|
||||||
|
"pymupdf>=1.23.0",
|
||||||
|
"pdfplumber>=0.10.0",
|
||||||
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
colab = [
|
||||||
|
"torch>=2.0.0,<3.0.0", # Limit torch version to avoid conflicts
|
||||||
|
"transformers>=4.30.0,<5.0.0", # Limit transformers version
|
||||||
|
"accelerate>=0.20.0,<1.0.0", # Limit accelerate version
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
leann = "leann.cli:main"
|
leann = "leann.cli:main"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|||||||
@@ -8,10 +8,14 @@ if platform.system() == "Darwin":
|
|||||||
os.environ["MKL_NUM_THREADS"] = "1"
|
os.environ["MKL_NUM_THREADS"] = "1"
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
os.environ["KMP_BLOCKTIME"] = "0"
|
os.environ["KMP_BLOCKTIME"] = "0"
|
||||||
|
# Additional fixes for PyTorch/sentence-transformers on macOS ARM64 only in CI
|
||||||
|
if os.environ.get("CI") == "true":
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
from .registry import BACKEND_REGISTRY, autodiscover_backends
|
||||||
|
|
||||||
autodiscover_backends()
|
autodiscover_backends()
|
||||||
|
|
||||||
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]
|
__all__ = ["BACKEND_REGISTRY", "LeannBuilder", "LeannChat", "LeannSearcher"]
|
||||||
|
|||||||
@@ -4,27 +4,36 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pickle
|
|
||||||
from leann.interface import LeannBackendSearcherInterface
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Dict, Any, Optional, Literal
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from .registry import BACKEND_REGISTRY
|
|
||||||
from .interface import LeannBackendFactoryInterface
|
|
||||||
from .chat import get_llm
|
|
||||||
import logging
|
import logging
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
|
from .chat import get_llm
|
||||||
|
from .interface import LeannBackendFactoryInterface
|
||||||
|
from .registry import BACKEND_REGISTRY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_registered_backends() -> list[str]:
|
||||||
|
"""Get list of registered backend names."""
|
||||||
|
return list(BACKEND_REGISTRY.keys())
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
chunks: List[str],
|
chunks: list[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: Optional[int] = None,
|
port: int | None = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -61,9 +70,7 @@ def compute_embeddings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_via_server(
|
def compute_embeddings_via_server(chunks: list[str], model_name: str, port: int) -> np.ndarray:
|
||||||
chunks: List[str], model_name: str, port: int
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Computes embeddings using sentence-transformers.
|
"""Computes embeddings using sentence-transformers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -73,9 +80,9 @@ def compute_embeddings_via_server(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
f"Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
)
|
)
|
||||||
import zmq
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
|
||||||
# Connect to embedding server
|
# Connect to embedding server
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
@@ -104,11 +111,11 @@ class SearchResult:
|
|||||||
id: str
|
id: str
|
||||||
score: float
|
score: float
|
||||||
text: str
|
text: str
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(self, passage_sources: List[Dict[str, Any]]):
|
def __init__(self, passage_sources: list[dict[str, Any]]):
|
||||||
self.offset_maps = {}
|
self.offset_maps = {}
|
||||||
self.passage_files = {}
|
self.passage_files = {}
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
self.global_offset_map = {} # Combined map for fast lookup
|
||||||
@@ -117,8 +124,15 @@ class PassageManager:
|
|||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source["path"]
|
||||||
index_file = source["index_path"] # .idx file
|
index_file = source["index_path"] # .idx file
|
||||||
|
|
||||||
|
# Fix path resolution for Colab and other environments
|
||||||
|
if not Path(index_file).is_absolute():
|
||||||
|
# If relative path, try to resolve it properly
|
||||||
|
index_file = str(Path(index_file).resolve())
|
||||||
|
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
offset_map = pickle.load(f)
|
offset_map = pickle.load(f)
|
||||||
self.offset_maps[passage_file] = offset_map
|
self.offset_maps[passage_file] = offset_map
|
||||||
@@ -128,11 +142,11 @@ class PassageManager:
|
|||||||
for passage_id, offset in offset_map.items():
|
for passage_id, offset in offset_map.items():
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
self.global_offset_map[passage_id] = (passage_file, offset)
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
if passage_id in self.global_offset_map:
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
passage_file, offset = self.global_offset_map[passage_id]
|
||||||
# Lazy file opening - only open when needed
|
# Lazy file opening - only open when needed
|
||||||
with open(passage_file, "r", encoding="utf-8") as f:
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
f.seek(offset)
|
f.seek(offset)
|
||||||
return json.loads(f.readline())
|
return json.loads(f.readline())
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
@@ -142,25 +156,93 @@ class LeannBuilder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backend_name: str,
|
backend_name: str,
|
||||||
embedding_model: str = "facebook/contriever-msmarco",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: int | None = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(
|
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
|
||||||
backend_name
|
|
||||||
)
|
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
|
||||||
self.backend_factory = backend_factory
|
self.backend_factory = backend_factory
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.backend_kwargs = backend_kwargs
|
|
||||||
self.chunks: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
# Check if we need to use cosine distance for normalized embeddings
|
||||||
|
normalized_embeddings_models = {
|
||||||
|
# OpenAI models
|
||||||
|
("openai", "text-embedding-ada-002"),
|
||||||
|
("openai", "text-embedding-3-small"),
|
||||||
|
("openai", "text-embedding-3-large"),
|
||||||
|
# Voyage AI models
|
||||||
|
("voyage", "voyage-2"),
|
||||||
|
("voyage", "voyage-3"),
|
||||||
|
("voyage", "voyage-large-2"),
|
||||||
|
("voyage", "voyage-multilingual-2"),
|
||||||
|
("voyage", "voyage-code-2"),
|
||||||
|
# Cohere models
|
||||||
|
("cohere", "embed-english-v3.0"),
|
||||||
|
("cohere", "embed-multilingual-v3.0"),
|
||||||
|
("cohere", "embed-english-light-v3.0"),
|
||||||
|
("cohere", "embed-multilingual-light-v3.0"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Also check for patterns in model names
|
||||||
|
is_normalized = False
|
||||||
|
current_model_lower = embedding_model.lower()
|
||||||
|
current_mode_lower = embedding_mode.lower()
|
||||||
|
|
||||||
|
# Check exact matches
|
||||||
|
for mode, model in normalized_embeddings_models:
|
||||||
|
if (current_mode_lower == mode and current_model_lower == model) or (
|
||||||
|
mode in current_mode_lower and model in current_model_lower
|
||||||
|
):
|
||||||
|
is_normalized = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check patterns
|
||||||
|
if not is_normalized:
|
||||||
|
# OpenAI patterns
|
||||||
|
if "openai" in current_mode_lower or "openai" in current_model_lower:
|
||||||
|
if any(
|
||||||
|
pattern in current_model_lower
|
||||||
|
for pattern in ["text-embedding", "ada", "3-small", "3-large"]
|
||||||
|
):
|
||||||
|
is_normalized = True
|
||||||
|
# Voyage patterns
|
||||||
|
elif "voyage" in current_mode_lower or "voyage" in current_model_lower:
|
||||||
|
is_normalized = True
|
||||||
|
# Cohere patterns
|
||||||
|
elif "cohere" in current_mode_lower or "cohere" in current_model_lower:
|
||||||
|
if "embed" in current_model_lower:
|
||||||
|
is_normalized = True
|
||||||
|
|
||||||
|
# Handle distance metric
|
||||||
|
if is_normalized and "distance_metric" not in backend_kwargs:
|
||||||
|
backend_kwargs["distance_metric"] = "cosine"
|
||||||
|
warnings.warn(
|
||||||
|
f"Detected normalized embeddings model '{embedding_model}' with mode '{embedding_mode}'. "
|
||||||
|
f"Automatically setting distance_metric='cosine' for optimal performance. "
|
||||||
|
f"Normalized embeddings (L2 norm = 1) should use cosine similarity instead of MIPS.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
elif is_normalized and backend_kwargs.get("distance_metric", "").lower() != "cosine":
|
||||||
|
current_metric = backend_kwargs.get("distance_metric", "mips")
|
||||||
|
warnings.warn(
|
||||||
|
f"Warning: Using '{current_metric}' distance metric with normalized embeddings model "
|
||||||
|
f"'{embedding_model}' may lead to suboptimal search results. "
|
||||||
|
f"Consider using 'cosine' distance metric for better performance.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backend_kwargs = backend_kwargs
|
||||||
|
self.chunks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def add_text(self, text: str, metadata: dict[str, Any] | None = None):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
passage_id = metadata.get("id", str(len(self.chunks)))
|
passage_id = metadata.get("id", str(len(self.chunks)))
|
||||||
@@ -190,9 +272,7 @@ class LeannBuilder:
|
|||||||
try:
|
try:
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
chunk_iterator = tqdm(
|
chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk")
|
||||||
self.chunks, desc="Writing passages", unit="chunk"
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
chunk_iterator = self.chunks
|
chunk_iterator = self.chunks
|
||||||
|
|
||||||
@@ -222,9 +302,7 @@ class LeannBuilder:
|
|||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||||
builder_instance.build(
|
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||||
embeddings, string_ids, index_path, **current_backend_kwargs
|
|
||||||
)
|
|
||||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
@@ -273,9 +351,7 @@ class LeannBuilder:
|
|||||||
ids, embeddings = data
|
ids, embeddings = data
|
||||||
|
|
||||||
if not isinstance(embeddings, np.ndarray):
|
if not isinstance(embeddings, np.ndarray):
|
||||||
raise ValueError(
|
raise ValueError(f"Expected embeddings to be numpy array, got {type(embeddings)}")
|
||||||
f"Expected embeddings to be numpy array, got {type(embeddings)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(ids) != embeddings.shape[0]:
|
if len(ids) != embeddings.shape[0]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -287,9 +363,7 @@ class LeannBuilder:
|
|||||||
if self.dimensions is None:
|
if self.dimensions is None:
|
||||||
self.dimensions = embedding_dim
|
self.dimensions = embedding_dim
|
||||||
elif self.dimensions != embedding_dim:
|
elif self.dimensions != embedding_dim:
|
||||||
raise ValueError(
|
raise ValueError(f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}")
|
||||||
f"Dimension mismatch: expected {self.dimensions}, got {embedding_dim}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
f"Building index from precomputed embeddings: {len(ids)} items, {embedding_dim} dimensions"
|
||||||
@@ -374,26 +448,24 @@ class LeannBuilder:
|
|||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||||
f"Index built successfully from precomputed embeddings: {index_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
|
# Fix path resolution for Colab and other environments
|
||||||
|
if not Path(index_path).is_absolute():
|
||||||
|
index_path = str(Path(index_path).resolve())
|
||||||
|
|
||||||
self.meta_path_str = f"{index_path}.meta.json"
|
self.meta_path_str = f"{index_path}.meta.json"
|
||||||
if not Path(self.meta_path_str).exists():
|
if not Path(self.meta_path_str).exists():
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(f"Leann metadata file not found at {self.meta_path_str}")
|
||||||
f"Leann metadata file not found at {self.meta_path_str}"
|
with open(self.meta_path_str, encoding="utf-8") as f:
|
||||||
)
|
|
||||||
with open(self.meta_path_str, "r", encoding="utf-8") as f:
|
|
||||||
self.meta_data = json.load(f)
|
self.meta_data = json.load(f)
|
||||||
backend_name = self.meta_data["backend_name"]
|
backend_name = self.meta_data["backend_name"]
|
||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get(
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
"embedding_mode", "sentence-transformers"
|
|
||||||
)
|
|
||||||
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
self.passage_manager = PassageManager(self.meta_data.get("passage_sources", []))
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
@@ -415,7 +487,7 @@ class LeannSearcher:
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
@@ -441,9 +513,9 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
embedding_time = time.time() - start_time
|
time.time() - start_time
|
||||||
logger.info(f" Embedding time: {embedding_time} seconds")
|
# logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
@@ -457,17 +529,15 @@ class LeannSearcher:
|
|||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
time.time() - start_time
|
||||||
logger.info(f" Search time: {search_time} seconds")
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
logger.info(
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
|
|
||||||
)
|
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
if "labels" in results and "distances" in results:
|
if "labels" in results and "distances" in results:
|
||||||
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
logger.info(f" Processing {len(results['labels'][0])} passage IDs:")
|
||||||
for i, (string_id, dist) in enumerate(
|
for i, (string_id, dist) in enumerate(
|
||||||
zip(results["labels"][0], results["distances"][0])
|
zip(results["labels"][0], results["distances"][0], strict=False)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
passage_data = self.passage_manager.get_passage(string_id)
|
||||||
@@ -479,15 +549,25 @@ class LeannSearcher:
|
|||||||
metadata=passage_data.get("metadata", {}),
|
metadata=passage_data.get("metadata", {}),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Color codes for better logging
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
# Truncate text for display (first 100 chars)
|
||||||
|
display_text = passage_data["text"]
|
||||||
logger.info(
|
logger.info(
|
||||||
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
|
f" {GREEN}✓{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
RED = "\033[91m"
|
||||||
logger.error(
|
logger.error(
|
||||||
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f" Final enriched results: {len(enriched_results)} passages")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
|
||||||
@@ -495,7 +575,7 @@ class LeannChat:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[Dict[str, Any]] = None,
|
llm_config: dict[str, Any] | None = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -511,13 +591,13 @@ class LeannChat:
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[Dict[str, Any]] = None,
|
llm_kwargs: dict[str, Any] | None = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
llm_kwargs = {}
|
llm_kwargs = {}
|
||||||
|
search_time = time.time()
|
||||||
results = self.searcher.search(
|
results = self.searcher.search(
|
||||||
question,
|
question,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -529,6 +609,8 @@ class LeannChat:
|
|||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
|
search_time = time.time() - search_time
|
||||||
|
# logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
|||||||
@@ -4,21 +4,24 @@ This file contains the chat generation logic for the LEANN project,
|
|||||||
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
|
supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
import difflib
|
||||||
from typing import Dict, Any, Optional, List
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import difflib
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_models() -> List[str]:
|
def check_ollama_models() -> list[str]:
|
||||||
"""Check available Ollama models and return a list"""
|
"""Check available Ollama models and return a list"""
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -28,68 +31,135 @@ def check_ollama_models() -> List[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
|
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
|
||||||
|
"""Check if a model exists in Ollama's remote library and return available tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model_exists, available_tags): bool and list of matching tags
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Split model name and tag
|
||||||
|
if ":" in model_name:
|
||||||
|
base_model, requested_tag = model_name.split(":", 1)
|
||||||
|
else:
|
||||||
|
base_model, requested_tag = model_name, None
|
||||||
|
|
||||||
|
# First check if base model exists in library
|
||||||
|
library_response = requests.get("https://ollama.com/library", timeout=8)
|
||||||
|
if library_response.status_code != 200:
|
||||||
|
return True, [] # Assume exists if can't check
|
||||||
|
|
||||||
|
# Extract model names from library page
|
||||||
|
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
|
||||||
|
|
||||||
|
if base_model not in models_in_library:
|
||||||
|
return False, [] # Base model doesn't exist
|
||||||
|
|
||||||
|
# If base model exists, get available tags
|
||||||
|
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
|
||||||
|
if tags_response.status_code != 200:
|
||||||
|
return True, [] # Base model exists but can't get tags
|
||||||
|
|
||||||
|
# Extract tags for this model - be more specific to avoid HTML artifacts
|
||||||
|
tag_pattern = rf"{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+"
|
||||||
|
raw_tags = re.findall(tag_pattern, tags_response.text)
|
||||||
|
|
||||||
|
# Clean up tags - remove HTML artifacts and duplicates
|
||||||
|
available_tags = []
|
||||||
|
seen = set()
|
||||||
|
for tag in raw_tags:
|
||||||
|
# Skip if it looks like HTML (contains < or >)
|
||||||
|
if "<" in tag or ">" in tag:
|
||||||
|
continue
|
||||||
|
if tag not in seen:
|
||||||
|
seen.add(tag)
|
||||||
|
available_tags.append(tag)
|
||||||
|
|
||||||
|
# Check if exact model exists
|
||||||
|
if requested_tag is None:
|
||||||
|
# User just requested base model, suggest tags
|
||||||
|
return True, available_tags[:10] # Return up to 10 tags
|
||||||
|
else:
|
||||||
|
exact_match = model_name in available_tags
|
||||||
|
return exact_match, available_tags[:10]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If scraping fails, assume model might exist (don't block user)
|
||||||
|
return True, []
|
||||||
|
|
||||||
|
|
||||||
|
def search_ollama_models_fuzzy(query: str, available_models: list[str]) -> list[str]:
|
||||||
"""Use intelligent fuzzy search for Ollama models"""
|
"""Use intelligent fuzzy search for Ollama models"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_lower = query.lower()
|
query_lower = query.lower()
|
||||||
suggestions = []
|
suggestions = []
|
||||||
|
|
||||||
# 1. Exact matches first
|
# 1. Exact matches first
|
||||||
exact_matches = [m for m in available_models if query_lower == m.lower()]
|
exact_matches = [m for m in available_models if query_lower == m.lower()]
|
||||||
suggestions.extend(exact_matches)
|
suggestions.extend(exact_matches)
|
||||||
|
|
||||||
# 2. Starts with query
|
# 2. Starts with query
|
||||||
starts_with = [m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions]
|
starts_with = [
|
||||||
|
m for m in available_models if m.lower().startswith(query_lower) and m not in suggestions
|
||||||
|
]
|
||||||
suggestions.extend(starts_with)
|
suggestions.extend(starts_with)
|
||||||
|
|
||||||
# 3. Contains query
|
# 3. Contains query
|
||||||
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
|
contains = [m for m in available_models if query_lower in m.lower() and m not in suggestions]
|
||||||
suggestions.extend(contains)
|
suggestions.extend(contains)
|
||||||
|
|
||||||
# 4. Base model name matching (remove version numbers)
|
# 4. Base model name matching (remove version numbers)
|
||||||
def get_base_name(model_name: str) -> str:
|
def get_base_name(model_name: str) -> str:
|
||||||
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
"""Extract base name without version (e.g., 'llama3:8b' -> 'llama3')"""
|
||||||
return model_name.split(':')[0].split('-')[0]
|
return model_name.split(":")[0].split("-")[0]
|
||||||
|
|
||||||
query_base = get_base_name(query_lower)
|
query_base = get_base_name(query_lower)
|
||||||
base_matches = [
|
base_matches = [
|
||||||
m for m in available_models
|
m
|
||||||
|
for m in available_models
|
||||||
if get_base_name(m.lower()) == query_base and m not in suggestions
|
if get_base_name(m.lower()) == query_base and m not in suggestions
|
||||||
]
|
]
|
||||||
suggestions.extend(base_matches)
|
suggestions.extend(base_matches)
|
||||||
|
|
||||||
# 5. Family/variant matching
|
# 5. Family/variant matching
|
||||||
model_families = {
|
model_families = {
|
||||||
'llama': ['llama2', 'llama3', 'alpaca', 'vicuna', 'codellama'],
|
"llama": ["llama2", "llama3", "alpaca", "vicuna", "codellama"],
|
||||||
'qwen': ['qwen', 'qwen2', 'qwen3'],
|
"qwen": ["qwen", "qwen2", "qwen3"],
|
||||||
'gemma': ['gemma', 'gemma2'],
|
"gemma": ["gemma", "gemma2"],
|
||||||
'phi': ['phi', 'phi2', 'phi3'],
|
"phi": ["phi", "phi2", "phi3"],
|
||||||
'mistral': ['mistral', 'mixtral', 'openhermes'],
|
"mistral": ["mistral", "mixtral", "openhermes"],
|
||||||
'dolphin': ['dolphin', 'openchat'],
|
"dolphin": ["dolphin", "openchat"],
|
||||||
'deepseek': ['deepseek', 'deepseek-coder']
|
"deepseek": ["deepseek", "deepseek-coder"],
|
||||||
}
|
}
|
||||||
|
|
||||||
query_family = None
|
query_family = None
|
||||||
for family, variants in model_families.items():
|
for family, variants in model_families.items():
|
||||||
if any(variant in query_lower for variant in variants):
|
if any(variant in query_lower for variant in variants):
|
||||||
query_family = family
|
query_family = family
|
||||||
break
|
break
|
||||||
|
|
||||||
if query_family:
|
if query_family:
|
||||||
family_variants = model_families[query_family]
|
family_variants = model_families[query_family]
|
||||||
family_matches = [
|
family_matches = [
|
||||||
m for m in available_models
|
m
|
||||||
|
for m in available_models
|
||||||
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
if any(variant in m.lower() for variant in family_variants) and m not in suggestions
|
||||||
]
|
]
|
||||||
suggestions.extend(family_matches)
|
suggestions.extend(family_matches)
|
||||||
|
|
||||||
# 6. Use difflib for remaining fuzzy matches
|
# 6. Use difflib for remaining fuzzy matches
|
||||||
remaining_models = [m for m in available_models if m not in suggestions]
|
remaining_models = [m for m in available_models if m not in suggestions]
|
||||||
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
|
difflib_matches = difflib.get_close_matches(query_lower, remaining_models, n=3, cutoff=0.4)
|
||||||
suggestions.extend(difflib_matches)
|
suggestions.extend(difflib_matches)
|
||||||
|
|
||||||
return suggestions[:8] # Return top 8 suggestions
|
return suggestions[:8] # Return top 8 suggestions
|
||||||
|
|
||||||
|
|
||||||
@@ -99,15 +169,13 @@ def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[
|
|||||||
# Remove this too - no need for fallback
|
# Remove this too - no need for fallback
|
||||||
|
|
||||||
|
|
||||||
def suggest_similar_models(invalid_model: str, available_models: List[str]) -> List[str]:
|
def suggest_similar_models(invalid_model: str, available_models: list[str]) -> list[str]:
|
||||||
"""Use difflib to find similar model names"""
|
"""Use difflib to find similar model names"""
|
||||||
if not available_models:
|
if not available_models:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Get close matches using fuzzy matching
|
# Get close matches using fuzzy matching
|
||||||
suggestions = difflib.get_close_matches(
|
suggestions = difflib.get_close_matches(invalid_model, available_models, n=3, cutoff=0.3)
|
||||||
invalid_model, available_models, n=3, cutoff=0.3
|
|
||||||
)
|
|
||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
|
|
||||||
@@ -115,49 +183,50 @@ def check_hf_model_exists(model_name: str) -> bool:
|
|||||||
"""Quick check if HuggingFace model exists without downloading"""
|
"""Quick check if HuggingFace model exists without downloading"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import model_info
|
from huggingface_hub import model_info
|
||||||
|
|
||||||
model_info(model_name)
|
model_info(model_name)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_popular_hf_models() -> List[str]:
|
def get_popular_hf_models() -> list[str]:
|
||||||
"""Return a list of popular HuggingFace models for suggestions"""
|
"""Return a list of popular HuggingFace models for suggestions"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
# Get popular text-generation models, sorted by downloads
|
# Get popular text-generation models, sorted by downloads
|
||||||
models = list_models(
|
models = list_models(
|
||||||
filter="text-generation",
|
filter="text-generation",
|
||||||
sort="downloads",
|
sort="downloads",
|
||||||
direction=-1,
|
direction=-1,
|
||||||
limit=20 # Get top 20 most downloaded
|
limit=20, # Get top 20 most downloaded
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract model names and filter for chat/conversation models
|
# Extract model names and filter for chat/conversation models
|
||||||
model_names = []
|
model_names = []
|
||||||
chat_keywords = ['chat', 'instruct', 'dialog', 'conversation', 'assistant']
|
chat_keywords = ["chat", "instruct", "dialog", "conversation", "assistant"]
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
model_name = model.id if hasattr(model, 'id') else str(model)
|
model_name = model.id if hasattr(model, "id") else str(model)
|
||||||
# Prioritize models with chat-related keywords
|
# Prioritize models with chat-related keywords
|
||||||
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
if any(keyword in model_name.lower() for keyword in chat_keywords):
|
||||||
model_names.append(model_name)
|
model_names.append(model_name)
|
||||||
elif len(model_names) < 10: # Fill up with other popular models
|
elif len(model_names) < 10: # Fill up with other popular models
|
||||||
model_names.append(model_name)
|
model_names.append(model_name)
|
||||||
|
|
||||||
return model_names[:10] if model_names else _get_fallback_hf_models()
|
return model_names[:10] if model_names else _get_fallback_hf_models()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to static list if API call fails
|
# Fallback to static list if API call fails
|
||||||
return _get_fallback_hf_models()
|
return _get_fallback_hf_models()
|
||||||
|
|
||||||
|
|
||||||
def _get_fallback_hf_models() -> List[str]:
|
def _get_fallback_hf_models() -> list[str]:
|
||||||
"""Fallback list of popular HuggingFace models"""
|
"""Fallback list of popular HuggingFace models"""
|
||||||
return [
|
return [
|
||||||
"microsoft/DialoGPT-medium",
|
"microsoft/DialoGPT-medium",
|
||||||
"microsoft/DialoGPT-large",
|
"microsoft/DialoGPT-large",
|
||||||
"facebook/blenderbot-400M-distill",
|
"facebook/blenderbot-400M-distill",
|
||||||
"microsoft/phi-2",
|
"microsoft/phi-2",
|
||||||
"deepseek-ai/deepseek-llm-7b-chat",
|
"deepseek-ai/deepseek-llm-7b-chat",
|
||||||
@@ -165,44 +234,44 @@ def _get_fallback_hf_models() -> List[str]:
|
|||||||
"facebook/blenderbot_small-90M",
|
"facebook/blenderbot_small-90M",
|
||||||
"microsoft/phi-1_5",
|
"microsoft/phi-1_5",
|
||||||
"facebook/opt-350m",
|
"facebook/opt-350m",
|
||||||
"EleutherAI/gpt-neo-1.3B"
|
"EleutherAI/gpt-neo-1.3B",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
def search_hf_models_fuzzy(query: str, limit: int = 10) -> list[str]:
|
||||||
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
"""Use HuggingFace Hub's native fuzzy search for model suggestions"""
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import list_models
|
from huggingface_hub import list_models
|
||||||
|
|
||||||
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
# HF Hub's search is already fuzzy! It handles typos and partial matches
|
||||||
models = list_models(
|
models = list_models(
|
||||||
search=query,
|
search=query,
|
||||||
filter="text-generation",
|
filter="text-generation",
|
||||||
sort="downloads",
|
sort="downloads",
|
||||||
direction=-1,
|
direction=-1,
|
||||||
limit=limit
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_names = [model.id if hasattr(model, 'id') else str(model) for model in models]
|
model_names = [model.id if hasattr(model, "id") else str(model) for model in models]
|
||||||
|
|
||||||
# If direct search doesn't return enough results, try some variations
|
# If direct search doesn't return enough results, try some variations
|
||||||
if len(model_names) < 3:
|
if len(model_names) < 3:
|
||||||
# Try searching for partial matches or common variations
|
# Try searching for partial matches or common variations
|
||||||
variations = []
|
variations = []
|
||||||
|
|
||||||
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
# Extract base name (e.g., "gpt3" from "gpt-3.5")
|
||||||
base_query = query.lower().replace('-', '').replace('.', '').replace('_', '')
|
base_query = query.lower().replace("-", "").replace(".", "").replace("_", "")
|
||||||
if base_query != query.lower():
|
if base_query != query.lower():
|
||||||
variations.append(base_query)
|
variations.append(base_query)
|
||||||
|
|
||||||
# Try common model name patterns
|
# Try common model name patterns
|
||||||
if 'gpt' in query.lower():
|
if "gpt" in query.lower():
|
||||||
variations.extend(['gpt2', 'gpt-neo', 'gpt-j', 'dialoGPT'])
|
variations.extend(["gpt2", "gpt-neo", "gpt-j", "dialoGPT"])
|
||||||
elif 'llama' in query.lower():
|
elif "llama" in query.lower():
|
||||||
variations.extend(['llama2', 'alpaca', 'vicuna'])
|
variations.extend(["llama2", "alpaca", "vicuna"])
|
||||||
elif 'bert' in query.lower():
|
elif "bert" in query.lower():
|
||||||
variations.extend(['roberta', 'distilbert', 'albert'])
|
variations.extend(["roberta", "distilbert", "albert"])
|
||||||
|
|
||||||
# Search with variations
|
# Search with variations
|
||||||
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
for var in variations[:2]: # Limit to 2 variations to avoid too many API calls
|
||||||
try:
|
try:
|
||||||
@@ -211,13 +280,15 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
|||||||
filter="text-generation",
|
filter="text-generation",
|
||||||
sort="downloads",
|
sort="downloads",
|
||||||
direction=-1,
|
direction=-1,
|
||||||
limit=3
|
limit=3,
|
||||||
)
|
)
|
||||||
var_names = [model.id if hasattr(model, 'id') else str(model) for model in var_models]
|
var_names = [
|
||||||
|
model.id if hasattr(model, "id") else str(model) for model in var_models
|
||||||
|
]
|
||||||
model_names.extend(var_names)
|
model_names.extend(var_names)
|
||||||
except:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Remove duplicates while preserving order
|
# Remove duplicates while preserving order
|
||||||
seen = set()
|
seen = set()
|
||||||
unique_models = []
|
unique_models = []
|
||||||
@@ -225,50 +296,96 @@ def search_hf_models_fuzzy(query: str, limit: int = 10) -> List[str]:
|
|||||||
if model not in seen:
|
if model not in seen:
|
||||||
seen.add(model)
|
seen.add(model)
|
||||||
unique_models.append(model)
|
unique_models.append(model)
|
||||||
|
|
||||||
return unique_models[:limit]
|
return unique_models[:limit]
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If search fails, return empty list
|
# If search fails, return empty list
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def search_hf_models(query: str, limit: int = 10) -> List[str]:
|
def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
||||||
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
"""Simple search for HuggingFace models based on query (kept for backward compatibility)"""
|
||||||
return search_hf_models_fuzzy(query, limit)
|
return search_hf_models_fuzzy(query, limit)
|
||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
def validate_model_and_suggest(model_name: str, llm_type: str) -> str | None:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models()
|
available_models = check_ollama_models()
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
# Use intelligent fuzzy search based on locally installed models
|
|
||||||
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
|
||||||
|
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
if suggestions:
|
|
||||||
error_msg += "\n\nDid you mean one of these installed models?\n"
|
# Check if the model exists remotely and get available tags
|
||||||
for i, suggestion in enumerate(suggestions, 1):
|
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
|
||||||
error_msg += f" {i}. {suggestion}\n"
|
|
||||||
|
if model_exists_remotely and model_name in available_tags:
|
||||||
|
# Exact model exists remotely - suggest pulling it
|
||||||
|
error_msg += "\n\nTo install the requested model:\n"
|
||||||
|
error_msg += f" ollama pull {model_name}\n"
|
||||||
|
|
||||||
|
# Show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
|
elif model_exists_remotely and available_tags:
|
||||||
|
# Base model exists but requested tag doesn't - suggest correct tags
|
||||||
|
base_model = model_name.split(":")[0]
|
||||||
|
requested_tag = model_name.split(":", 1)[1] if ":" in model_name else None
|
||||||
|
|
||||||
|
error_msg += (
|
||||||
|
f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
|
||||||
|
)
|
||||||
|
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
|
||||||
|
for i, tag in enumerate(available_tags[:8], 1):
|
||||||
|
error_msg += f" {i}. ollama pull {tag}\n"
|
||||||
|
if len(available_tags) > 8:
|
||||||
|
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
|
||||||
|
|
||||||
|
# Also show local alternatives
|
||||||
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
|
if suggestions:
|
||||||
|
error_msg += "\nOr use one of these similar installed models:\n"
|
||||||
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
error_msg += "\n\nYour installed models:\n"
|
# Model doesn't exist remotely - show fuzzy suggestions
|
||||||
for i, model in enumerate(available_models[:8], 1):
|
suggestions = search_ollama_models_fuzzy(model_name, available_models)
|
||||||
error_msg += f" {i}. {model}\n"
|
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
|
||||||
if len(available_models) > 8:
|
|
||||||
error_msg += f" ... and {len(available_models) - 8} more\n"
|
if suggestions:
|
||||||
|
error_msg += "\n\nDid you mean one of these installed models?\n"
|
||||||
error_msg += "\nTo list all models: ollama list"
|
for i, suggestion in enumerate(suggestions, 1):
|
||||||
error_msg += "\nTo download a new model: ollama pull <model_name>"
|
error_msg += f" {i}. {suggestion}\n"
|
||||||
error_msg += "\nBrowse models: https://ollama.com/library"
|
else:
|
||||||
|
error_msg += "\n\nYour installed models:\n"
|
||||||
|
for i, model in enumerate(available_models[:8], 1):
|
||||||
|
error_msg += f" {i}. {model}\n"
|
||||||
|
if len(available_models) > 8:
|
||||||
|
error_msg += f" ... and {len(available_models) - 8} more\n"
|
||||||
|
|
||||||
|
error_msg += "\n\nCommands:"
|
||||||
|
error_msg += "\n ollama list # List installed models"
|
||||||
|
if model_exists_remotely and available_tags:
|
||||||
|
if model_name in available_tags:
|
||||||
|
error_msg += f"\n ollama pull {model_name} # Install requested model"
|
||||||
|
else:
|
||||||
|
error_msg += (
|
||||||
|
f"\n ollama pull {available_tags[0]} # Install recommended variant"
|
||||||
|
)
|
||||||
|
error_msg += "\n https://ollama.com/library # Browse available models"
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
# For HF models, we can do a quick existence check
|
# For HF models, we can do a quick existence check
|
||||||
if not check_hf_model_exists(model_name):
|
if not check_hf_model_exists(model_name):
|
||||||
# Use HF Hub's native fuzzy search directly
|
# Use HF Hub's native fuzzy search directly
|
||||||
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
|
search_suggestions = search_hf_models_fuzzy(model_name, limit=8)
|
||||||
|
|
||||||
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
|
error_msg = f"Model '{model_name}' not found on HuggingFace Hub."
|
||||||
if search_suggestions:
|
if search_suggestions:
|
||||||
error_msg += "\n\nDid you mean one of these?\n"
|
error_msg += "\n\nDid you mean one of these?\n"
|
||||||
@@ -280,10 +397,10 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
|
|||||||
error_msg += "\n\nPopular chat models:\n"
|
error_msg += "\n\nPopular chat models:\n"
|
||||||
for i, model in enumerate(popular_models[:5], 1):
|
for i, model in enumerate(popular_models[:5], 1):
|
||||||
error_msg += f" {i}. {model}\n"
|
error_msg += f" {i}. {model}\n"
|
||||||
|
|
||||||
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
|
error_msg += f"\nSearch more: https://huggingface.co/models?search={model_name}&pipeline_tag=text-generation"
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
return None # Model is valid or we can't check
|
return None # Model is valid or we can't check
|
||||||
|
|
||||||
|
|
||||||
@@ -346,28 +463,27 @@ class OllamaChat(LLMInterface):
|
|||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if host:
|
if host:
|
||||||
requests.get(host)
|
requests.get(host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama")
|
model_error = validate_model_and_suggest(model, "ollama")
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(
|
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
|
||||||
)
|
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
import requests
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
full_url = f"{self.host}/api/generate"
|
full_url = f"{self.host}/api/generate"
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
@@ -377,7 +493,7 @@ class OllamaChat(LLMInterface):
|
|||||||
}
|
}
|
||||||
logger.debug(f"Sending request to Ollama: {payload}")
|
logger.debug(f"Sending request to Ollama: {payload}")
|
||||||
try:
|
try:
|
||||||
logger.info(f"Sending request to Ollama and waiting for response...")
|
logger.info("Sending request to Ollama and waiting for response...")
|
||||||
response = requests.post(full_url, data=json.dumps(payload))
|
response = requests.post(full_url, data=json.dumps(payload))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -397,19 +513,19 @@ class OllamaChat(LLMInterface):
|
|||||||
|
|
||||||
|
|
||||||
class HFChat(LLMInterface):
|
class HFChat(LLMInterface):
|
||||||
"""LLM interface for local Hugging Face Transformers models."""
|
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
|
||||||
logger.info(f"Initializing HFChat with model='{model_name}'")
|
logger.info(f"Initializing HFChat with model='{model_name}'")
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model_name, "hf")
|
model_error = validate_model_and_suggest(model_name, "hf")
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.pipelines import pipeline
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
"The 'transformers' and 'torch' libraries are required for Hugging Face models. Please install them with 'pip install transformers torch'."
|
||||||
@@ -417,60 +533,102 @@ class HFChat(LLMInterface):
|
|||||||
|
|
||||||
# Auto-detect device
|
# Auto-detect device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
self.device = "cuda"
|
||||||
logger.info("CUDA is available. Using GPU.")
|
logger.info("CUDA is available. Using GPU.")
|
||||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
device = "mps"
|
self.device = "mps"
|
||||||
logger.info("MPS is available. Using Apple Silicon GPU.")
|
logger.info("MPS is available. Using Apple Silicon GPU.")
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
self.device = "cpu"
|
||||||
logger.info("No GPU detected. Using CPU.")
|
logger.info("No GPU detected. Using CPU.")
|
||||||
|
|
||||||
self.pipeline = pipeline("text-generation", model=model_name, device=device)
|
# Load tokenizer and model
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
|
||||||
|
device_map="auto" if self.device != "cpu" else None,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move model to device if not using device_map
|
||||||
|
if self.device != "cpu" and "device_map" not in str(self.model):
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
# Set pad token if not present
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
# Map OpenAI-style arguments to Hugging Face equivalents
|
print("kwargs in HF: ", kwargs)
|
||||||
if "max_tokens" in kwargs:
|
# Check if this is a Qwen model and add /no_think by default
|
||||||
# Prefer user-provided max_new_tokens if both are present
|
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
|
||||||
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
|
|
||||||
# Remove the unsupported key to avoid errors in Transformers
|
|
||||||
kwargs.pop("max_tokens")
|
|
||||||
|
|
||||||
# Handle temperature=0 edge-case for greedy decoding
|
# For Qwen models, automatically add /no_think to the prompt
|
||||||
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
|
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
|
||||||
# Remove unsupported zero temperature and use deterministic generation
|
prompt = prompt + " /no_think"
|
||||||
kwargs.pop("temperature")
|
|
||||||
kwargs.setdefault("do_sample", False)
|
|
||||||
|
|
||||||
# Sensible defaults for text generation
|
# Prepare chat template
|
||||||
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
|
messages = [{"role": "user", "content": prompt}]
|
||||||
logger.info(f"Generating text with Hugging Face model with params: {params}")
|
|
||||||
results = self.pipeline(prompt, **params)
|
|
||||||
|
|
||||||
# Handle different response formats from transformers
|
# Apply chat template if available
|
||||||
if isinstance(results, list) and len(results) > 0:
|
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||||
generated_text = (
|
try:
|
||||||
results[0].get("generated_text", "")
|
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||||
if isinstance(results[0], dict)
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
else str(results[0])
|
)
|
||||||
)
|
except Exception as e:
|
||||||
|
logger.warning(f"Chat template failed, using raw prompt: {e}")
|
||||||
|
formatted_prompt = prompt
|
||||||
else:
|
else:
|
||||||
generated_text = str(results)
|
# Fallback for models without chat template
|
||||||
|
formatted_prompt = prompt
|
||||||
|
|
||||||
# Extract only the newly generated portion by removing the original prompt
|
# Tokenize input
|
||||||
if isinstance(generated_text, str) and generated_text.startswith(prompt):
|
inputs = self.tokenizer(
|
||||||
response = generated_text[len(prompt) :].strip()
|
formatted_prompt,
|
||||||
else:
|
return_tensors="pt",
|
||||||
# Fallback: return the full response if prompt removal fails
|
padding=True,
|
||||||
response = str(generated_text)
|
truncation=True,
|
||||||
|
max_length=2048,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
# Move inputs to device
|
||||||
|
if self.device != "cpu":
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Set generation parameters
|
||||||
|
generation_config = {
|
||||||
|
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
|
||||||
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
|
"top_p": kwargs.get("top_p", 0.9),
|
||||||
|
"do_sample": kwargs.get("temperature", 0.7) > 0,
|
||||||
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle temperature=0 for greedy decoding
|
||||||
|
if generation_config["temperature"] == 0.0:
|
||||||
|
generation_config["do_sample"] = False
|
||||||
|
generation_config.pop("temperature")
|
||||||
|
|
||||||
|
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.generate(**inputs, **generation_config)
|
||||||
|
|
||||||
|
# Decode response
|
||||||
|
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
||||||
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(self, model: str = "gpt-4o", api_key: str | None = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
@@ -497,11 +655,7 @@ class OpenAIChat(LLMInterface):
|
|||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||||
"temperature": kwargs.get("temperature", 0.7),
|
"temperature": kwargs.get("temperature", 0.7),
|
||||||
**{
|
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
|
||||||
k: v
|
|
||||||
for k, v in kwargs.items()
|
|
||||||
if k not in ["max_tokens", "temperature"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI with model {self.model}")
|
logger.info(f"Sending request to OpenAI with model {self.model}")
|
||||||
@@ -523,7 +677,7 @@ class SimulatedChat(LLMInterface):
|
|||||||
return "This is a simulated answer from the LLM based on the retrieved context."
|
return "This is a simulated answer from the LLM based on the retrieved context."
|
||||||
|
|
||||||
|
|
||||||
def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface:
|
def get_llm(llm_config: dict[str, Any] | None = None) -> LLMInterface:
|
||||||
"""
|
"""
|
||||||
Factory function to get an LLM interface based on configuration.
|
Factory function to get an LLM interface based on configuration.
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,38 @@ from pathlib import Path
|
|||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
from .api import LeannBuilder, LeannSearcher, LeannChat
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
|
"""Extract text from PDF using PyMuPDF for better quality."""
|
||||||
|
try:
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
|
||||||
|
doc = fitz.open(file_path)
|
||||||
|
text = ""
|
||||||
|
for page in doc:
|
||||||
|
text += page.get_text()
|
||||||
|
doc.close()
|
||||||
|
return text
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to default reader
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
|
||||||
|
"""Extract text from PDF using pdfplumber for better quality."""
|
||||||
|
try:
|
||||||
|
import pdfplumber
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
with pdfplumber.open(file_path) as pdf:
|
||||||
|
for page in pdf.pages:
|
||||||
|
text += page.extract_text() or ""
|
||||||
|
return text
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to default reader
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class LeannCLI:
|
class LeannCLI:
|
||||||
@@ -45,18 +76,12 @@ Examples:
|
|||||||
# Build command
|
# Build command
|
||||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
build_parser.add_argument("index_name", help="Index name")
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
build_parser.add_argument(
|
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
||||||
"--docs", type=str, required=True, help="Documents directory"
|
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||||
)
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
||||||
"--embedding-model", type=str, default="facebook/contriever"
|
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
||||||
)
|
|
||||||
build_parser.add_argument(
|
|
||||||
"--force", "-f", action="store_true", help="Force rebuild"
|
|
||||||
)
|
|
||||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
build_parser.add_argument("--complexity", type=int, default=64)
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
@@ -102,7 +127,7 @@ Examples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
subparsers.add_parser("list", help="List all indexes")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -110,17 +135,13 @@ Examples:
|
|||||||
print("Stored LEANN indexes:")
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
if not self.indexes_dir.exists():
|
if not self.indexes_dir.exists():
|
||||||
print(
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
if not index_dirs:
|
if not index_dirs:
|
||||||
print(
|
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
||||||
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Found {len(index_dirs)} indexes:")
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
@@ -130,27 +151,58 @@ Examples:
|
|||||||
|
|
||||||
print(f" {i}. {index_name} [{status}]")
|
print(f" {i}. {index_name} [{status}]")
|
||||||
if self.index_exists(index_name):
|
if self.index_exists(index_name):
|
||||||
meta_file = index_dir / "documents.leann.meta.json"
|
index_dir / "documents.leann.meta.json"
|
||||||
size_mb = sum(
|
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (
|
||||||
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
1024 * 1024
|
||||||
) / (1024 * 1024)
|
)
|
||||||
print(f" Size: {size_mb:.1f} MB")
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
if index_dirs:
|
if index_dirs:
|
||||||
example_name = index_dirs[0].name
|
example_name = index_dirs[0].name
|
||||||
print(f"\nUsage:")
|
print("\nUsage:")
|
||||||
print(f' leann search {example_name} "your query"')
|
print(f' leann search {example_name} "your query"')
|
||||||
print(f" leann ask {example_name} --interactive")
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
def load_documents(self, docs_dir: str):
|
def load_documents(self, docs_dir: str):
|
||||||
print(f"Loading documents from {docs_dir}...")
|
print(f"Loading documents from {docs_dir}...")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
# Try to use better PDF parsers first
|
||||||
|
documents = []
|
||||||
|
docs_path = Path(docs_dir)
|
||||||
|
|
||||||
|
for file_path in docs_path.rglob("*.pdf"):
|
||||||
|
print(f"Processing PDF: {file_path}")
|
||||||
|
|
||||||
|
# Try PyMuPDF first (best quality)
|
||||||
|
text = extract_pdf_text_with_pymupdf(str(file_path))
|
||||||
|
if text is None:
|
||||||
|
# Try pdfplumber
|
||||||
|
text = extract_pdf_text_with_pdfplumber(str(file_path))
|
||||||
|
|
||||||
|
if text:
|
||||||
|
# Create a simple document structure
|
||||||
|
from llama_index.core import Document
|
||||||
|
|
||||||
|
doc = Document(text=text, metadata={"source": str(file_path)})
|
||||||
|
documents.append(doc)
|
||||||
|
else:
|
||||||
|
# Fallback to default reader
|
||||||
|
print(f"Using default reader for {file_path}")
|
||||||
|
default_docs = SimpleDirectoryReader(
|
||||||
|
str(file_path.parent),
|
||||||
|
filename_as_id=True,
|
||||||
|
required_exts=[file_path.suffix],
|
||||||
|
).load_data()
|
||||||
|
documents.extend(default_docs)
|
||||||
|
|
||||||
|
# Load other file types with default reader
|
||||||
|
other_docs = SimpleDirectoryReader(
|
||||||
docs_dir,
|
docs_dir,
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md", ".docx"],
|
required_exts=[".txt", ".md", ".docx"],
|
||||||
).load_data(show_progress=True)
|
).load_data(show_progress=True)
|
||||||
|
documents.extend(other_docs)
|
||||||
|
|
||||||
all_texts = []
|
all_texts = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ Consolidates all embedding computation logic using SentenceTransformer
|
|||||||
Preserves all optimization parameters to ensure performance
|
Preserves all optimization parameters to ensure performance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -17,11 +18,11 @@ log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
|||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Global model cache to avoid repeated loading
|
# Global model cache to avoid repeated loading
|
||||||
_model_cache: Dict[str, Any] = {}
|
_model_cache: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: List[str],
|
texts: list[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
mode: str = "sentence-transformers",
|
mode: str = "sentence-transformers",
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
@@ -59,7 +60,7 @@ def compute_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(
|
def compute_embeddings_sentence_transformers(
|
||||||
texts: List[str],
|
texts: list[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_fp16: bool = True,
|
use_fp16: bool = True,
|
||||||
device: str = "auto",
|
device: str = "auto",
|
||||||
@@ -101,7 +102,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if device == "mps":
|
if device == "mps":
|
||||||
batch_size = 128 # MPS optimal batch size from benchmark
|
batch_size = 128 # MPS optimal batch size from benchmark
|
||||||
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
if model_name == "Qwen/Qwen3-Embedding-0.6B":
|
||||||
batch_size = 64
|
batch_size = 32
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
batch_size = 256 # CUDA optimal batch size
|
batch_size = 256 # CUDA optimal batch size
|
||||||
# Keep original batch_size for CPU
|
# Keep original batch_size for CPU
|
||||||
@@ -114,9 +115,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Using cached optimized model: {model_name}")
|
logger.info(f"Using cached optimized model: {model_name}")
|
||||||
model = _model_cache[cache_key]
|
model = _model_cache[cache_key]
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(f"Loading and caching optimized SentenceTransformer model: {model_name}")
|
||||||
f"Loading and caching optimized SentenceTransformer model: {model_name}"
|
|
||||||
)
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
logger.info(f"Using device: {device}")
|
logger.info(f"Using device: {device}")
|
||||||
@@ -134,9 +133,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
||||||
torch.mps.set_per_process_memory_fraction(0.9)
|
torch.mps.set_per_process_memory_fraction(0.9)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning(
|
logger.warning("Some MPS optimizations not available in this PyTorch version")
|
||||||
"Some MPS optimizations not available in this PyTorch version"
|
|
||||||
)
|
|
||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
# TODO: Haven't tested this yet
|
# TODO: Haven't tested this yet
|
||||||
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
torch.set_num_threads(min(8, os.cpu_count() or 4))
|
||||||
@@ -226,25 +223,22 @@ def compute_embeddings_sentence_transformers(
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Detected NaN or Inf values in embeddings, model: {model_name}")
|
||||||
f"Detected NaN or Inf values in embeddings, model: {model_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
import openai
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import openai
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
|
|
||||||
@@ -264,9 +258,10 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
|
print(f"len of texts: {len(texts)}")
|
||||||
|
|
||||||
# OpenAI has limits on batch size and input length
|
# OpenAI has limits on batch size and input length
|
||||||
max_batch_size = 100 # Conservative batch size
|
max_batch_size = 1000 # Conservative batch size
|
||||||
all_embeddings = []
|
all_embeddings = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -293,15 +288,12 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
embeddings = np.array(all_embeddings, dtype=np.float32)
|
embeddings = np.array(all_embeddings, dtype=np.float32)
|
||||||
logger.info(
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
|
print(f"len of embeddings: {len(embeddings)}")
|
||||||
)
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_mlx(
|
def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int = 16) -> np.ndarray:
|
||||||
chunks: List[str], model_name: str, batch_size: int = 16
|
|
||||||
) -> np.ndarray:
|
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Computes embeddings using an MLX model."""
|
"""Computes embeddings using an MLX model."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import time
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import os
|
import time
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
@@ -18,6 +18,24 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_colab_environment() -> bool:
|
||||||
|
"""Check if we're running in Google Colab environment."""
|
||||||
|
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_port(start_port: int = 5557) -> int:
|
||||||
|
"""Get an available port starting from start_port."""
|
||||||
|
port = start_port
|
||||||
|
while port < start_port + 100: # Try up to 100 ports
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("localhost", port))
|
||||||
|
return port
|
||||||
|
except OSError:
|
||||||
|
port += 1
|
||||||
|
raise RuntimeError(f"No available ports found in range {start_port}-{start_port + 100}")
|
||||||
|
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
def _check_port(port: int) -> bool:
|
||||||
"""Check if a port is in use"""
|
"""Check if a port is in use"""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
@@ -164,8 +182,8 @@ class EmbeddingServerManager:
|
|||||||
e.g., "leann_backend_diskann.embedding_server"
|
e.g., "leann_backend_diskann.embedding_server"
|
||||||
"""
|
"""
|
||||||
self.backend_module_name = backend_module_name
|
self.backend_module_name = backend_module_name
|
||||||
self.server_process: Optional[subprocess.Popen] = None
|
self.server_process: subprocess.Popen | None = None
|
||||||
self.server_port: Optional[int] = None
|
self.server_port: int | None = None
|
||||||
self._atexit_registered = False
|
self._atexit_registered = False
|
||||||
|
|
||||||
def start_server(
|
def start_server(
|
||||||
@@ -175,68 +193,69 @@ class EmbeddingServerManager:
|
|||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""
|
"""Start the embedding server."""
|
||||||
Starts the embedding server process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
port (int): The preferred ZMQ port for the server.
|
|
||||||
model_name (str): The name of the embedding model to use.
|
|
||||||
**kwargs: Additional arguments for the server.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[bool, int]: (success, actual_port_used)
|
|
||||||
"""
|
|
||||||
passages_file = kwargs.get("passages_file")
|
passages_file = kwargs.get("passages_file")
|
||||||
assert isinstance(passages_file, str), "passages_file must be a string"
|
|
||||||
|
|
||||||
# Check if we have a compatible running server
|
# Check if we have a compatible server already running
|
||||||
if self._has_compatible_running_server(model_name, passages_file):
|
if self._has_compatible_running_server(model_name, passages_file):
|
||||||
assert self.server_port is not None, (
|
logger.info("Found compatible running server!")
|
||||||
"a compatible running server should set server_port"
|
return True, port
|
||||||
)
|
|
||||||
return True, self.server_port
|
|
||||||
|
|
||||||
# Find available port (compatible or free)
|
# For Colab environment, use a different strategy
|
||||||
try:
|
if _is_colab_environment():
|
||||||
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
port, model_name, passages_file
|
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
# Find a compatible port or next available
|
||||||
logger.error(str(e))
|
actual_port, is_compatible = _find_compatible_port_or_next_available(
|
||||||
return False, port
|
port, model_name, passages_file
|
||||||
|
)
|
||||||
|
|
||||||
if is_compatible:
|
if is_compatible:
|
||||||
logger.info(f"Using existing compatible server on port {actual_port}")
|
logger.info(f"Found compatible server on port {actual_port}")
|
||||||
self.server_port = actual_port
|
|
||||||
self.server_process = None # We don't own this process
|
|
||||||
return True, actual_port
|
return True, actual_port
|
||||||
|
|
||||||
if actual_port != port:
|
# Start a new server
|
||||||
logger.info(f"Using port {actual_port} instead of {port}")
|
|
||||||
|
|
||||||
# Start new server
|
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
def _has_compatible_running_server(
|
def _start_server_colab(
|
||||||
self, model_name: str, passages_file: str
|
self,
|
||||||
) -> bool:
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Start server with Colab-specific configuration."""
|
||||||
|
# Try to find an available port
|
||||||
|
try:
|
||||||
|
actual_port = _get_available_port(port)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.error("No available ports found")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
logger.info(f"Starting server on port {actual_port} for Colab environment")
|
||||||
|
|
||||||
|
# Use a simpler startup strategy for Colab
|
||||||
|
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# In Colab, we'll use a more direct approach
|
||||||
|
self._launch_server_process_colab(command, actual_port)
|
||||||
|
return self._wait_for_server_ready_colab(actual_port)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
|
return False, actual_port
|
||||||
|
|
||||||
|
def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool:
|
||||||
"""Check if we have a compatible running server."""
|
"""Check if we have a compatible running server."""
|
||||||
if not (
|
if not (self.server_process and self.server_process.poll() is None and self.server_port):
|
||||||
self.server_process
|
|
||||||
and self.server_process.poll() is None
|
|
||||||
and self.server_port
|
|
||||||
):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
logger.info(
|
logger.info(f"Existing server process (PID {self.server_process.pid}) is compatible")
|
||||||
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.info(
|
logger.info("Existing server process is incompatible. Should start a new server.")
|
||||||
"Existing server process is incompatible. Should start a new server."
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
@@ -269,9 +288,13 @@ class EmbeddingServerManager:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if kwargs.get("passages_file"):
|
if kwargs.get("passages_file"):
|
||||||
command.extend(["--passages-file", str(kwargs["passages_file"])])
|
# Convert to absolute path to ensure subprocess can find the file
|
||||||
|
passages_file = Path(kwargs["passages_file"]).resolve()
|
||||||
|
command.extend(["--passages-file", str(passages_file)])
|
||||||
if embedding_mode != "sentence-transformers":
|
if embedding_mode != "sentence-transformers":
|
||||||
command.extend(["--embedding-mode", embedding_mode])
|
command.extend(["--embedding-mode", embedding_mode])
|
||||||
|
if kwargs.get("distance_metric"):
|
||||||
|
command.extend(["--distance-metric", kwargs["distance_metric"]])
|
||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
@@ -346,3 +369,45 @@ class EmbeddingServerManager:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
self.server_process = None
|
self.server_process = None
|
||||||
|
|
||||||
|
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
||||||
|
"""Launch the server process with Colab-specific settings."""
|
||||||
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
|
|
||||||
|
# In Colab, we need to be more careful about process management
|
||||||
|
self.server_process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
self.server_port = port
|
||||||
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
|
# Register atexit callback
|
||||||
|
if not self._atexit_registered:
|
||||||
|
atexit.register(lambda: self.stop_server() if self.server_process else None)
|
||||||
|
self._atexit_registered = True
|
||||||
|
|
||||||
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
|
||||||
|
|
||||||
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
|
if _check_port(port):
|
||||||
|
logger.info("Colab embedding server is ready!")
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
|
# Check for error output
|
||||||
|
stdout, stderr = self.server_process.communicate()
|
||||||
|
logger.error("Colab server terminated during startup.")
|
||||||
|
logger.error(f"stdout: {stdout}")
|
||||||
|
logger.error(f"stderr: {stderr}")
|
||||||
|
return False, port
|
||||||
|
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
|
logger.error(f"Colab server failed to start within {max_wait} seconds.")
|
||||||
|
self.stop_server()
|
||||||
|
return False, port
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Any, List, Literal, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class LeannBackendBuilderInterface(ABC):
|
class LeannBackendBuilderInterface(ABC):
|
||||||
"""Backend interface for building indexes"""
|
"""Backend interface for building indexes"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(
|
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs) -> None:
|
||||||
self, data: np.ndarray, ids: List[str], index_path: str, **kwargs
|
|
||||||
) -> None:
|
|
||||||
"""Build index
|
"""Build index
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -35,9 +34,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(self, passages_source_file: str, port: int | None, **kwargs) -> int:
|
||||||
self, passages_source_file: str, port: Optional[int], **kwargs
|
|
||||||
) -> int:
|
|
||||||
"""Ensure server is running"""
|
"""Ensure server is running"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -51,9 +48,9 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Search for nearest neighbors
|
"""Search for nearest neighbors
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -77,7 +74,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
use_server_if_available: bool = True,
|
use_server_if_available: bool = True,
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Compute embedding for a query string
|
"""Compute embedding for a query string
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
# packages/leann-core/src/leann/registry.py
|
# packages/leann-core/src/leann/registry.py
|
||||||
|
|
||||||
from typing import Dict, TYPE_CHECKING
|
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from leann.interface import LeannBackendFactoryInterface
|
from leann.interface import LeannBackendFactoryInterface
|
||||||
|
|
||||||
BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {}
|
BACKEND_REGISTRY: dict[str, "LeannBackendFactoryInterface"] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_backend(name: str):
|
def register_backend(name: str):
|
||||||
@@ -31,13 +31,11 @@ def autodiscover_backends():
|
|||||||
backend_module_name = dist_name.replace("-", "_")
|
backend_module_name = dist_name.replace("-", "_")
|
||||||
discovered_backends.append(backend_module_name)
|
discovered_backends.append(backend_module_name)
|
||||||
|
|
||||||
for backend_module_name in sorted(
|
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
|
||||||
discovered_backends
|
|
||||||
): # sort for deterministic loading
|
|
||||||
try:
|
try:
|
||||||
importlib.import_module(backend_module_name)
|
importlib.import_module(backend_module_name)
|
||||||
# Registration message is printed by the decorator
|
# Registration message is printed by the decorator
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
# print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
|
||||||
pass
|
pass
|
||||||
# print("INFO: Backend auto-discovery finished.")
|
# print("INFO: Backend auto-discovery finished.")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -38,9 +38,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
|
|
||||||
self.embedding_model = self.meta.get("embedding_model")
|
self.embedding_model = self.meta.get("embedding_model")
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
print(
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||||
"WARNING: embedding_model not found in meta.json. Recompute will fail."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
|
||||||
@@ -48,39 +46,40 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_meta(self) -> Dict[str, Any]:
|
def _load_meta(self) -> dict[str, Any]:
|
||||||
"""Loads the metadata file associated with the index."""
|
"""Loads the metadata file associated with the index."""
|
||||||
# This is the corrected logic for finding the meta file.
|
# This is the corrected logic for finding the meta file.
|
||||||
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
meta_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
if not meta_path.exists():
|
if not meta_path.exists():
|
||||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
|
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}")
|
||||||
with open(meta_path, "r", encoding="utf-8") as f:
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def _ensure_server_running(
|
def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> int:
|
||||||
self, passages_source_file: str, port: int, **kwargs
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
Ensures the embedding server is running if recompute is needed.
|
Ensures the embedding server is running if recompute is needed.
|
||||||
This is a helper for subclasses.
|
This is a helper for subclasses.
|
||||||
"""
|
"""
|
||||||
if not self.embedding_model:
|
if not self.embedding_model:
|
||||||
raise ValueError(
|
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
|
||||||
"Cannot use recompute mode without 'embedding_model' in meta.json."
|
|
||||||
)
|
# Get distance_metric from meta if not provided in kwargs
|
||||||
|
distance_metric = (
|
||||||
|
kwargs.get("distance_metric")
|
||||||
|
or self.meta.get("backend_kwargs", {}).get("distance_metric")
|
||||||
|
or "mips"
|
||||||
|
)
|
||||||
|
|
||||||
server_started, actual_port = self.embedding_server_manager.start_server(
|
server_started, actual_port = self.embedding_server_manager.start_server(
|
||||||
port=port,
|
port=port,
|
||||||
model_name=self.embedding_model,
|
model_name=self.embedding_model,
|
||||||
embedding_mode=self.embedding_mode,
|
embedding_mode=self.embedding_mode,
|
||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=kwargs.get("distance_metric"),
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
f"Failed to start embedding server on port {actual_port}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return actual_port
|
return actual_port
|
||||||
|
|
||||||
@@ -109,11 +108,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
# on that port?
|
# on that port?
|
||||||
|
|
||||||
# Ensure we have a server with passages_file for compatibility
|
# Ensure we have a server with passages_file for compatibility
|
||||||
passages_source_file = (
|
passages_source_file = self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
# Convert to absolute path to ensure server can find it
|
||||||
)
|
|
||||||
zmq_port = self._ensure_server_running(
|
zmq_port = self._ensure_server_running(
|
||||||
str(passages_source_file), zmq_port
|
str(passages_source_file.resolve()), zmq_port
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._compute_embedding_via_server([query], zmq_port)[
|
return self._compute_embedding_via_server([query], zmq_port)[
|
||||||
@@ -131,8 +129,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
|
|
||||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
"""Compute embeddings using the ZMQ embedding server."""
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
import zmq
|
|
||||||
import msgpack
|
import msgpack
|
||||||
|
import zmq
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
@@ -171,9 +169,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: Optional[int] = None,
|
zmq_port: int | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Search for the top_k nearest neighbors of the query vector.
|
Search for the top_k nearest neighbors of the query vector.
|
||||||
|
|
||||||
|
|||||||
39
packages/leann/README.md
Normal file
39
packages/leann/README.md
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# LEANN - The smallest vector index in the world
|
||||||
|
|
||||||
|
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default installation (HNSW backend, recommended)
|
||||||
|
uv pip install leann
|
||||||
|
|
||||||
|
# With DiskANN backend (for large-scale deployments)
|
||||||
|
uv pip install leann[diskann]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||||
|
from pathlib import Path
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
|
# Build an index
|
||||||
|
builder = LeannBuilder(backend_name="hnsw")
|
||||||
|
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
|
||||||
|
builder.add_text("Tung Tung Tung Sahur called—they need their banana‑crocodile hybrid back")
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
|
||||||
|
# Search
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
results = searcher.search("fantastical AI-generated creatures", top_k=1)
|
||||||
|
|
||||||
|
# Chat with your data
|
||||||
|
chat = LeannChat(INDEX_PATH, llm_config={"type": "hf", "model": "Qwen/Qwen3-0.6B"})
|
||||||
|
response = chat.ask("How much storage does LEANN save?", top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License
|
||||||
12
packages/leann/__init__.py
Normal file
12
packages/leann/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
LEANN - Low-storage Embedding Approximation for Neural Networks
|
||||||
|
|
||||||
|
A revolutionary vector database that democratizes personal AI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
# Re-export main API from leann-core
|
||||||
|
from leann_core import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
__all__ = ["LeannBuilder", "LeannChat", "LeannSearcher"]
|
||||||
40
packages/leann/pyproject.toml
Normal file
40
packages/leann/pyproject.toml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "leann"
|
||||||
|
version = "0.1.15"
|
||||||
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
authors = [
|
||||||
|
{ name = "LEANN Team" }
|
||||||
|
]
|
||||||
|
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default installation: core + hnsw
|
||||||
|
dependencies = [
|
||||||
|
"leann-core>=0.1.0",
|
||||||
|
"leann-backend-hnsw>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
diskann = [
|
||||||
|
"leann-backend-diskann>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Repository = "https://github.com/yichuan-w/LEANN"
|
||||||
|
Issues = "https://github.com/yichuan-w/LEANN/issues"
|
||||||
@@ -1,22 +1,23 @@
|
|||||||
import json
|
import json
|
||||||
import typer
|
|
||||||
from pathlib import Path
|
|
||||||
import requests
|
|
||||||
from tqdm import tqdm
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import xml.etree.ElementTree as ElementTree
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import typer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
def get_safe_path(s: str) -> str:
|
def get_safe_path(s: str) -> str:
|
||||||
"""
|
"""
|
||||||
Remove invalid characters to sanitize a path.
|
Remove invalid characters to sanitize a path.
|
||||||
:param s: str to sanitize
|
:param s: str to sanitize
|
||||||
:returns: sanitized str
|
:returns: sanitized str
|
||||||
"""
|
"""
|
||||||
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(
|
ban_chars = "\\ / : * ? \" ' < > | $ \r \n".replace(" ", "")
|
||||||
' ', '')
|
|
||||||
for i in ban_chars:
|
for i in ban_chars:
|
||||||
s = s.replace(i, "")
|
s = s.replace(i, "")
|
||||||
return s
|
return s
|
||||||
@@ -25,36 +26,40 @@ def get_safe_path(s: str) -> str:
|
|||||||
def process_history(history: str):
|
def process_history(history: str):
|
||||||
if history.startswith("<?xml") or history.startswith("<msg>"):
|
if history.startswith("<?xml") or history.startswith("<msg>"):
|
||||||
try:
|
try:
|
||||||
root = ET.fromstring(history)
|
root = ElementTree.fromstring(history)
|
||||||
title = root.find('.//title').text if root.find('.//title') is not None else None
|
title = root.find(".//title").text if root.find(".//title") is not None else None
|
||||||
quoted = root.find('.//refermsg/content').text if root.find('.//refermsg/content') is not None else None
|
quoted = (
|
||||||
|
root.find(".//refermsg/content").text
|
||||||
|
if root.find(".//refermsg/content") is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
if title and quoted:
|
if title and quoted:
|
||||||
return {
|
return {"title": title, "quoted": process_history(quoted)}
|
||||||
"title": title,
|
|
||||||
"quoted": process_history(quoted)
|
|
||||||
}
|
|
||||||
if title:
|
if title:
|
||||||
return title
|
return title
|
||||||
except Exception:
|
except Exception:
|
||||||
return history
|
return history
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def get_message(history: dict | str):
|
def get_message(history: dict | str):
|
||||||
if isinstance(history, dict):
|
if isinstance(history, dict):
|
||||||
if 'title' in history:
|
if "title" in history:
|
||||||
return history['title']
|
return history["title"]
|
||||||
else:
|
else:
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def export_chathistory(user_id: str):
|
def export_chathistory(user_id: str):
|
||||||
res = requests.get("http://localhost:48065/wechat/chatlog", params={
|
res = requests.get(
|
||||||
"userId": user_id,
|
"http://localhost:48065/wechat/chatlog",
|
||||||
"count": 100000
|
params={"userId": user_id, "count": 100000},
|
||||||
}).json()
|
).json()
|
||||||
for i in range(len(res['chatLogs'])):
|
for i in range(len(res["chatLogs"])):
|
||||||
res['chatLogs'][i]['content'] = process_history(res['chatLogs'][i]['content'])
|
res["chatLogs"][i]["content"] = process_history(res["chatLogs"][i]["content"])
|
||||||
res['chatLogs'][i]['message'] = get_message(res['chatLogs'][i]['content'])
|
res["chatLogs"][i]["message"] = get_message(res["chatLogs"][i]["content"])
|
||||||
return res['chatLogs']
|
return res["chatLogs"]
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
|
def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")]):
|
||||||
@@ -64,7 +69,7 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
|
|||||||
if not dest.is_dir():
|
if not dest.is_dir():
|
||||||
if not dest.exists():
|
if not dest.exists():
|
||||||
inp = typer.prompt("Destination path does not exist, create it? (y/n)")
|
inp = typer.prompt("Destination path does not exist, create it? (y/n)")
|
||||||
if inp.lower() == 'y':
|
if inp.lower() == "y":
|
||||||
dest.mkdir(parents=True)
|
dest.mkdir(parents=True)
|
||||||
else:
|
else:
|
||||||
typer.echo("Aborted.", err=True)
|
typer.echo("Aborted.", err=True)
|
||||||
@@ -77,12 +82,12 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
|
|||||||
exported_count = 0
|
exported_count = 0
|
||||||
for user in tqdm(all_users):
|
for user in tqdm(all_users):
|
||||||
try:
|
try:
|
||||||
usr_chatlog = export_chathistory(user['arg'])
|
usr_chatlog = export_chathistory(user["arg"])
|
||||||
|
|
||||||
# Only write file if there are messages
|
# Only write file if there are messages
|
||||||
if len(usr_chatlog) > 0:
|
if len(usr_chatlog) > 0:
|
||||||
out_path = dest/get_safe_path((user['title'] or "")+"-"+user['arg']+'.json')
|
out_path = dest / get_safe_path((user["title"] or "") + "-" + user["arg"] + ".json")
|
||||||
with open(out_path, 'w', encoding='utf-8') as f:
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
|
json.dump(usr_chatlog, f, ensure_ascii=False, indent=2)
|
||||||
exported_count += 1
|
exported_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -91,23 +96,43 @@ def export_all(dest: Annotated[Path, typer.Argument(help="Destination path to ex
|
|||||||
|
|
||||||
print(f"Exported {exported_count} users' chat history to {dest} in json.")
|
print(f"Exported {exported_count} users' chat history to {dest} in json.")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def export_sqlite(dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path("chatlog.db")):
|
def export_sqlite(
|
||||||
|
dest: Annotated[Path, typer.Argument(help="Destination path to export to.")] = Path(
|
||||||
|
"chatlog.db"
|
||||||
|
),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export all users' chat history to a sqlite database.
|
Export all users' chat history to a sqlite database.
|
||||||
"""
|
"""
|
||||||
connection = sqlite3.connect(dest)
|
connection = sqlite3.connect(dest)
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
cursor.execute("CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)")
|
cursor.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS chatlog (id INTEGER PRIMARY KEY AUTOINCREMENT, with_id TEXT, from_user TEXT, to_user TEXT, message TEXT, timest DATETIME, auxiliary TEXT)"
|
||||||
|
)
|
||||||
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
|
cursor.execute("CREATE INDEX IF NOT EXISTS chatlog_with_id_index ON chatlog (with_id)")
|
||||||
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
|
cursor.execute("CREATE TABLE iF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT)")
|
||||||
|
|
||||||
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
all_users = requests.get("http://localhost:48065/wechat/allcontacts").json()
|
||||||
for user in tqdm(all_users):
|
for user in tqdm(all_users):
|
||||||
cursor.execute("INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)", (user['arg'], user['title']))
|
cursor.execute(
|
||||||
usr_chatlog = export_chathistory(user['arg'])
|
"INSERT OR IGNORE INTO users (id, name) VALUES (?, ?)",
|
||||||
|
(user["arg"], user["title"]),
|
||||||
|
)
|
||||||
|
usr_chatlog = export_chathistory(user["arg"])
|
||||||
for msg in usr_chatlog:
|
for msg in usr_chatlog:
|
||||||
cursor.execute("INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)", (user['arg'], msg['fromUser'], msg['toUser'], msg['message'], msg['createTime'], str(msg['content'])))
|
cursor.execute(
|
||||||
|
"INSERT INTO chatlog (with_id, from_user, to_user, message, timest, auxiliary) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
user["arg"],
|
||||||
|
msg["fromUser"],
|
||||||
|
msg["toUser"],
|
||||||
|
msg["message"],
|
||||||
|
msg["createTime"],
|
||||||
|
str(msg["content"]),
|
||||||
|
),
|
||||||
|
)
|
||||||
connection.commit()
|
connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
108
pyproject.toml
108
pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "leann-workspace"
|
name = "leann-workspace"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.9"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core",
|
"leann-core",
|
||||||
@@ -25,16 +25,23 @@ dependencies = [
|
|||||||
"requests>=2.25.0",
|
"requests>=2.25.0",
|
||||||
"sentence-transformers>=2.2.0",
|
"sentence-transformers>=2.2.0",
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
|
# PDF parsing dependencies - essential for document processing
|
||||||
"PyPDF2>=3.0.0",
|
"PyPDF2>=3.0.0",
|
||||||
|
"pdfplumber>=0.11.0",
|
||||||
|
"pymupdf>=1.26.0",
|
||||||
|
"pypdfium2>=4.30.0",
|
||||||
|
# LlamaIndex core and readers - updated versions
|
||||||
"llama-index>=0.12.44",
|
"llama-index>=0.12.44",
|
||||||
"llama-index-readers-docling",
|
"llama-index-readers-file>=0.4.0", # Essential for PDF parsing
|
||||||
"llama-index-node-parser-docling",
|
# "llama-index-readers-docling", # Requires Python >= 3.10
|
||||||
"ipykernel==6.29.5",
|
# "llama-index-node-parser-docling", # Requires Python >= 3.10
|
||||||
"msgpack>=1.1.1",
|
|
||||||
"llama-index-vector-stores-faiss>=0.4.0",
|
"llama-index-vector-stores-faiss>=0.4.0",
|
||||||
"llama-index-embeddings-huggingface>=0.5.5",
|
"llama-index-embeddings-huggingface>=0.5.5",
|
||||||
"mlx>=0.26.3",
|
# Other dependencies
|
||||||
"mlx-lm>=0.26.0",
|
"ipykernel==6.29.5",
|
||||||
|
"msgpack>=1.1.1",
|
||||||
|
"mlx>=0.26.3; sys_platform == 'darwin'",
|
||||||
|
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
|
||||||
"psutil>=5.8.0",
|
"psutil>=5.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -42,16 +49,35 @@ dependencies = [
|
|||||||
dev = [
|
dev = [
|
||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
"pytest-cov>=4.0",
|
"pytest-cov>=4.0",
|
||||||
|
"pytest-xdist>=3.0", # For parallel test execution
|
||||||
"black>=23.0",
|
"black>=23.0",
|
||||||
"ruff>=0.1.0",
|
"ruff>=0.1.0",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"huggingface-hub>=0.20.0",
|
"huggingface-hub>=0.20.0",
|
||||||
|
"pre-commit>=3.5.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
test = [
|
||||||
|
"pytest>=7.0",
|
||||||
|
"pytest-timeout>=2.0",
|
||||||
|
"llama-index-core>=0.12.0",
|
||||||
|
"llama-index-readers-file>=0.4.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"sentence-transformers>=2.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
diskann = [
|
diskann = [
|
||||||
"leann-backend-diskann",
|
"leann-backend-diskann",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Add a new optional dependency group for document processing
|
||||||
|
documents = [
|
||||||
|
"beautifulsoup4>=4.13.0", # For HTML parsing
|
||||||
|
"python-docx>=0.8.11", # For Word documents
|
||||||
|
"openpyxl>=3.1.0", # For Excel files
|
||||||
|
"pandas>=2.2.0", # For data processing
|
||||||
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = []
|
py-modules = []
|
||||||
|
|
||||||
@@ -60,3 +86,71 @@ py-modules = []
|
|||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py310"
|
||||||
|
line-length = 100
|
||||||
|
extend-exclude = [
|
||||||
|
"third_party",
|
||||||
|
"*.egg-info",
|
||||||
|
"__pycache__",
|
||||||
|
".git",
|
||||||
|
".venv",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"N", # pep8-naming
|
||||||
|
"RUF", # ruff-specific rules
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
"B904", # raise without from
|
||||||
|
"N812", # lowercase imported as non-lowercase
|
||||||
|
"N806", # variable in function should be lowercase
|
||||||
|
"RUF012", # mutable class attributes should be annotated with typing.ClassVar
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"test/**/*.py" = ["E402"] # module level import not at top of file (common in tests)
|
||||||
|
"examples/**/*.py" = ["E402"] # module level import not at top of file (common in examples)
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
indent-style = "space"
|
||||||
|
skip-magic-trailing-comma = false
|
||||||
|
line-ending = "auto"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"ruff>=0.12.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
python_classes = ["Test*"]
|
||||||
|
python_functions = ["test_*"]
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
|
"openai: marks tests that require OpenAI API key",
|
||||||
|
]
|
||||||
|
timeout = 600
|
||||||
|
addopts = [
|
||||||
|
"-v",
|
||||||
|
"--tb=short",
|
||||||
|
"--strict-markers",
|
||||||
|
"--disable-warnings",
|
||||||
|
]
|
||||||
|
env = [
|
||||||
|
"HF_HUB_DISABLE_SYMLINKS=1",
|
||||||
|
"TOKENIZERS_PARALLELISM=false",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
import faiss
|
|
||||||
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
||||||
|
|
||||||
# print total number of nodes
|
|
||||||
print(hnsw_index.ntotal)
|
|
||||||
|
|
||||||
# print stats of the graph
|
|
||||||
print(hnsw_index.hnsw.print_neighbor_stats(0))
|
|
||||||
|
|
||||||
|
|
||||||
# save_degree_distribution
|
|
||||||
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
import faiss
|
|
||||||
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
|
|
||||||
|
|
||||||
# print total number of nodes
|
|
||||||
print(nsg_index.ntotal)
|
|
||||||
|
|
||||||
# print stats of the graph
|
|
||||||
print(nsg_index.nsg.print_neighbor_stats(0))
|
|
||||||
|
|
||||||
# save degree distribution
|
|
||||||
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
|
|
||||||
# import bitsandbytes as bnb
|
|
||||||
from bitsandbytes.nn import Linear8bitLt
|
|
||||||
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
|
|
||||||
M = 2048
|
|
||||||
N = 2048
|
|
||||||
|
|
||||||
bsz = 2048
|
|
||||||
import torch_int
|
|
||||||
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
|
|
||||||
|
|
||||||
fp16_model = nn.Sequential(
|
|
||||||
nn.Linear(M, N),
|
|
||||||
# nn.Linear(2048, 2048)
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_model = nn.Sequential(
|
|
||||||
Linear8bitLt(M, N, has_fp16_weights=False),
|
|
||||||
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
int8_model.load_state_dict(fp16_model.state_dict())
|
|
||||||
int8_model = int8_model.to(0) # Quantization happens here
|
|
||||||
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
|
|
||||||
|
|
||||||
# Create random input tensor
|
|
||||||
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
|
|
||||||
|
|
||||||
# Speed test function
|
|
||||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
# Actual timing
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_iterations
|
|
||||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
||||||
return avg_time
|
|
||||||
|
|
||||||
# Run speed tests
|
|
||||||
with torch.no_grad(): # Disable gradient calculation for inference
|
|
||||||
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
|
|
||||||
int8_time = speed_test(int8_model, input_tensor, "INT8")
|
|
||||||
|
|
||||||
# Calculate speedup
|
|
||||||
speedup = fp16_time / int8_time
|
|
||||||
print(f"INT8 is {speedup:.2f}x faster than FP16")
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
|
|
||||||
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
|
|
||||||
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
|
|
||||||
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
|
|
||||||
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
|
|
||||||
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
|
|
||||||
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
|
|
||||||
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
|
|
||||||
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
|
|
||||||
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
|
|
||||||
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
|
|
||||||
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
|
|
||||||
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
|
|
||||||
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
|
|
||||||
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
|
|
||||||
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
|
|
||||||
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
|
|
||||||
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
|
|
||||||
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
|
|
||||||
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
|
|
||||||
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
|
|
||||||
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
|
|
||||||
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
|
|
||||||
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
|
|
||||||
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
|
|
||||||
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
|
|
||||||
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
|
|
||||||
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
|
|
||||||
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
|
|
||||||
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
|
|
||||||
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
|
|
||||||
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
|
|
||||||
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
|
|
||||||
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
|
|
||||||
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
|
|
||||||
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
|
|
||||||
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
|
|
||||||
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
|
|
||||||
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
|
|
||||||
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
|
|
||||||
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
|
|
||||||
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
|
|
||||||
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
|
|
||||||
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
|
|
||||||
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
|
|
||||||
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
|
|
||||||
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
|
|
||||||
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
|
|
||||||
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
|
|
||||||
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
|
|
||||||
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
|
|
||||||
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
|
|
||||||
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
|
|
||||||
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
|
|
||||||
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
|
|
||||||
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
|
|
||||||
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
|
|
||||||
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
|
|
||||||
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
|
|
||||||
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
|
|
||||||
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
|
|
||||||
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
|
|
||||||
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
|
|
||||||
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
|
|
||||||
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
|
|
||||||
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
|
|
||||||
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
|
|
||||||
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
|
|
||||||
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
|
|
||||||
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
|
|
||||||
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
|
|
||||||
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
|
|
||||||
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
|
|
||||||
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
|
|
||||||
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
|
|
||||||
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
|
|
||||||
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
|
|
||||||
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
|
|
||||||
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
|
|
||||||
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
|
|
||||||
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
|
|
||||||
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
|
|
||||||
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
|
|
||||||
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
|
|
||||||
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
|
|
||||||
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
|
|
||||||
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
|
|
||||||
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
|
|
||||||
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
|
|
||||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 45 KiB |
@@ -1,594 +0,0 @@
|
|||||||
# python embedd_micro.py --use_int8 Fastest
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torchao import quantize_
|
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkConfig:
|
|
||||||
model_path: str
|
|
||||||
batch_sizes: List[int]
|
|
||||||
seq_length: int
|
|
||||||
num_runs: int
|
|
||||||
use_fp16: bool = True
|
|
||||||
use_int4: bool = False
|
|
||||||
use_int8: bool = False # Add this parameter
|
|
||||||
use_cuda_graphs: bool = False
|
|
||||||
use_flash_attention: bool = False
|
|
||||||
use_linear8bitlt: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
|
||||||
if batch_size not in self.graphs:
|
|
||||||
self.graphs[batch_size] = CUDAGraphWrapper(
|
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
|
||||||
|
|
||||||
# Warm up
|
|
||||||
self._warmup()
|
|
||||||
|
|
||||||
# Capture graph
|
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self.graph):
|
|
||||||
self.static_output = self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
||||||
self.static_input.copy_(input_ids)
|
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
|
||||||
self.graph.replay()
|
|
||||||
return self.static_output
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
|
||||||
"""Applies various optimizations to the model."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Cannot optimize None model")
|
|
||||||
|
|
||||||
# Move to GPU
|
|
||||||
model = model.cuda()
|
|
||||||
print("- Model moved to GPU")
|
|
||||||
|
|
||||||
# FP16
|
|
||||||
if config.use_fp16 and not config.use_int4:
|
|
||||||
model = model.half()
|
|
||||||
# use torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
print("- Using FP16 precision")
|
|
||||||
|
|
||||||
# Check if using SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Flash Attention
|
|
||||||
if config.use_flash_attention:
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
|
||||||
print("- Flash Attention 2 available")
|
|
||||||
if hasattr(model.config, "attention_mode"):
|
|
||||||
model.config.attention_mode = "flash_attention_2"
|
|
||||||
print(" - Enabled Flash Attention 2 mode")
|
|
||||||
except ImportError:
|
|
||||||
print("- Flash Attention not available")
|
|
||||||
|
|
||||||
# Memory efficient attention
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
else:
|
|
||||||
print("- Model doesn't support xformers")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start_event.record()
|
|
||||||
yield
|
|
||||||
self.end_event.record()
|
|
||||||
self.end_event.synchronize()
|
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
"""Main benchmark runner."""
|
|
||||||
|
|
||||||
def __init__(self, config: BenchmarkConfig):
|
|
||||||
self.config = config
|
|
||||||
try:
|
|
||||||
self.model = self._load_model()
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError("Model initialization failed - model is None")
|
|
||||||
|
|
||||||
self.cuda_graphs = (
|
|
||||||
CUDAGraphContainer(self.model, config.seq_length)
|
|
||||||
if config.use_cuda_graphs
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.timer = Timer()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Int4 quantization using HuggingFace integration
|
|
||||||
if self.config.use_int4:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import torch
|
|
||||||
# set default to half
|
|
||||||
torch.set_default_dtype(torch.float16)
|
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
torch_dtype=compute_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 定义替换函数
|
|
||||||
def replace_linear_with_linear8bitlt(model):
|
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
|
||||||
for name, module in list(model.named_children()):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
# 获取原始线性层的参数
|
|
||||||
in_features = module.in_features
|
|
||||||
out_features = module.out_features
|
|
||||||
bias = module.bias is not None
|
|
||||||
|
|
||||||
# 创建8bit线性层
|
|
||||||
# print size
|
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
|
||||||
in_features,
|
|
||||||
out_features,
|
|
||||||
bias=bias,
|
|
||||||
has_fp16_weights=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 复制权重和偏置
|
|
||||||
new_module.weight.data = module.weight.data
|
|
||||||
if bias:
|
|
||||||
new_module.bias.data = module.bias.data
|
|
||||||
|
|
||||||
# 替换模块
|
|
||||||
setattr(model, name, new_module)
|
|
||||||
else:
|
|
||||||
# 递归处理子模块
|
|
||||||
replace_linear_with_linear8bitlt(module)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
# 替换所有线性层
|
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 使用原来的Int4量化方法
|
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
|
||||||
|
|
||||||
# Create quantization config
|
|
||||||
|
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
|
||||||
quantization_config = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type="nf4"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
|
||||||
|
|
||||||
# Load model directly with quantization config
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
torch_dtype=compute_dtype,
|
|
||||||
device_map="auto" # Let HF decide on device mapping
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if model loaded successfully
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Model loading returned None")
|
|
||||||
|
|
||||||
print(f"- Model type: {type(model)}")
|
|
||||||
|
|
||||||
# Apply optimizations directly here
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
|
||||||
else:
|
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
|
||||||
print("- Model already on GPU due to device_map='auto'")
|
|
||||||
|
|
||||||
# Skip FP16 conversion since we specified compute_dtype
|
|
||||||
print(f"- Using {compute_dtype} for compute dtype")
|
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Try xformers if available
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
else:
|
|
||||||
print("- Model doesn't support xformers")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
# Set to eval mode
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
# Int8 quantization using HuggingFace integration
|
|
||||||
# Int8 quantization using TorchAO
|
|
||||||
elif self.config.use_int8:
|
|
||||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
|
||||||
|
|
||||||
# Import the quantize_ function and the quantization config
|
|
||||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
|
||||||
print("- Successfully imported TorchAO")
|
|
||||||
|
|
||||||
# Load model normally first
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
self.config.model_path,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("- Model loaded in full precision")
|
|
||||||
print(f"- Model type: {type(model)}")
|
|
||||||
|
|
||||||
# Apply quantization - call the function to get the config, then apply it
|
|
||||||
# quantize_(model, int8_dynamic_activation_int8_weight())
|
|
||||||
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
|
|
||||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
|
||||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
||||||
print("- Model successfully quantized with int8 weights and int8 activations")
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
# For older PyTorch versions that have issues with tensor subclasses
|
|
||||||
from torchao.utils import unwrap_tensor_subclass
|
|
||||||
import torch
|
|
||||||
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
|
|
||||||
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
|
|
||||||
unwrap_tensor_subclass(model)
|
|
||||||
|
|
||||||
# Apply optimizations
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Set to eval mode
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
# For better performance with int8 dynamic quantization
|
|
||||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
|
||||||
print("- Enabled fusion of int matmul with mul operations")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Standard loading for FP16/FP32
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
|
||||||
print("- Model loaded in standard precision")
|
|
||||||
print(f"- Model type: {type(model)}")
|
|
||||||
|
|
||||||
# Apply standard optimizations
|
|
||||||
# set default to half
|
|
||||||
import torch
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
|
||||||
model = model.half()
|
|
||||||
# add torch compile
|
|
||||||
model = torch.compile(model)
|
|
||||||
|
|
||||||
# Final check to ensure model is not None
|
|
||||||
if model is None:
|
|
||||||
raise ValueError("Model is None after optimization")
|
|
||||||
|
|
||||||
print(f"- Final model type: {type(model)}")
|
|
||||||
return model
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR loading model: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
|
||||||
else:
|
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# Reset peak memory stats
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
|
||||||
times = []
|
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
|
||||||
cuda_graph_wrapper = (
|
|
||||||
self.cuda_graphs.get_or_create(batch_size)
|
|
||||||
if self.cuda_graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
|
||||||
print(f"Input shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
|
||||||
try:
|
|
||||||
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
|
|
||||||
if i == 0: # Only print on first run
|
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
|
||||||
times.append(elapsed_time)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during inference: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not times:
|
|
||||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
avg_time = np.mean(times)
|
|
||||||
std_time = np.std(times)
|
|
||||||
throughput = batch_size / avg_time
|
|
||||||
|
|
||||||
results[batch_size] = {
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"std_time": std_time,
|
|
||||||
"throughput": throughput,
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
|
||||||
|
|
||||||
# Log memory usage
|
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
|
||||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
|
||||||
|
|
||||||
# Add memory info to results
|
|
||||||
for batch_size in results:
|
|
||||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="facebook/contriever",
|
|
||||||
help="Path to the model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_sizes",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
|
|
||||||
help="Comma-separated list of batch sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq_length",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Sequence length for input",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_runs",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of runs for each batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_fp16",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable FP16 inference",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_int4",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable INT4 quantization using bitsandbytes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_int8",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_cuda_graphs",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA Graphs optimization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attention",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Flash Attention 2 if available",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_linear8bitlt",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Linear8bitLt quantization for all linear layers",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Print arguments for debugging
|
|
||||||
print("\nCommand line arguments:")
|
|
||||||
for arg, value in vars(args).items():
|
|
||||||
print(f"- {arg}: {value}")
|
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
|
||||||
model_path=args.model_path,
|
|
||||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
num_runs=args.num_runs,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int4=args.use_int4,
|
|
||||||
use_int8=args.use_int8, # Add this line
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
use_flash_attention=args.use_flash_attention,
|
|
||||||
use_linear8bitlt=args.use_linear8bitlt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print configuration for debugging
|
|
||||||
print("\nBenchmark configuration:")
|
|
||||||
for field, value in vars(config).items():
|
|
||||||
print(f"- {field}: {value}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
benchmark = Benchmark(config)
|
|
||||||
results = benchmark.run()
|
|
||||||
|
|
||||||
# Save results to file
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Create results directory if it doesn't exist
|
|
||||||
os.makedirs("results", exist_ok=True)
|
|
||||||
|
|
||||||
# Generate filename based on configuration
|
|
||||||
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
|
|
||||||
model_name = os.path.basename(config.model_path)
|
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
|
||||||
|
|
||||||
# Save results
|
|
||||||
with open(output_file, "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
|
||||||
"results": {str(k): v for k, v in results.items()}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
indent=2
|
|
||||||
)
|
|
||||||
print(f"Results saved to {output_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Benchmark failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,376 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import AutoModel
|
|
||||||
from tqdm import tqdm
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import math
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkConfig:
|
|
||||||
model_path: str
|
|
||||||
batch_sizes: List[int]
|
|
||||||
seq_length: int
|
|
||||||
num_runs: int
|
|
||||||
use_fp16: bool = True
|
|
||||||
use_cuda_graphs: bool = False
|
|
||||||
use_flash_attention: bool = False
|
|
||||||
max_batch_size: int = 256 # Maximum batch size before splitting
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphContainer:
|
|
||||||
"""Container for managing CUDA graphs for different batch sizes."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
|
|
||||||
self.model = model
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.max_batch_size = max_batch_size
|
|
||||||
self.graphs: Dict[int, CUDAGraphWrapper] = {}
|
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
|
|
||||||
# For CUDA graphs, we always use the actual batch size or max_batch_size
|
|
||||||
effective_batch_size = min(batch_size, self.max_batch_size)
|
|
||||||
|
|
||||||
if effective_batch_size not in self.graphs:
|
|
||||||
self.graphs[effective_batch_size] = CUDAGraphWrapper(
|
|
||||||
self.model, effective_batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[effective_batch_size]
|
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphWrapper:
|
|
||||||
"""Wrapper for CUDA graph capture and replay."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
|
||||||
self.model = model
|
|
||||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
|
||||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
|
||||||
|
|
||||||
# Warm up
|
|
||||||
self._warmup()
|
|
||||||
|
|
||||||
# Capture graph
|
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(self.graph):
|
|
||||||
self.static_output = self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
self.model(
|
|
||||||
input_ids=self.static_input,
|
|
||||||
attention_mask=self.static_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
||||||
self.static_input.copy_(input_ids)
|
|
||||||
self.static_attention_mask.copy_(attention_mask)
|
|
||||||
self.graph.replay()
|
|
||||||
return self.static_output
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptimizer:
|
|
||||||
"""Applies various optimizations to the model."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
|
||||||
print("\nApplying model optimizations:")
|
|
||||||
|
|
||||||
# Move to GPU
|
|
||||||
model = model.cuda()
|
|
||||||
print("- Model moved to GPU")
|
|
||||||
|
|
||||||
# FP16
|
|
||||||
if config.use_fp16:
|
|
||||||
model = model.half()
|
|
||||||
print("- Using FP16 precision")
|
|
||||||
|
|
||||||
# Check if using SDPA
|
|
||||||
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
|
||||||
# No need to do anything as it's automatically enabled
|
|
||||||
else:
|
|
||||||
print("- PyTorch SDPA not available")
|
|
||||||
|
|
||||||
# Flash Attention
|
|
||||||
if config.use_flash_attention:
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
|
||||||
print("- Flash Attention 2 available")
|
|
||||||
if hasattr(model.config, "attention_mode"):
|
|
||||||
model.config.attention_mode = "flash_attention_2"
|
|
||||||
print(" - Enabled Flash Attention 2 mode")
|
|
||||||
except ImportError:
|
|
||||||
print("- Flash Attention not available")
|
|
||||||
|
|
||||||
# Optimize LayerNorm
|
|
||||||
try:
|
|
||||||
num_layernorms = 0
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, torch.nn.LayerNorm):
|
|
||||||
module.forward = torch.jit.script(module.forward)
|
|
||||||
num_layernorms += 1
|
|
||||||
if num_layernorms > 0:
|
|
||||||
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"- LayerNorm optimization failed: {e}")
|
|
||||||
|
|
||||||
# Memory efficient attention
|
|
||||||
try:
|
|
||||||
from xformers.ops import memory_efficient_attention
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
|
||||||
print("- Enabled xformers memory efficient attention")
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
print("- Xformers not available")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
print("- Model set to eval mode")
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
|
||||||
"""Handles accurate GPU timing using CUDA events."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start_event.record()
|
|
||||||
yield
|
|
||||||
self.end_event.record()
|
|
||||||
self.end_event.synchronize()
|
|
||||||
|
|
||||||
def elapsed_time(self) -> float:
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
|
||||||
"""Main benchmark runner."""
|
|
||||||
|
|
||||||
def __init__(self, config: BenchmarkConfig):
|
|
||||||
self.config = config
|
|
||||||
self.model = self._load_model()
|
|
||||||
self.cuda_graphs = (
|
|
||||||
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
|
|
||||||
if config.use_cuda_graphs
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
self.timer = Timer()
|
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
|
||||||
return ModelOptimizer.optimize(model, self.config)
|
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
|
||||||
return torch.randint(
|
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
original_batch_size = input_ids.shape[0]
|
|
||||||
print(f"Original input_ids shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Split large batches to avoid OOM
|
|
||||||
max_batch_size = self.config.max_batch_size
|
|
||||||
if original_batch_size > max_batch_size:
|
|
||||||
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
|
|
||||||
total_time = 0
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(0, original_batch_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, original_batch_size)
|
|
||||||
batch_slice = input_ids[i:end_idx]
|
|
||||||
mask_slice = attention_mask[i:end_idx]
|
|
||||||
|
|
||||||
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
|
|
||||||
|
|
||||||
# Use CUDA graph if available (with the smaller batch size)
|
|
||||||
chunk_cuda_graph = None
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
|
|
||||||
|
|
||||||
with self.timer.timing():
|
|
||||||
if chunk_cuda_graph is not None:
|
|
||||||
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
|
|
||||||
else:
|
|
||||||
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
|
|
||||||
|
|
||||||
total_time += self.timer.elapsed_time()
|
|
||||||
outputs.append(chunk_output.last_hidden_state)
|
|
||||||
|
|
||||||
# Combine outputs
|
|
||||||
combined_output = torch.cat(outputs, dim=0)
|
|
||||||
print(f"Combined output shape: {combined_output.shape}")
|
|
||||||
|
|
||||||
# Create a wrapper object similar to model output to maintain consistency
|
|
||||||
class DummyOutput:
|
|
||||||
def __init__(self, hidden_states):
|
|
||||||
self.last_hidden_state = hidden_states
|
|
||||||
|
|
||||||
output = DummyOutput(combined_output)
|
|
||||||
return total_time, output
|
|
||||||
else:
|
|
||||||
# Process normally for small batches
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
|
||||||
if cuda_graph_wrapper is not None:
|
|
||||||
output = cuda_graph_wrapper(input_ids, attention_mask)
|
|
||||||
else:
|
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
print(f"Output shape: {output.last_hidden_state.shape}")
|
|
||||||
return self.timer.elapsed_time(), output
|
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for batch_size in self.config.batch_sizes:
|
|
||||||
print(f"\nTesting batch size: {batch_size}")
|
|
||||||
times = []
|
|
||||||
|
|
||||||
# Get or create CUDA graph for this batch size
|
|
||||||
cuda_graph_wrapper = None
|
|
||||||
if self.cuda_graphs is not None:
|
|
||||||
if batch_size <= self.config.max_batch_size:
|
|
||||||
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
|
|
||||||
else:
|
|
||||||
# For large batches, we'll use the max_batch_size graph in chunks
|
|
||||||
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
|
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
|
||||||
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
|
|
||||||
times.append(elapsed_time)
|
|
||||||
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
avg_time = np.mean(times)
|
|
||||||
std_time = np.std(times)
|
|
||||||
throughput = batch_size / avg_time
|
|
||||||
|
|
||||||
results[batch_size] = {
|
|
||||||
"avg_time": avg_time,
|
|
||||||
"std_time": std_time,
|
|
||||||
"throughput": throughput,
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
|
||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_path",
|
|
||||||
type=str,
|
|
||||||
default="facebook/contriever",
|
|
||||||
help="Path to the model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_sizes",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
|
|
||||||
help="Comma-separated list of batch sizes",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seq_length",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Sequence length for input",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_runs",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of runs for each batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no_fp16",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable FP16 inference",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_cuda_graphs",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA Graphs optimization",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use_flash_attention",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable Flash Attention 2 if available",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Maximum batch size before splitting to prevent OOM",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
|
||||||
model_path=args.model_path,
|
|
||||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
|
||||||
seq_length=args.seq_length,
|
|
||||||
num_runs=args.num_runs,
|
|
||||||
use_fp16=not args.no_fp16,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
use_flash_attention=args.use_flash_attention,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
benchmark = Benchmark(config)
|
|
||||||
results = benchmark.run()
|
|
||||||
|
|
||||||
# Print overall summary
|
|
||||||
print("\n===== BENCHMARK SUMMARY =====")
|
|
||||||
print(f"Model: {config.model_path}")
|
|
||||||
print(f"Sequence Length: {config.seq_length}")
|
|
||||||
print(f"FP16: {config.use_fp16}")
|
|
||||||
print(f"CUDA Graphs: {config.use_cuda_graphs}")
|
|
||||||
print(f"Flash Attention: {config.use_flash_attention}")
|
|
||||||
print(f"Max Batch Size: {config.max_batch_size}")
|
|
||||||
print("\nResults:")
|
|
||||||
|
|
||||||
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
|
|
||||||
print("-" * 50)
|
|
||||||
for bs in sorted(results.keys()):
|
|
||||||
r = results[bs]
|
|
||||||
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,218 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import time
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# Import necessary functions from the quantize.py file
|
|
||||||
def get_group_qparams(w, n_bit=4, groupsize=128):
|
|
||||||
# needed for GPTQ with padding
|
|
||||||
if groupsize > w.shape[-1]:
|
|
||||||
groupsize = w.shape[-1]
|
|
||||||
assert groupsize > 1
|
|
||||||
assert w.shape[-1] % groupsize == 0
|
|
||||||
assert w.dim() == 2
|
|
||||||
|
|
||||||
to_quant = w.reshape(-1, groupsize)
|
|
||||||
assert torch.isnan(to_quant).sum() == 0
|
|
||||||
|
|
||||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
|
||||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
|
||||||
max_int = 2**n_bit - 1
|
|
||||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
|
||||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
|
||||||
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
|
||||||
torch.bfloat16
|
|
||||||
).reshape(w.shape[0], -1)
|
|
||||||
|
|
||||||
def pack_scales_and_zeros(scales, zeros):
|
|
||||||
assert scales.shape == zeros.shape
|
|
||||||
assert scales.dtype == torch.bfloat16
|
|
||||||
assert zeros.dtype == torch.bfloat16
|
|
||||||
return (
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
scales.reshape(scales.size(0), scales.size(1), 1),
|
|
||||||
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
|
||||||
],
|
|
||||||
2,
|
|
||||||
)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
|
||||||
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
|
||||||
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
|
||||||
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
|
||||||
return w_int32, scales_and_zeros
|
|
||||||
|
|
||||||
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
|
||||||
assert groupsize > 1
|
|
||||||
# needed for GPTQ single column quantize
|
|
||||||
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
|
||||||
groupsize = w.shape[-1]
|
|
||||||
|
|
||||||
assert w.shape[-1] % groupsize == 0
|
|
||||||
assert w.dim() == 2
|
|
||||||
|
|
||||||
to_quant = w.reshape(-1, groupsize)
|
|
||||||
assert torch.isnan(to_quant).sum() == 0
|
|
||||||
|
|
||||||
scales = scales.reshape(-1, 1)
|
|
||||||
zeros = zeros.reshape(-1, 1)
|
|
||||||
min_val = zeros - scales * (2 ** (n_bit - 1))
|
|
||||||
max_int = 2**n_bit - 1
|
|
||||||
min_int = 0
|
|
||||||
w_int32 = (
|
|
||||||
to_quant.sub(min_val)
|
|
||||||
.div(scales)
|
|
||||||
.round()
|
|
||||||
.clamp_(min_int, max_int)
|
|
||||||
.to(torch.int32)
|
|
||||||
.reshape_as(w)
|
|
||||||
)
|
|
||||||
|
|
||||||
return w_int32
|
|
||||||
|
|
||||||
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
|
||||||
weight_int32, scales_and_zeros = group_quantize_tensor(
|
|
||||||
weight_bf16, n_bit=4, groupsize=groupsize
|
|
||||||
)
|
|
||||||
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
|
|
||||||
return weight_int4pack, scales_and_zeros
|
|
||||||
|
|
||||||
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
|
||||||
origin_x_size = x.size()
|
|
||||||
x = x.reshape(-1, origin_x_size[-1])
|
|
||||||
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
|
|
||||||
new_shape = origin_x_size[:-1] + (out_features,)
|
|
||||||
c = c.reshape(new_shape)
|
|
||||||
return c
|
|
||||||
|
|
||||||
class WeightOnlyInt4Linear(torch.nn.Module):
|
|
||||||
__constants__ = ['in_features', 'out_features']
|
|
||||||
in_features: int
|
|
||||||
out_features: int
|
|
||||||
weight: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, in_features: int, out_features: int,
|
|
||||||
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.groupsize = groupsize
|
|
||||||
self.inner_k_tiles = inner_k_tiles
|
|
||||||
|
|
||||||
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
|
||||||
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
|
|
||||||
self.register_buffer(
|
|
||||||
"weight",
|
|
||||||
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"scales_and_zeros",
|
|
||||||
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
input = input.to(torch.bfloat16)
|
|
||||||
return linear_forward_int4(
|
|
||||||
input,
|
|
||||||
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
|
||||||
)
|
|
||||||
|
|
||||||
# Define dimensions that satisfy the requirements for INT4 quantization
|
|
||||||
# in_features must be divisible by inner_k_tiles * 16
|
|
||||||
# out_features must be divisible by 8
|
|
||||||
in_features = 1024 # Must be divisible by inner_k_tiles * 16
|
|
||||||
out_features = 2048 # Must be divisible by 8
|
|
||||||
groupsize = 128
|
|
||||||
inner_k_tiles = 8
|
|
||||||
|
|
||||||
# Create models
|
|
||||||
fp16_model = nn.Sequential(
|
|
||||||
nn.Linear(in_features, out_features, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create INT4 model
|
|
||||||
int4_model = nn.Sequential(
|
|
||||||
WeightOnlyInt4Linear(in_features, out_features, bias=False,
|
|
||||||
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Quantize the weights and set up the INT4 model
|
|
||||||
with torch.no_grad():
|
|
||||||
# Convert FP16 weights to INT4
|
|
||||||
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
|
|
||||||
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
|
|
||||||
fp16_weight, groupsize, inner_k_tiles
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the quantized weights in the INT4 model
|
|
||||||
int4_model[0].weight.copy_(weight_int4pack)
|
|
||||||
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
|
|
||||||
|
|
||||||
# Move models to GPU
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
fp16_model = fp16_model.to(device)
|
|
||||||
int4_model = int4_model.to(device)
|
|
||||||
|
|
||||||
# Create random input tensor
|
|
||||||
batch_size = 1024
|
|
||||||
input_tensor = torch.randn(batch_size, in_features, device=device)
|
|
||||||
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
|
|
||||||
|
|
||||||
# Speed test function
|
|
||||||
def speed_test(model, input_tensor, name, num_iterations=100):
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
# Actual timing
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = model(input_tensor)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
avg_time = (end_time - start_time) / num_iterations
|
|
||||||
print(f"{name} model: {avg_time:.6f} seconds per iteration")
|
|
||||||
return avg_time
|
|
||||||
|
|
||||||
# Run speed tests
|
|
||||||
with torch.no_grad(): # Disable gradient calculation for inference
|
|
||||||
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
|
|
||||||
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
|
|
||||||
|
|
||||||
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
|
|
||||||
int4_time = speed_test(int4_model, input_tensor, "INT4")
|
|
||||||
|
|
||||||
# Calculate speedup
|
|
||||||
speedup = fp16_time / int4_time
|
|
||||||
print(f"INT4 is {speedup:.2f}x faster than FP16")
|
|
||||||
|
|
||||||
# Calculate memory savings
|
|
||||||
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
|
|
||||||
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
|
|
||||||
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
|
|
||||||
|
|
||||||
memory_reduction = fp16_memory / int4_memory
|
|
||||||
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
|
|
||||||
|
|
||||||
# Check accuracy
|
|
||||||
with torch.no_grad():
|
|
||||||
fp16_output = fp16_model(input_tensor_bf16)
|
|
||||||
int4_output = int4_model(input_tensor)
|
|
||||||
|
|
||||||
# Calculate error metrics
|
|
||||||
abs_error = torch.abs(fp16_output - int4_output)
|
|
||||||
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
|
|
||||||
|
|
||||||
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
|
|
||||||
print(f"Max absolute error: {abs_error.max().item():.6f}")
|
|
||||||
print(f"Mean relative error: {rel_error.mean().item():.6f}")
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
import torch
|
|
||||||
import nvmath.bindings.cublas
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
# 创建 CUBLAS 句柄
|
|
||||||
handle = nvmath.bindings.cublas.create()
|
|
||||||
|
|
||||||
# 准备数据 - 使用 uint8 类型,并确保内存连续
|
|
||||||
m, n, k = 64, 32, 48
|
|
||||||
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
|
|
||||||
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
|
|
||||||
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
|
|
||||||
|
|
||||||
# 确保张量在 CUDA 上
|
|
||||||
assert a.is_cuda and b.is_cuda and c.is_cuda
|
|
||||||
# 确保张量是连续的
|
|
||||||
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
|
|
||||||
|
|
||||||
# 获取指针
|
|
||||||
a_ptr = a.data_ptr()
|
|
||||||
b_ptr = b.data_ptr()
|
|
||||||
c_ptr = c.data_ptr()
|
|
||||||
|
|
||||||
# 设置参数
|
|
||||||
transa = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
transb = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
transc = 0 # CUBLAS_OP_N (不转置)
|
|
||||||
|
|
||||||
# 设置偏置值
|
|
||||||
a_bias = 0
|
|
||||||
b_bias = 0
|
|
||||||
c_bias = 0
|
|
||||||
|
|
||||||
# 设置正确的 leading dimensions
|
|
||||||
lda = k # A 的 leading dimension
|
|
||||||
ldb = n # B 的 leading dimension
|
|
||||||
ldc = n # C 的 leading dimension
|
|
||||||
|
|
||||||
c_mult = 1
|
|
||||||
c_shift = 0
|
|
||||||
|
|
||||||
# 打印调试信息
|
|
||||||
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
|
|
||||||
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
|
|
||||||
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用 uint8gemm_bias
|
|
||||||
nvmath.bindings.cublas.uint8gemm_bias(
|
|
||||||
handle,
|
|
||||||
transa, transb, transc,
|
|
||||||
m, n, k,
|
|
||||||
a_ptr, a_bias, lda,
|
|
||||||
b_ptr, b_bias, ldb,
|
|
||||||
c_ptr, c_bias, ldc,
|
|
||||||
c_mult, c_shift
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
# 尝试使用 ctypes 转换指针
|
|
||||||
a_ptr_c = ctypes.c_void_p(a_ptr).value
|
|
||||||
b_ptr_c = ctypes.c_void_p(b_ptr).value
|
|
||||||
c_ptr_c = ctypes.c_void_p(c_ptr).value
|
|
||||||
|
|
||||||
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
|
|
||||||
|
|
||||||
# 再次尝试调用
|
|
||||||
nvmath.bindings.cublas.uint8gemm_bias(
|
|
||||||
handle,
|
|
||||||
transa, transb, transc,
|
|
||||||
m, n, k,
|
|
||||||
a_ptr_c, a_bias, lda,
|
|
||||||
b_ptr_c, b_bias, ldb,
|
|
||||||
c_ptr_c, c_bias, ldc,
|
|
||||||
c_mult, c_shift
|
|
||||||
)
|
|
||||||
|
|
||||||
# 销毁 CUBLAS 句柄
|
|
||||||
nvmath.bindings.cublas.destroy(handle)
|
|
||||||
|
|
||||||
# 打印结果
|
|
||||||
print("Result:")
|
|
||||||
print(c)
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
|
||||||
from llmcompressor import oneshot
|
|
||||||
|
|
||||||
# Select quantization algorithm. In this case, we:
|
|
||||||
# * apply SmoothQuant to make the activations easier to quantize
|
|
||||||
# * quantize the weights to int8 with GPTQ (static per channel)
|
|
||||||
# * quantize the activations to int8 (dynamic per token)
|
|
||||||
recipe = [
|
|
||||||
SmoothQuantModifier(smoothing_strength=0.8),
|
|
||||||
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply quantization using the built in open_platypus dataset.
|
|
||||||
# * See examples for demos showing how to pass a custom calibration set
|
|
||||||
oneshot(
|
|
||||||
model="facebook/contriever",
|
|
||||||
dataset="open_platypus",
|
|
||||||
recipe=recipe,
|
|
||||||
output_dir="contriever-INT4",
|
|
||||||
max_seq_length=2048,
|
|
||||||
num_calibration_samples=512,
|
|
||||||
)
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
"""
|
|
||||||
This example demonstrates basic matrix multiplication of FP8 tensors.
|
|
||||||
|
|
||||||
In narrow-precision operations, quantization scales must be provided for each tensor. These
|
|
||||||
scales are used to dequantize input operands and quantize the result. Without proper
|
|
||||||
scaling, the results of FP8 operations will likely exceed the type's range.
|
|
||||||
|
|
||||||
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
|
|
||||||
capability 8.9 or higher.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import nvmath
|
|
||||||
|
|
||||||
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
|
|
||||||
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
|
|
||||||
# transpose it.
|
|
||||||
m, n, k = 64, 32, 48
|
|
||||||
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
|
|
||||||
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
|
|
||||||
|
|
||||||
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
|
|
||||||
# range of the data type used. Scales can be provided either as a dictionary or as a
|
|
||||||
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
|
|
||||||
scales = {"a": 1, "b": 1, "d": 0.1}
|
|
||||||
|
|
||||||
# Perform the multiplication. The result of the multiplication will be:
|
|
||||||
# (scales.a * A) @ (scales.b * B) * scales.d
|
|
||||||
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
|
|
||||||
|
|
||||||
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
|
|
||||||
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
|
|
||||||
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
|
|
||||||
print(result_without_scaling)
|
|
||||||
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
|
|
||||||
print(result)
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def save_model_in_pth_format(model_name, output_dir):
|
|
||||||
"""
|
|
||||||
Download a model from Hugging Face and save it in PTH format
|
|
||||||
for use with quantization benchmarks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model on Hugging Face
|
|
||||||
output_dir: Directory to save the model
|
|
||||||
"""
|
|
||||||
print(f"Loading model {model_name}...")
|
|
||||||
|
|
||||||
# Create output directory if it doesn't exist
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Load tokenizer and model
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
low_cpu_mem_usage=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save tokenizer
|
|
||||||
tokenizer.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
# Extract and save the model weights in PTH format
|
|
||||||
model_state_dict = model.state_dict()
|
|
||||||
|
|
||||||
# Save the model weights
|
|
||||||
model_path = Path(output_dir) / "model.pth"
|
|
||||||
torch.save(model_state_dict, model_path)
|
|
||||||
|
|
||||||
print(f"Model saved to {model_path}")
|
|
||||||
|
|
||||||
# Print model size information
|
|
||||||
param_count = sum(p.numel() for p in model.parameters())
|
|
||||||
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
|
|
||||||
|
|
||||||
print(f"Model parameters: {param_count:,}")
|
|
||||||
print(f"Model size: {model_size_mb:.2f} MB")
|
|
||||||
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Use a small model for testing
|
|
||||||
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
||||||
output_dir = "./tinyllama-1.1b-chat"
|
|
||||||
|
|
||||||
model_path = save_model_in_pth_format(model_name, output_dir)
|
|
||||||
|
|
||||||
print("\nYou can now use this model with the INT4 benchmark script.")
|
|
||||||
print("Example command:")
|
|
||||||
print(f"python int4benchmark.py --model_path {model_path}")
|
|
||||||
@@ -1,677 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "cab91cfc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import copy\n",
|
|
||||||
"import dataclasses\n",
|
|
||||||
"import os\n",
|
|
||||||
"import time\n",
|
|
||||||
"import pathlib\n",
|
|
||||||
"import itertools\n",
|
|
||||||
"import multiprocessing\n",
|
|
||||||
"import scipy\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"import pickle\n",
|
|
||||||
"import gzip\n",
|
|
||||||
"import threading\n",
|
|
||||||
"import queue\n",
|
|
||||||
"import pytz\n",
|
|
||||||
"import traceback\n",
|
|
||||||
"from datetime import datetime\n",
|
|
||||||
"from tqdm.auto import tqdm, trange\n",
|
|
||||||
"from typing import Any\n",
|
|
||||||
"\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"import matplotlib.ticker as mtick\n",
|
|
||||||
"%matplotlib inline\n",
|
|
||||||
"%config InlineBackend.figure_format='retina'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "8d24fbd7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Sat Apr 12 00:10:05 2025 \n",
|
|
||||||
"+-----------------------------------------------------------------------------------------+\n",
|
|
||||||
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
|
|
||||||
"|-----------------------------------------+------------------------+----------------------+\n",
|
|
||||||
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
|
||||||
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
|
||||||
"| | | MIG M. |\n",
|
|
||||||
"|=========================================+========================+======================|\n",
|
|
||||||
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
|
|
||||||
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
|
|
||||||
"| | | N/A |\n",
|
|
||||||
"+-----------------------------------------+------------------------+----------------------+\n",
|
|
||||||
" \n",
|
|
||||||
"+-----------------------------------------------------------------------------------------+\n",
|
|
||||||
"| Processes: |\n",
|
|
||||||
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
|
||||||
"| ID ID Usage |\n",
|
|
||||||
"|=========================================================================================|\n",
|
|
||||||
"| No running processes found |\n",
|
|
||||||
"+-----------------------------------------------------------------------------------------+\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"!nvidia-smi"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "538b2c11",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
|
|
||||||
" latency = []\n",
|
|
||||||
" \n",
|
|
||||||
" # First run, ignore min_secs\n",
|
|
||||||
" if f_setup is not None:\n",
|
|
||||||
" f_setup()\n",
|
|
||||||
" st = time.perf_counter_ns()\n",
|
|
||||||
" f()\n",
|
|
||||||
" ed = time.perf_counter_ns()\n",
|
|
||||||
" latency.append((ed-st)/1e9)\n",
|
|
||||||
" \n",
|
|
||||||
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
|
|
||||||
" min_nanos = int(min_secs * 1e9)\n",
|
|
||||||
" start_nanos = time.perf_counter_ns()\n",
|
|
||||||
" while True:\n",
|
|
||||||
" now_nanos = time.perf_counter_ns()\n",
|
|
||||||
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
|
|
||||||
" break\n",
|
|
||||||
" if f_setup is not None:\n",
|
|
||||||
" f_setup()\n",
|
|
||||||
" st = time.perf_counter_ns()\n",
|
|
||||||
" f()\n",
|
|
||||||
" ed = time.perf_counter_ns()\n",
|
|
||||||
" latency.append((ed-st)/1e9)\n",
|
|
||||||
" return np.array(latency)\n",
|
|
||||||
"\n",
|
|
||||||
"def tail_mean(xs, skip=0.2):\n",
|
|
||||||
" return xs[int(len(xs) * skip):].mean()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "02c9c9b1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"torch.set_grad_enabled(False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "3405fdc7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
|
|
||||||
"seqlen_list = [256]\n",
|
|
||||||
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "10dc981a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[(12, 256), (3, 256)]\n",
|
|
||||||
"[256]\n",
|
|
||||||
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print(nd_list)\n",
|
|
||||||
"print(seqlen_list)\n",
|
|
||||||
"print(bs_list)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "7e0ee385",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
|
|
||||||
" seqlen_list = [1] + seqlen_list\n",
|
|
||||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
|
||||||
" pbar = tqdm(total=total)\n",
|
|
||||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
|
||||||
" h = n * d\n",
|
|
||||||
" maxbs = max(bs_list)\n",
|
|
||||||
" print(maxbs, n, d, seqlen)\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
|
||||||
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" for bs in reversed(bs_list):\n",
|
|
||||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
|
||||||
" def run():\n",
|
|
||||||
" torch.matmul(X[:bs], W)\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" def clear_cache():\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
|
||||||
" l = tail_mean(latency)\n",
|
|
||||||
" out.append({\n",
|
|
||||||
" \"n\": n,\n",
|
|
||||||
" \"d\": d,\n",
|
|
||||||
" \"seqlen\": seqlen,\n",
|
|
||||||
" \"bs\": bs,\n",
|
|
||||||
" \"latency\": l\n",
|
|
||||||
" })\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" del cache, X, W\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" pbar.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "c206a502",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
|
|
||||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
|
||||||
" pbar = tqdm(total=total)\n",
|
|
||||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
|
||||||
" h = n * d\n",
|
|
||||||
" try:\n",
|
|
||||||
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
|
|
||||||
" except ValueError:\n",
|
|
||||||
" pbar.update(len(bs_list))\n",
|
|
||||||
" continue\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
|
||||||
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" for bs in reversed(bs_list):\n",
|
|
||||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
|
||||||
" if bs > maxbs:\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" continue\n",
|
|
||||||
" Q = Qmax[:bs]\n",
|
|
||||||
" K = Kmax[:bs]\n",
|
|
||||||
" def run():\n",
|
|
||||||
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" def clear_cache():\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
|
||||||
" l = tail_mean(latency)\n",
|
|
||||||
" out.append({\n",
|
|
||||||
" \"n\": n,\n",
|
|
||||||
" \"d\": d,\n",
|
|
||||||
" \"seqlen\": seqlen,\n",
|
|
||||||
" \"bs\": bs,\n",
|
|
||||||
" \"latency\": l\n",
|
|
||||||
" })\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" del cache, Q, K, Qmax, Kmax\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" pbar.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "a3a2103c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
|
|
||||||
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
|
|
||||||
" pbar = tqdm(total=total)\n",
|
|
||||||
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
|
|
||||||
" h = n * d\n",
|
|
||||||
" try:\n",
|
|
||||||
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
|
|
||||||
" except ValueError:\n",
|
|
||||||
" pbar.update(len(bs_list))\n",
|
|
||||||
" continue\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
|
|
||||||
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" for bs in reversed(bs_list):\n",
|
|
||||||
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
|
|
||||||
" if bs > maxbs:\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" continue\n",
|
|
||||||
" Q = Qmax[:bs]\n",
|
|
||||||
" K = Kmax[:bs]\n",
|
|
||||||
" def run():\n",
|
|
||||||
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" def clear_cache():\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
|
|
||||||
" l = tail_mean(latency)\n",
|
|
||||||
" out.append({\n",
|
|
||||||
" \"n\": n,\n",
|
|
||||||
" \"d\": d,\n",
|
|
||||||
" \"seqlen\": seqlen,\n",
|
|
||||||
" \"bs\": bs,\n",
|
|
||||||
" \"latency\": l\n",
|
|
||||||
" })\n",
|
|
||||||
" pbar.update()\n",
|
|
||||||
" del cache, Q, K, Qmax, Kmax\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
" pbar.close()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"id": "3aaad98a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"data = {}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"id": "18137de3",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/22 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"db = []\n",
|
|
||||||
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
|
|
||||||
"data[\"qk_init\"] = db"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"id": "26c76e15",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"db = []\n",
|
|
||||||
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
|
|
||||||
"data[\"qk_ar\"] = db"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 13,
|
|
||||||
"id": "313e36eb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 3 256 256\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 3 256 1\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 12 256 256\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"2048 12 256 1\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"db = []\n",
|
|
||||||
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
|
|
||||||
"data[\"dense\"] = db"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 14,
|
|
||||||
"id": "50c37959",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(data, f)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 15,
|
|
||||||
"id": "828ddb54",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"df_dense = (\n",
|
|
||||||
" pd.DataFrame.from_dict(data[\"dense\"])\n",
|
|
||||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
|
||||||
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
|
|
||||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
|
|
||||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
|
||||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
|
||||||
" .assign(series=\"dense\")\n",
|
|
||||||
")\n",
|
|
||||||
"df_qk_init = (\n",
|
|
||||||
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
|
|
||||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
|
||||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
|
|
||||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
|
|
||||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
|
||||||
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
|
|
||||||
" .assign(series=\"qk_init\")\n",
|
|
||||||
")\n",
|
|
||||||
"df_qk_ar = (\n",
|
|
||||||
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
|
|
||||||
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
|
|
||||||
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
|
|
||||||
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
|
|
||||||
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
|
|
||||||
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
|
|
||||||
" .assign(series=\"qk_ar\")\n",
|
|
||||||
")\n",
|
|
||||||
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 39,
|
|
||||||
"id": "c296a395",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 39,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"pd\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "a25cdd5a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "63b8a531",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import transformers"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "af90eff1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
|
|
||||||
" return transformers.OPTConfig(\n",
|
|
||||||
" num_hidden_layers=n_layers,\n",
|
|
||||||
" hidden_size=d_model,\n",
|
|
||||||
" ffn_dim=d_model*4,\n",
|
|
||||||
" num_attention_heads=n_heads,\n",
|
|
||||||
" **kwargs\n",
|
|
||||||
" )\n",
|
|
||||||
"optcfg = {\n",
|
|
||||||
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
|
|
||||||
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
|
|
||||||
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
|
|
||||||
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
|
|
||||||
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
|
|
||||||
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
|
|
||||||
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
|
|
||||||
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
|
|
||||||
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
|
|
||||||
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
|
|
||||||
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
|
|
||||||
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
|
|
||||||
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
|
|
||||||
"}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "5b9ebbec",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
|
|
||||||
" bs, tgt_len = input_ids.shape\n",
|
|
||||||
" if past_key_values is not None:\n",
|
|
||||||
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
|
|
||||||
" assert bs == _bs\n",
|
|
||||||
" else:\n",
|
|
||||||
" src_len = 0\n",
|
|
||||||
" if attention_mask is None:\n",
|
|
||||||
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
|
|
||||||
" ret = model(\n",
|
|
||||||
" input_ids=input_ids,\n",
|
|
||||||
" attention_mask=attention_mask,\n",
|
|
||||||
" past_key_values=past_key_values,\n",
|
|
||||||
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
|
|
||||||
" )\n",
|
|
||||||
" return ret\n",
|
|
||||||
"\n",
|
|
||||||
"def time_greedy_generate(model, input_ids, new_tokens):\n",
|
|
||||||
" ts = []\n",
|
|
||||||
" output = input_ids\n",
|
|
||||||
" past_key_values = None\n",
|
|
||||||
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
|
|
||||||
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
|
|
||||||
" for _ in range(new_tokens):\n",
|
|
||||||
" cache.zero_()\n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" st = time.perf_counter_ns()\n",
|
|
||||||
" \n",
|
|
||||||
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
|
|
||||||
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
|
|
||||||
" output = torch.cat([output, input_ids], axis=1)\n",
|
|
||||||
" past_key_values = ret.past_key_values\n",
|
|
||||||
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
|
|
||||||
" \n",
|
|
||||||
" torch.cuda.synchronize()\n",
|
|
||||||
" ed = time.perf_counter_ns()\n",
|
|
||||||
" ts.append((ed-st)/1e9)\n",
|
|
||||||
" return np.array(ts)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "fc92f940",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"opt_config = optcfg[\"6.7b\"]\n",
|
|
||||||
"\n",
|
|
||||||
"torch.set_default_dtype(torch.bfloat16)\n",
|
|
||||||
"with transformers.modeling_utils.no_init_weights():\n",
|
|
||||||
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
|
|
||||||
"torch.set_default_dtype(torch.float32)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c19fa396",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"db = {}\n",
|
|
||||||
"input_tokens = 200\n",
|
|
||||||
"new_tokens = 500\n",
|
|
||||||
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
|
|
||||||
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
|
|
||||||
" stack = []\n",
|
|
||||||
" for _ in range(10):\n",
|
|
||||||
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
|
|
||||||
" stack.append(l)\n",
|
|
||||||
" db[bs] = np.median(np.stack(stack), axis=0)\n",
|
|
||||||
" del x\n",
|
|
||||||
" torch.cuda.empty_cache()\n",
|
|
||||||
"del model\n",
|
|
||||||
"torch.cuda.empty_cache()\n",
|
|
||||||
"\n",
|
|
||||||
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
|
|
||||||
" pickle.dump(db, f)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.12"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# Set plot parameters
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.5
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
|
|
||||||
# Path settings
|
|
||||||
FIGURE_PATH = "./paper_plot/figures"
|
|
||||||
|
|
||||||
# Load accuracy data
|
|
||||||
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
|
|
||||||
|
|
||||||
# Create figure with 4 subplots (one for each dataset)
|
|
||||||
fig, axs = plt.subplots(1, 4)
|
|
||||||
fig.set_size_inches(9, 2.5)
|
|
||||||
|
|
||||||
# Reduce the spacing between subplots
|
|
||||||
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
|
|
||||||
|
|
||||||
# Define datasets and their columns
|
|
||||||
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
|
|
||||||
metrics = ["Exact Match", "F1"]
|
|
||||||
|
|
||||||
# Define bar settings - make bars thicker
|
|
||||||
# total_width, n = 0.9, 3 # increased total width and n for three models
|
|
||||||
# width = total_width / n
|
|
||||||
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
|
|
||||||
# It's also used as the base for calculating the actual plotted bar width.
|
|
||||||
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
|
|
||||||
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
|
|
||||||
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
|
|
||||||
n = 3 # Number of models
|
|
||||||
width = 0.64 # Distance between centers of adjacent bars in a group
|
|
||||||
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
|
|
||||||
|
|
||||||
# Colors and hatches
|
|
||||||
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
|
|
||||||
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
|
|
||||||
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
|
|
||||||
|
|
||||||
# Create plots for each dataset
|
|
||||||
for i, dataset in enumerate(datasets):
|
|
||||||
ax = axs[i]
|
|
||||||
|
|
||||||
# Get data for this dataset and convert to percentages
|
|
||||||
em_values = [
|
|
||||||
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
|
|
||||||
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
|
|
||||||
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
|
|
||||||
]
|
|
||||||
f1_values = [
|
|
||||||
acc_data.loc[0, f"{dataset} F1"] * 100,
|
|
||||||
acc_data.loc[1, f"{dataset} F1"] * 100,
|
|
||||||
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
|
|
||||||
]
|
|
||||||
|
|
||||||
# Define x positions for bars
|
|
||||||
# For EM: center - width, center, center + width
|
|
||||||
# For F1: center - width, center, center + width
|
|
||||||
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
|
|
||||||
bar_offsets = [-width, 0, width]
|
|
||||||
|
|
||||||
# Plot all bars on the same axis
|
|
||||||
for metric_idx, metric_group_center in enumerate(group_centers):
|
|
||||||
values_to_plot = em_values if metric_idx == 0 else f1_values
|
|
||||||
for j, model_label in enumerate(labels):
|
|
||||||
x_pos = metric_group_center + bar_offsets[j]
|
|
||||||
bar_value = values_to_plot[j]
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
x_pos,
|
|
||||||
bar_value,
|
|
||||||
width=width * bar_width_plotting_factor, # Use the new factor for bar width
|
|
||||||
color="white",
|
|
||||||
edgecolor=edgecolors[j],
|
|
||||||
hatch=hatches[j],
|
|
||||||
linewidth=1.5,
|
|
||||||
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add value on top of bar
|
|
||||||
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
|
|
||||||
f"{bar_value:.1f}", ha='center', va='bottom',
|
|
||||||
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
|
|
||||||
|
|
||||||
# Set x-ticks and labels
|
|
||||||
ax.set_xticks(group_centers) # Position ticks at the center of each group
|
|
||||||
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
|
|
||||||
|
|
||||||
# Now, shift these labels slightly to the right
|
|
||||||
# Adjust this value to control the amount of shift (in data coordinates)
|
|
||||||
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
|
|
||||||
# horizontal_shift = 0.7 # Try adjusting this value
|
|
||||||
|
|
||||||
# for label in xticklabels:
|
|
||||||
# # Get the current x position (which is the tick location)
|
|
||||||
# current_x_pos = label.get_position()[0]
|
|
||||||
# # Set the new x position by adding the shift
|
|
||||||
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
|
|
||||||
# # Ensure the label remains horizontally centered on this new x position
|
|
||||||
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
|
|
||||||
# label.set_horizontalalignment('center')
|
|
||||||
|
|
||||||
# Set title
|
|
||||||
ax.set_title(dataset, fontsize=14)
|
|
||||||
|
|
||||||
# Set y-label for all subplots
|
|
||||||
if i == 0:
|
|
||||||
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
|
|
||||||
else:
|
|
||||||
# Hide y-tick labels for non-first subplots to save space
|
|
||||||
ax.tick_params(axis='y', labelsize=10)
|
|
||||||
|
|
||||||
# Set y-limits based on data range
|
|
||||||
all_values = em_values + f1_values
|
|
||||||
max_val = max(all_values)
|
|
||||||
min_val = min(all_values)
|
|
||||||
|
|
||||||
# Special handling for GPQA which has very low values
|
|
||||||
if dataset == "GPQA":
|
|
||||||
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
|
|
||||||
else:
|
|
||||||
# Reduce the extra space above the bars
|
|
||||||
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
|
|
||||||
|
|
||||||
# Format y-ticks as percentages
|
|
||||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
|
|
||||||
|
|
||||||
# Set x-limits to properly space the bars with less blank space
|
|
||||||
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
|
|
||||||
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
|
|
||||||
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
|
|
||||||
|
|
||||||
# Add a box around the subplot
|
|
||||||
# for spine in ax.spines.values():
|
|
||||||
# spine.set_visible(True)
|
|
||||||
# spine.set_linewidth(1.0)
|
|
||||||
|
|
||||||
# Add legend to first subplot
|
|
||||||
if i == 0:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
|
|
||||||
ncol=3, # Changed to 3 columns for three labels
|
|
||||||
loc="upper center",
|
|
||||||
labelspacing=0.1,
|
|
||||||
edgecolor="black",
|
|
||||||
facecolor="white",
|
|
||||||
framealpha=1,
|
|
||||||
shadow=False,
|
|
||||||
fancybox=False,
|
|
||||||
handlelength=1.0,
|
|
||||||
handletextpad=0.6,
|
|
||||||
columnspacing=0.8,
|
|
||||||
prop={"weight": "bold", "size": 12},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save figure with tight layout but no additional padding
|
|
||||||
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
|
|
||||||
plt.show()
|
|
||||||
@@ -1,309 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
|
|
||||||
# \file: /hnsw_degree_visit_plot_binned_academic.py
|
|
||||||
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
|
|
||||||
# per degree bin, styled for academic publications, with caching.
|
|
||||||
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import re
|
|
||||||
from collections import Counter
|
|
||||||
import os # For robust filepath manipulation
|
|
||||||
import math # For calculating scaling factor
|
|
||||||
import pickle # For caching data
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- Matplotlib parameters for academic paper style (from reference) ---
|
|
||||||
plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1.5
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
|
|
||||||
|
|
||||||
# --- Define styles from reference ---
|
|
||||||
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- File Paths ---
|
|
||||||
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
|
|
||||||
visit_log_file = './re.log'
|
|
||||||
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
|
|
||||||
# --- CACHE FILE PATH: Keep this consistent ---
|
|
||||||
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
|
|
||||||
|
|
||||||
# --- Configuration ---
|
|
||||||
# Set to True to bypass cache and force recomputation.
|
|
||||||
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
|
|
||||||
FORCE_RECOMPUTE = False
|
|
||||||
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
|
|
||||||
|
|
||||||
# Create directory for figures if it doesn't exist
|
|
||||||
output_dir = os.path.dirname(output_image_file)
|
|
||||||
if output_dir and not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
print(f"Created directory: {output_dir}")
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- Attempt to load data from cache or compute ---
|
|
||||||
df_plot_data = None
|
|
||||||
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
|
|
||||||
|
|
||||||
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
|
|
||||||
try:
|
|
||||||
with open(CACHE_FILE_PATH, 'rb') as f:
|
|
||||||
cache_content = pickle.load(f)
|
|
||||||
df_plot_data = cache_content['data']
|
|
||||||
bin_size_for_plot = cache_content['bin_size']
|
|
||||||
# Basic validation of cached data
|
|
||||||
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
|
|
||||||
if not isinstance(df_plot_data, pd.DataFrame) or \
|
|
||||||
'degree_bin_label' not in df_plot_data.columns or \
|
|
||||||
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
|
|
||||||
not isinstance(bin_size_for_plot, int):
|
|
||||||
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
|
|
||||||
df_plot_data = None # Invalidate to trigger recomputation
|
|
||||||
else:
|
|
||||||
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
|
|
||||||
|
|
||||||
# --- Modify the label loaded from cache for display purpose ---
|
|
||||||
# This modification only happens when data is loaded from cache and meets specific conditions.
|
|
||||||
# Assumption: If the bin_size_for_plot in cache is 5,
|
|
||||||
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
|
|
||||||
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
|
|
||||||
# Check if "0-4" label exists
|
|
||||||
if '0-4' in df_plot_data['degree_bin_label'].values:
|
|
||||||
# Use .loc to ensure the modification is on the original DataFrame
|
|
||||||
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
|
|
||||||
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading from cache: {e}. Recomputing.")
|
|
||||||
df_plot_data = None # Invalidate to trigger recomputation
|
|
||||||
|
|
||||||
if df_plot_data is None:
|
|
||||||
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
|
|
||||||
# --- 1. Read Degree Distribution File ---
|
|
||||||
degrees_data = []
|
|
||||||
try:
|
|
||||||
with open(degree_file, 'r') as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
line_stripped = line.strip()
|
|
||||||
if line_stripped:
|
|
||||||
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
|
|
||||||
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
|
|
||||||
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
|
|
||||||
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
|
|
||||||
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
|
|
||||||
|
|
||||||
|
|
||||||
if not degrees_data:
|
|
||||||
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
|
|
||||||
exit()
|
|
||||||
df_degrees = pd.DataFrame(degrees_data)
|
|
||||||
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
|
|
||||||
|
|
||||||
# --- 2. Read Visit Log File and Count Frequencies ---
|
|
||||||
visit_counts = Counter()
|
|
||||||
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
|
|
||||||
try:
|
|
||||||
with open(visit_log_file, 'r') as f_log:
|
|
||||||
for line_num, line in enumerate(f_log, 1):
|
|
||||||
match = node_id_pattern.search(line)
|
|
||||||
if match:
|
|
||||||
try:
|
|
||||||
node_id = int(match.group(2))
|
|
||||||
visit_counts[node_id] += 1 # Increment visit count for the node
|
|
||||||
except ValueError:
|
|
||||||
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
|
|
||||||
if not df_degrees.empty:
|
|
||||||
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
|
|
||||||
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
|
|
||||||
# Generate visit counts to test different probability magnitudes
|
|
||||||
if node_id_val % 23 == 0: # Very low probability
|
|
||||||
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
|
|
||||||
elif node_id_val % 11 == 0: # Low probability
|
|
||||||
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
|
|
||||||
elif node_id_val % 5 == 0: # Moderate probability
|
|
||||||
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
|
|
||||||
else: # Higher probability (but still < 1000 visits for a single node usually)
|
|
||||||
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
|
|
||||||
visit_counts[node_id_val] = np.random.poisson(lambda_val)
|
|
||||||
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
|
|
||||||
|
|
||||||
if not visit_counts:
|
|
||||||
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
|
|
||||||
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
|
|
||||||
else:
|
|
||||||
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
|
|
||||||
df_visits = pd.DataFrame(df_visits_list)
|
|
||||||
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
|
|
||||||
|
|
||||||
# --- 3. Merge Degree Data with Visit Data ---
|
|
||||||
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
|
|
||||||
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
|
|
||||||
print(f"Merged data contains {len(df_merged)} entries.")
|
|
||||||
|
|
||||||
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
|
|
||||||
current_bin_size = 5
|
|
||||||
bin_size_for_plot = current_bin_size
|
|
||||||
|
|
||||||
if not df_degrees.empty:
|
|
||||||
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
|
|
||||||
|
|
||||||
df_merged_with_bins = df_merged.copy()
|
|
||||||
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
|
|
||||||
|
|
||||||
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
|
|
||||||
total_visit_count_in_bin=('visit_count', 'sum'),
|
|
||||||
node_count_in_bin=('node_id', 'nunique')
|
|
||||||
).reset_index()
|
|
||||||
|
|
||||||
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
|
|
||||||
# This value is what gets cached.
|
|
||||||
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
|
|
||||||
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
|
|
||||||
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
|
|
||||||
|
|
||||||
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
|
|
||||||
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
|
|
||||||
|
|
||||||
bin_to_drop_label = '60-64'
|
|
||||||
original_length = len(df_binned_analysis)
|
|
||||||
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
|
|
||||||
if len(df_plot_data_intermediate) < original_length:
|
|
||||||
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
|
|
||||||
else:
|
|
||||||
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
|
|
||||||
|
|
||||||
df_plot_data = df_plot_data_intermediate
|
|
||||||
|
|
||||||
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
|
|
||||||
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
|
|
||||||
|
|
||||||
if df_plot_data is not None and not df_plot_data.empty:
|
|
||||||
try:
|
|
||||||
with open(CACHE_FILE_PATH, 'wb') as f:
|
|
||||||
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
|
|
||||||
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving data to cache: {e}")
|
|
||||||
elif df_plot_data is None or df_plot_data.empty:
|
|
||||||
print("Computed data for binned plot is empty, not saving to cache.")
|
|
||||||
else:
|
|
||||||
print("Degree data (df_degrees) is empty. Cannot perform binning.")
|
|
||||||
df_plot_data = pd.DataFrame()
|
|
||||||
bin_size_for_plot = current_bin_size
|
|
||||||
|
|
||||||
# %%
|
|
||||||
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
|
|
||||||
|
|
||||||
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
|
|
||||||
base_name, ext = os.path.splitext(output_image_file)
|
|
||||||
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
|
|
||||||
binned_output_image_file = base_name + ext
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
|
|
||||||
|
|
||||||
df_plot_data_plotting = df_plot_data.copy()
|
|
||||||
# Calculate per-query probability: (avg visits over N queries) / N
|
|
||||||
df_plot_data_plotting['per_query_visit_probability'] = \
|
|
||||||
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
|
|
||||||
|
|
||||||
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
|
|
||||||
|
|
||||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
|
|
||||||
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
|
|
||||||
|
|
||||||
apply_scaling_to_label_and_values = False # Initialize flag
|
|
||||||
exponent_for_label_display = 0 # Initialize exponent
|
|
||||||
|
|
||||||
if pd.notna(max_probability) and max_probability > 0:
|
|
||||||
potential_exponent = math.floor(math.log10(max_probability))
|
|
||||||
|
|
||||||
if potential_exponent <= -4 or potential_exponent >= 0:
|
|
||||||
apply_scaling_to_label_and_values = True
|
|
||||||
exponent_for_label_display = potential_exponent
|
|
||||||
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
|
|
||||||
|
|
||||||
if apply_scaling_to_label_and_values:
|
|
||||||
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
|
|
||||||
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
|
|
||||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
|
|
||||||
else:
|
|
||||||
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
|
|
||||||
|
|
||||||
elif pd.notna(max_probability) and max_probability == 0:
|
|
||||||
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
|
|
||||||
else:
|
|
||||||
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
|
|
||||||
|
|
||||||
ax.bar(
|
|
||||||
df_plot_data_plotting['degree_bin_label'],
|
|
||||||
y_axis_values_to_plot,
|
|
||||||
color='white',
|
|
||||||
edgecolor=edgecolors_ref[0],
|
|
||||||
linewidth=1.5,
|
|
||||||
width=0.8
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
|
|
||||||
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
|
|
||||||
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
|
|
||||||
|
|
||||||
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
|
|
||||||
|
|
||||||
num_bins = len(df_plot_data_plotting)
|
|
||||||
if num_bins > 12:
|
|
||||||
ax.set_xticks(ax.get_xticks())
|
|
||||||
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
|
|
||||||
elif num_bins > 8:
|
|
||||||
ax.tick_params(axis='x', labelsize=9)
|
|
||||||
else:
|
|
||||||
ax.tick_params(axis='x', labelsize=10)
|
|
||||||
|
|
||||||
ax.tick_params(axis='y', labelsize=10)
|
|
||||||
|
|
||||||
padding_factor = 0.05
|
|
||||||
current_max_y_on_axis = y_axis_values_to_plot.max()
|
|
||||||
|
|
||||||
upper_y_limit = 0.1 # Default small upper limit
|
|
||||||
if pd.notna(current_max_y_on_axis):
|
|
||||||
if current_max_y_on_axis > 0:
|
|
||||||
# Adjust minimum visible range based on whether scaling was applied and the exponent
|
|
||||||
min_meaningful_limit = 0.01
|
|
||||||
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
|
|
||||||
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
|
|
||||||
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
|
|
||||||
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
|
|
||||||
|
|
||||||
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
|
|
||||||
|
|
||||||
else: # current_max_y_on_axis is 0
|
|
||||||
upper_y_limit = 0.1
|
|
||||||
ax.set_ylim(0, upper_y_limit)
|
|
||||||
else:
|
|
||||||
ax.set_ylim(0, 1.0) # Default for empty or NaN data
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
|
|
||||||
print(f"Binned bar chart saved to {binned_output_image_file}")
|
|
||||||
plt.show()
|
|
||||||
plt.close(fig)
|
|
||||||
else:
|
|
||||||
if df_plot_data is None:
|
|
||||||
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
|
|
||||||
elif df_plot_data.empty:
|
|
||||||
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
|
|
||||||
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
|
|
||||||
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
|
|
||||||
|
|
||||||
# %%
|
|
||||||
print("Script finished.")
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data – up to 50× smaller than standard indexes – while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
|
|
||||||
BIG_GRAPH_PATHS = [
|
|
||||||
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
|
|
||||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
|
|
||||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
|
|
||||||
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
|
|
||||||
]
|
|
||||||
STATS_FILE_NAME = "degree_distribution.txt"
|
|
||||||
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
|
|
||||||
"HNSW-Base",
|
|
||||||
"DegreeGuide",
|
|
||||||
"HNSW-D9",
|
|
||||||
"RandCut",
|
|
||||||
]
|
|
||||||
# Average degrees are static and can be directly used in the plotting script or also cached.
|
|
||||||
# For simplicity here, we'll focus on caching the dynamic degree arrays.
|
|
||||||
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
|
|
||||||
|
|
||||||
# --- Cache File Configuration ---
|
|
||||||
DATA_CACHE_DIR = "./paper_plot/data/"
|
|
||||||
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
|
|
||||||
|
|
||||||
def create_degree_data_cache():
|
|
||||||
"""
|
|
||||||
Reads degree distribution data from specified text files and saves it
|
|
||||||
into a compressed NumPy (.npz) cache file.
|
|
||||||
"""
|
|
||||||
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
|
|
||||||
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
|
|
||||||
|
|
||||||
cached_data = {}
|
|
||||||
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
|
|
||||||
|
|
||||||
for i, base_path in enumerate(BIG_GRAPH_PATHS):
|
|
||||||
method_label = BIG_GRAPH_LABELS[i]
|
|
||||||
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
|
|
||||||
|
|
||||||
print(f"Processing: {method_label} from {degree_file_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load degrees as integers
|
|
||||||
degrees = np.loadtxt(degree_file_path, dtype=int)
|
|
||||||
|
|
||||||
if degrees.size == 0:
|
|
||||||
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
|
|
||||||
# Store an empty array or handle as needed. For npz, an empty array is fine.
|
|
||||||
cached_data[method_label] = np.array([], dtype=int)
|
|
||||||
else:
|
|
||||||
# Store the loaded degrees array with the method label as the key
|
|
||||||
cached_data[method_label] = degrees
|
|
||||||
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
|
|
||||||
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
|
|
||||||
# Storing None might require special handling when loading. Empty array is safer for np.load.
|
|
||||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
|
|
||||||
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
|
|
||||||
|
|
||||||
if not cached_data:
|
|
||||||
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Save all collected degree arrays into a single .npz file.
|
|
||||||
# Using savez_compressed for potentially smaller file size.
|
|
||||||
np.savez_compressed(cache_file_path, **cached_data)
|
|
||||||
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
|
|
||||||
print("Cached arrays (keys):", list(cached_data.keys()))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("--- Degree Distribution Data Caching Script ---")
|
|
||||||
create_degree_data_cache()
|
|
||||||
print("--- Caching script finished. ---")
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
|
|
||||||
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
|
|
||||||
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
|
|
||||||
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
|
|
||||||
|
@@ -1,3 +0,0 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
|
||||||
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
|
|
||||||
size 227482438
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
2,1,512,1024,0.541,0.326,1.659509202
|
|
||||||
2,2,512,1024,0.979,0.621,1.576489533
|
|
||||||
2,4,512,1024,1.846,0.977,1.889457523
|
|
||||||
2,8,512,1024,3.575,1.943,1.83993824
|
|
||||||
2,16,512,1024,7.035,3.733,1.884543263
|
|
||||||
2,32,512,1024,15.655,8.517,1.838088529
|
|
||||||
2,64,512,1024,32.772,17.43,1.88020654
|
|
||||||
4,1,512,1024,2.675,1.38,1.938405797
|
|
||||||
4,2,512,1024,5.397,2.339,2.307396323
|
|
||||||
4,4,512,1024,10.672,4.944,2.158576052
|
|
||||||
4,8,512,1024,21.061,9.266,2.272933305
|
|
||||||
4,16,512,1024,46.332,18.334,2.527108105
|
|
||||||
4,32,512,1024,99.607,36.156,2.754923111
|
|
||||||
4,64,512,1024,186.348,72.356,2.575432583
|
|
||||||
8,1,512,1024,7.325,4.087,1.792268167
|
|
||||||
8,2,512,1024,14.109,7.491,1.883460152
|
|
||||||
8,4,512,1024,28.499,14.013,2.033754371
|
|
||||||
8,8,512,1024,65.222,27.453,2.375769497
|
|
||||||
8,16,512,1024,146.294,52.55,2.783901047
|
|
||||||
8,32,512,1024,277.099,103.61,2.674442621
|
|
||||||
8,64,512,1024,512.979,208.36,2.461984066
|
|
||||||
|
@@ -1,9 +0,0 @@
|
|||||||
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
|
|
||||||
NQ,Latency,6.9,5.8,4.2,3.7
|
|
||||||
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
|
|
||||||
TriviaQA,Latency,17.054,14.542,12.046,10.83
|
|
||||||
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
|
|
||||||
GPQA,Latency,9.164,7.639,6.798,5.77
|
|
||||||
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
|
|
||||||
HotpotQA,Latency,60.279,39.827,50.664,29.868
|
|
||||||
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
|
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
|
|
||||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
|
|
||||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
|
|
||||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
|
|
||||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
|
|
||||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
|
|
||||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
|
|
||||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
|
|
||||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
|
|
||||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
|
|
||||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
|
|
||||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
|
|
||||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
|
|
||||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
|
|
||||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
|
|
||||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
|
|
||||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
|
|
||||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
|
|
||||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
|
|
||||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
|
|
||||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
|
|
||||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
|
|
||||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
|
|
||||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
|
|
||||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
|
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
|
|
||||||
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
|
|
||||||
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
|
|
||||||
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
|
|
||||||
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
|
|
||||||
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
|
|
||||||
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
|
|
||||||
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
|
|
||||||
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
|
|
||||||
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
|
|
||||||
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
|
|
||||||
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
|
|
||||||
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
|
|
||||||
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
|
|
||||||
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
|
|
||||||
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
|
|
||||||
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
|
|
||||||
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
|
|
||||||
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
|
|
||||||
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
|
|
||||||
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
|
|
||||||
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
|
|
||||||
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
|
|
||||||
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
|
|
||||||
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
|
|
||||||
|
@@ -1,3 +0,0 @@
|
|||||||
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
|
|
||||||
RAM,190,171,10,0,0,0,0
|
|
||||||
Storage,185.4,171,240,171,0.5,5,59
|
|
||||||
|
@@ -1,12 +0,0 @@
|
|||||||
Torch,8,55.592
|
|
||||||
Torch,16,75.439
|
|
||||||
Torch,32,110.025
|
|
||||||
Torch,64,186.496
|
|
||||||
Tutel,8,56.718
|
|
||||||
Tutel,16,82.121
|
|
||||||
Tutel,32,125.070
|
|
||||||
Tutel,64,216.191
|
|
||||||
BRT,8,56.725
|
|
||||||
BRT,16,79.291
|
|
||||||
BRT,32,93.180
|
|
||||||
BRT,64,118.923
|
|
||||||
|
@@ -1,6 +0,0 @@
|
|||||||
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
|
|
||||||
Latency,,,,,
|
|
||||||
NQ,4.616,4.133,3.826,3.511,3.323
|
|
||||||
TriviaQA,5.777,4.979,4.553,4.141,3.916
|
|
||||||
GPQA,1.733,1.593,1.468,1.336,1.259
|
|
||||||
Hotpot,15.515,13.479,12.383,11.216,10.606
|
|
||||||
|
@@ -1,151 +0,0 @@
|
|||||||
import matplotlib
|
|
||||||
from matplotlib.axes import Axes
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.patches as mpatches
|
|
||||||
from matplotlib.lines import Line2D
|
|
||||||
|
|
||||||
# plt.rcParams["font.family"] = "Helvetica"
|
|
||||||
plt.rcParams["ytick.direction"] = "in"
|
|
||||||
plt.rcParams["hatch.linewidth"] = 1
|
|
||||||
plt.rcParams["font.weight"] = "bold"
|
|
||||||
plt.rcParams["axes.labelweight"] = "bold"
|
|
||||||
plt.rcParams["text.usetex"] = True
|
|
||||||
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
|
|
||||||
plt.rcParams['text.latex.preamble'] = r"""
|
|
||||||
\usepackage{helvet} % Use Helvetica font for text
|
|
||||||
\usepackage{sfmath} % Use sans-serif font for math
|
|
||||||
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
|
|
||||||
\usepackage[T1]{fontenc} % Recommended for font encoding
|
|
||||||
"""
|
|
||||||
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
|
|
||||||
SAVE_PTH = "./paper_plot/figures"
|
|
||||||
font_size = 16
|
|
||||||
|
|
||||||
# New data in dictionary format
|
|
||||||
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
|
|
||||||
|
|
||||||
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
|
|
||||||
latency_data = {
|
|
||||||
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
|
|
||||||
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
|
|
||||||
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
|
|
||||||
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
|
|
||||||
}
|
|
||||||
cache_hit_counts = {
|
|
||||||
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
|
|
||||||
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
|
|
||||||
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
|
|
||||||
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create the figure with 4 subplots in a 2x2 grid
|
|
||||||
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
|
|
||||||
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
|
|
||||||
|
|
||||||
# Bar style settings
|
|
||||||
width = 0.7
|
|
||||||
x = np.arange(len(cache_ratios))
|
|
||||||
|
|
||||||
# Define hatch patterns for different cache ratios
|
|
||||||
hatch_patterns = ['//', '//', '//', '//', '//']
|
|
||||||
|
|
||||||
# Find max cache hit value across all datasets for unified y-axis
|
|
||||||
all_hit_counts = []
|
|
||||||
for dataset in datasets:
|
|
||||||
all_hit_counts.extend(cache_hit_counts[dataset])
|
|
||||||
max_unified_hit = max(all_hit_counts) * 1.13
|
|
||||||
|
|
||||||
for i, dataset in enumerate(datasets):
|
|
||||||
latencies = latency_data[dataset]
|
|
||||||
hit_counts = cache_hit_counts[dataset]
|
|
||||||
|
|
||||||
for j, val in enumerate(latencies):
|
|
||||||
container = axes[i].bar(
|
|
||||||
x[j],
|
|
||||||
val,
|
|
||||||
width=width,
|
|
||||||
color="white",
|
|
||||||
edgecolor="black",
|
|
||||||
linewidth=1.0,
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
axes[i].bar_label(
|
|
||||||
container,
|
|
||||||
[f"{val:.2f}"],
|
|
||||||
fontsize=10,
|
|
||||||
zorder=200,
|
|
||||||
fontweight="bold",
|
|
||||||
)
|
|
||||||
|
|
||||||
axes[i].set_title(dataset, fontsize=font_size)
|
|
||||||
axes[i].set_xticks(x)
|
|
||||||
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
|
|
||||||
|
|
||||||
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
|
|
||||||
max_val = max(latencies) * max_val_ratios[i]
|
|
||||||
axes[i].set_ylim(0, max_val)
|
|
||||||
axes[i].tick_params(axis='y', labelsize=12)
|
|
||||||
|
|
||||||
if i % 2 == 0:
|
|
||||||
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
|
|
||||||
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
|
|
||||||
|
|
||||||
ax2: Axes = axes[i].twinx()
|
|
||||||
ax2.plot(x, hit_counts,
|
|
||||||
linestyle='--',
|
|
||||||
marker='o',
|
|
||||||
markersize=6,
|
|
||||||
linewidth=1.5,
|
|
||||||
color='k',
|
|
||||||
markerfacecolor='none',
|
|
||||||
zorder=20)
|
|
||||||
|
|
||||||
ax2.set_ylim(0, max_unified_hit)
|
|
||||||
ax2.tick_params(axis='y', labelsize=12)
|
|
||||||
if i % 2 == 1:
|
|
||||||
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
|
|
||||||
|
|
||||||
for j, val in enumerate(hit_counts):
|
|
||||||
if val > 0:
|
|
||||||
ax2.annotate(f"{val:.1f}%",
|
|
||||||
(x[j], val),
|
|
||||||
textcoords="offset points",
|
|
||||||
xytext=(0, 5),
|
|
||||||
ha='center',
|
|
||||||
va='bottom',
|
|
||||||
fontsize=10,
|
|
||||||
fontweight='bold')
|
|
||||||
|
|
||||||
# Create legend for both plots
|
|
||||||
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
|
|
||||||
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
|
|
||||||
|
|
||||||
# --- MODIFICATION FOR LEGEND AT THE TOP ---
|
|
||||||
fig.legend(handles=[bar_patch, line_patch],
|
|
||||||
loc='upper center', # Position the legend at the upper center
|
|
||||||
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
|
|
||||||
# 0.97 means 97% from the bottom, so near the top)
|
|
||||||
ncol=3,
|
|
||||||
fontsize=font_size-2)
|
|
||||||
# --- END OF MODIFICATION ---
|
|
||||||
|
|
||||||
# Set common x-axis label - you might want to add this back if needed
|
|
||||||
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
|
|
||||||
|
|
||||||
# --- MODIFICATION FOR TIGHT LAYOUT ---
|
|
||||||
# Adjust rect to make space for the legend at the top.
|
|
||||||
# (left, bottom, right, top_for_subplots)
|
|
||||||
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
|
|
||||||
# leaving the top portion (0.93 to 1.0) for the legend.
|
|
||||||
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
|
|
||||||
# --- END OF MODIFICATION ---
|
|
||||||
|
|
||||||
# Create directory if it doesn't exist (optional, good practice)
|
|
||||||
import os
|
|
||||||
if not os.path.exists(SAVE_PTH):
|
|
||||||
os.makedirs(SAVE_PTH)
|
|
||||||
|
|
||||||
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
|
|
||||||
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
|
|
||||||
# plt.show() # Optional: to display the plot
|
|
||||||
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 130 KiB |
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 100 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user