Compare commits

..

93 Commits
apps ... v0.1.9

Author SHA1 Message Date
Andy Lee
fc9c5cb39d fix: make release workflow idempotent
- Check if version is already updated before trying to update
- Check if tag already exists before creating
- Check if GitHub release already exists before creating
- This allows re-running the workflow after partial failures

Previously, if the workflow failed after updating version but before
completing the release, it couldn't be re-run with the same version.
2025-07-25 14:47:35 -07:00
Andy Lee
8f2a1e87ea Merge pull request #7 from yichuan-w/fix/simple-ubuntu22-build
fix: simplify build system for Colab compatibility
2025-07-25 14:08:37 -07:00
Andy Lee
50caf65f28 fix: change ubuntu-latest to ubuntu-22.04 and add Python 3.13
- Explicitly use ubuntu-22.04 instead of ubuntu-latest
- Add Python 3.13 to the build matrix
- This ensures we build on the same OS version as Google Colab
2025-07-25 13:48:59 -07:00
Andy Lee
1b48794ca8 cleanup: remove cibuildwheel workflow files
- Remove ci-cibuildwheel.yml and build-cibuildwheel.yml
- These files were not present in v0.1.5
- Keep only the simple build system
2025-07-25 13:48:08 -07:00
Andy Lee
4aef1d814e revert: simplify build system by removing manylinux/cibuildwheel
- Revert to simple Ubuntu 22.04 builds that should work with Colab
- Remove all manylinux container complexity
- Colab runs on Ubuntu 22.04, so direct builds should be compatible
- Restore build-reusable.yml to v0.1.5 version
- Remove cibuildwheel option from release workflow

This should fix the overcomplicated build issues while maintaining
Colab compatibility through direct Ubuntu 22.04 builds.
2025-07-25 13:46:51 -07:00
GitHub Actions
75ddcd6158 chore: release v0.1.9 2025-07-25 20:04:42 +00:00
Andy Lee
2a4df11f5c fix: absolute path for passages 2025-07-25 11:59:30 -07:00
Andy Lee
5eb893c62b ci: add Python 3.13 support to build matrix 2025-07-25 09:53:36 -07:00
yichuan520030910320
d91ce2e94d readme 2025-07-25 02:19:54 -07:00
yichuan520030910320
5c2ff8a641 clean research stuff 2025-07-25 02:14:15 -07:00
yichuan520030910320
d4f474c9b7 update broken link 2025-07-25 02:13:22 -07:00
yichuan520030910320
170f7644e9 simplify readme 2025-07-25 02:11:02 -07:00
yichuan520030910320
cd8b970eff Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 01:45:57 -07:00
yichuan520030910320
52153bbb69 update faiss compare 2025-07-25 01:45:50 -07:00
GitHub Actions
e1ae087207 chore: release v0.1.8 2025-07-25 08:24:40 +00:00
Andy Lee
48c5e12ac1 fix: use absolute path for passages_file to prevent FileNotFoundError
When embedding server is launched as a subprocess, it may run in a different
working directory. Using absolute paths ensures the server can always find
the metadata file regardless of where it's launched from.
2025-07-25 01:23:47 -07:00
yichuan520030910320
f8b5c97190 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 00:37:33 -07:00
yichuan520030910320
d038c81b8b update benchmard section 2025-07-25 00:37:27 -07:00
Andy Lee
29cbbbd0d6 fix: resolve libzmq pkg-config issues in manylinux containers
- Add gcc-c++ and cmake to dependencies
- Create libzmq.pc file if missing (CentOS 7 issue)
- Set PKG_CONFIG_PATH through CIBW_ENVIRONMENT_LINUX
- Add protobuf-devel to ensure all headers are available
- Fix shell variable escaping in heredoc
2025-07-25 00:35:52 -07:00
Andy Lee
179f30bc36 fix: improve system dependency installation in manylinux containers
- Add yum cache cleaning and updating
- Make package installations more resilient with fallbacks
- Use pkgconfig instead of pkg-config (CentOS 7 naming)
- Handle optional packages that might not be available
- Add error handling for package installation failures
2025-07-25 00:30:29 -07:00
Andy Lee
c4a0a68581 fix: handle pure Python packages in cibuildwheel workflow
- Build pure Python packages (leann-core, leann) with standard build tool
- Only use cibuildwheel for C extension packages (leann-backend-hnsw, leann-backend-diskann)
- Build pure Python packages only once on ubuntu-latest
- Add Python setup for building pure packages
- Add package listing step for debugging
2025-07-25 00:26:15 -07:00
Andy Lee
5c836ad08e fix: handle git dubious ownership error in manylinux containers
- Add multiple safe.directory configurations to cover different possible paths
- This fixes 'detected dubious ownership in repository' error
- Ensures git works properly in manylinux2014 containers
2025-07-25 00:22:01 -07:00
Andy Lee
673fd9b7cd fix: upgrade to actions v4 and handle manylinux2014 compatibility
- Upgrade all GitHub Actions to v4 (v3 is deprecated)
- Use manual git checkout in manylinux2014 containers to avoid Node.js issues
- Update artifact naming to ensure uniqueness (required by v4)
- Add fail-fast: false to build strategies
- This maintains manylinux2014 compatibility while using latest actions
2025-07-25 00:20:21 -07:00
Andy Lee
84b24b233d feat: add cibuildwheel option to release workflow
- Add optional use_cibuildwheel parameter to release workflow
- Create separate CI workflow for testing cibuildwheel
- Support conditional build workflow selection in release process
- This allows building wheels compatible with Google Colab and older systems
- Maintains backward compatibility with existing build process
2025-07-25 00:16:08 -07:00
Andy Lee
499cdd7822 feat: add cibuildwheel workflow for better platform compatibility
- Use cibuildwheel for professional wheel building
- Specifically use manylinux2014 for Google Colab compatibility
- Supports Python 3.9-3.12 on Linux and macOS
- Handles monorepo structure with separate builds per package
- Includes basic import tests for each package
- This should resolve compatibility issues with older systems like Google Colab
2025-07-25 00:16:08 -07:00
yichuan520030910320
800d4cf111 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 00:12:47 -07:00
yichuan520030910320
b6d43f5fd9 add gif 2025-07-25 00:12:35 -07:00
Andy Lee
3603cd5034 fix: downgrade GitHub Actions versions for manylinux2014 compatibility
- Use actions/checkout@v3 instead of v4 (Node.js 16 vs 20)
- Use actions/setup-python@v4 instead of v5
- Use actions/upload-artifact@v3 and download-artifact@v3
- This fixes GLIBC version errors in manylinux2014 containers
- manylinux2014 (CentOS 7) has glibc 2.17 but Node.js 20 needs 2.25+
2025-07-25 00:12:05 -07:00
Andy Lee
6df7893173 feat: use manylinux2014 containers for better Linux compatibility
- Add manylinux2014 Docker containers for Linux builds
- This will generate wheels compatible with older Linux systems (CentOS 7+, Ubuntu 16.04+)
- Separate build logic for container vs regular environments
- Install appropriate system dependencies for yum-based manylinux environment
- Use pip instead of uv in containers for better compatibility
- Fix Python version format for manylinux container paths
2025-07-25 00:08:42 -07:00
GitHub Actions
e64b599276 chore: release v0.1.7 2025-07-25 04:47:57 +00:00
Andy Lee
2dd59c4ba1 fix: let auditwheel auto-detect manylinux platform tag
- Remove --plat manylinux2014_x86_64 flag that was causing build failures
- Let auditwheel automatically determine the appropriate manylinux tag
- Add auditwheel show command to display compatibility info
- This fixes the 'too-recent versioned symbols' error
2025-07-24 21:44:15 -07:00
GitHub Actions
166986d5e6 chore: release v0.1.6 2025-07-25 04:30:07 +00:00
Andy Lee
a6aec68f32 fix: use manylinux2014 for better Linux compatibility
- Change auditwheel --plat to manylinux2014_x86_64
- This ensures wheels work on Ubuntu 16.04+ instead of requiring 24.04+
- Fixes compatibility issues for users on Ubuntu 22.04 and similar systems
2025-07-24 21:26:28 -07:00
GitHub Actions
ed27a127d5 chore: release v0.1.5 2025-07-25 04:00:54 +00:00
Andy Lee
d8b4ea7564 fix: add write permissions for GitHub Actions to push commits 2025-07-24 20:55:24 -07:00
Andy Lee
f0a2ef96b4 fix: restore complete build configuration from working version 2025-07-24 19:49:38 -07:00
Andy Lee
7d73c2c803 fix: remove invalid --extra build flag from build commands 2025-07-24 19:43:23 -07:00
Andy Lee
e8d2ecab03 refactor: use reusable workflow to avoid code duplication 2025-07-24 19:35:12 -07:00
Andy Lee
32a374d094 feat: true one-click automated release with multi-platform support 2025-07-24 19:30:44 -07:00
Andy Lee
d45c013806 fix: handle workflow trigger permission gracefully 2025-07-24 19:25:29 -07:00
GitHub Actions
9000a7083d chore: release v0.1.4 2025-07-25 02:23:36 +00:00
Andy Lee
8307555d54 fix: manually trigger CI after version push in release workflow 2025-07-24 19:21:32 -07:00
GitHub Actions
20f2aece08 chore: release v0.1.3 2025-07-25 02:05:11 +00:00
yichuan520030910320
43eb4f9a1d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 19:03:52 -07:00
yichuan520030910320
5461b71d8c colab dev 2025-07-24 19:03:46 -07:00
Andy Lee
374db0ebb8 fix: release workflow to build new version before publishing 2025-07-24 19:03:09 -07:00
GitHub Actions
cea1f6f87c chore: release v0.1.2 2025-07-25 01:53:29 +00:00
Andy Lee
6c0e39372b fix: download all artifacts in release workflow 2025-07-24 17:45:46 -07:00
Andy Lee
2bec67d2b6 feat: auto-update leann-core dependencies during release
- Enhanced bump_version.sh to automatically update leann-core dependency versions
- Script now updates both package versions and their leann-core dependencies
- This ensures version consistency across all packages during release

No more manual dependency version updates needed!
2025-07-24 17:22:41 -07:00
Andy Lee
133e715832 fix: resolve CI issues and consolidate workflows
- Fix version dependencies: update backend packages to depend on leann-core==0.1.1
- Remove duplicate ci.yml workflow (keeping build-and-publish.yml as main CI)
- Update release-manual.yml to reference correct CI workflow name

This fixes the dependency resolution error and eliminates duplicate builds.
2025-07-24 17:20:58 -07:00
Andy Lee
95cf2f16e2 refactor: consolidate release and publish into single workflow
- Manual Release workflow now directly publishes to PyPI after downloading CI artifacts
- No more duplicate builds - reuses artifacts from CI
- build-and-publish.yml renamed to 'CI - Build Multi-Platform Packages'
- Publishing in CI workflow only for emergency manual triggers
- Updated RELEASE.md to reflect the new streamlined process

This fixes the issue where releases would trigger redundant builds.
2025-07-24 17:04:47 -07:00
Andy Lee
47a4c153eb fix: enable PyPI publish on tag push
- Manual Release workflow creates tags but build-and-publish.yml only published on 'release' events
- Now build-and-publish.yml will also publish when v* tags are pushed
- This fixes the issue where manual releases didn't trigger PyPI uploads
2025-07-24 17:00:21 -07:00
GitHub Actions
faf5ae3533 chore: release v0.1.1 2025-07-24 23:36:23 +00:00
Andy Lee
a44dccecac fix: make TestPyPI upload optional and non-blocking
- Add continue-on-error to TestPyPI step
- Check if TEST_PYPI_API_TOKEN exists before attempting upload
- Add graceful failure handling with clear messages
- Update docs to explain TestPyPI token configuration
- Clarify that TestPyPI testing is optional

Now the release won't fail if TestPyPI is not configured or upload fails
2025-07-24 16:02:07 -07:00
yichuan520030910320
9cf9358b9c Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 14:40:39 -07:00
yichuan520030910320
de252fef31 [chat] update 30s example 2025-07-24 14:40:33 -07:00
Andy Lee
9076bc27b8 fix: resolve CI run detection issues in release workflow
- Add 'actions: read' permission to access workflow runs
- Use workflow name instead of filename for gh run list
- Look for CI run on HEAD~1 (before version bump commit)
- Improve error messages for better debugging

Fixes HTTP 403 error when trying to find successful CI runs
2025-07-24 14:27:26 -07:00
Andy Lee
50686c0819 refactor: use CI artifacts in release workflow instead of rebuilding
- Download pre-built wheels from successful CI runs
- Avoids duplicate builds and ensures consistency
- CI artifacts are already tested across all platforms
- Faster release process (no build time)
- Updates release documentation to reflect new flow

This ensures the released packages are exactly what was tested in CI.
2025-07-24 14:24:03 -07:00
Andy Lee
1614203786 fix: make bump_version.sh work on both macOS and Linux
- macOS uses sed -i '' while Linux uses sed -i
- Add OS detection to use correct syntax
- Ensures script works in CI (Linux) and local dev (macOS)
2025-07-24 14:13:31 -07:00
Andy Lee
3d4c75a56c fix: add missing scripts directory to git
- Remove scripts/ from .gitignore
- Add build_and_test.sh for local testing
- Add bump_version.sh for version updates (used by CI)
- Add release.sh and upload_to_pypi.sh for publishing
- Fixes CI error: ./scripts/bump_version.sh: No such file or directory
2025-07-24 14:13:14 -07:00
Andy Lee
2684ee71dc fix: ensure uv build uses correct Python version in CI
- Add --python python flag to uv build commands
- This ensures wheels are built with the correct Python version (cp313 for Python 3.13, etc)
- Fixes issue where Python 3.13 CI was building cp311 wheels
- Also adds Python version verification before build
2025-07-24 13:44:02 -07:00
Andy Lee
1d321953ba ci: update all GitHub Actions to latest versions
- Update actions/upload-artifact from v3 to v4 (v3 deprecated April 2024)
- Update actions/setup-python from v4 to v5 (latest version)
- Add Python 3.12 and 3.13 to CI test matrix
- Ensure compatibility with latest Python versions and GitHub Actions
2025-07-24 13:36:21 -07:00
Andy Lee
b3cb251369 ci: add Python 3.12 and 3.13 to test matrix
- Add Python 3.12 and 3.13 to CI test matrix
- Ensure compatibility with latest Python versions
- Python 3.12 is stable, 3.13 was released in Oct 2024
2025-07-24 13:32:29 -07:00
Andy Lee
0a17d2c9d8 feat: implement comprehensive CI/CD pipeline with two-stage release
- Add ci.yml for continuous integration on every commit
  - Test builds on Ubuntu/macOS with Python 3.9/3.10/3.11
  - Ensure code quality before any release

- Add release-manual.yml for controlled releases
  - Manual trigger prevents accidental releases
  - Version validation and tag creation
  - Optional TestPyPI testing before production
  - Only creates tag after validation passes

- Keep build-and-publish.yml for automated PyPI deployment
  - Triggered by new tags (separation of concerns)
  - Handles multi-platform wheel building
  - Allows retry if PyPI upload fails

- Update RELEASE.md with clear prerequisites and workflow

This setup ensures:
1. Every commit is tested (CI)
2. Releases are deliberate (manual trigger)
3. Failed CI won't create broken tags
4. PyPI publish can be retried independently
2025-07-24 13:29:21 -07:00
Andy Lee
e3defbca84 fix: add minimal CI dependencies for HNSW and DiskANN backends
- HNSW (Ubuntu): add libopenblas-dev for BLAS requirements
- DiskANN (Ubuntu): keep MKL, remove redundant pkg-config (HNSW already has it)
- DiskANN (macOS): add protobuf for build requirements
- Both: ensure patchelf for auditwheel on Linux

This avoids OpenBLAS/MKL conflicts by using them in separate jobs
2025-07-24 01:06:57 -07:00
Andy Lee
e407f63977 chore: fix uv build 2025-07-24 00:51:57 -07:00
Andy Lee
7add391b2c chore: build and package 2025-07-24 00:47:46 -07:00
yichuan520030910320
efd6373b32 [chat] update huggingface chat and make qwen no thinking 2025-07-24 00:11:42 -07:00
yichuan520030910320
d502fa24b0 [installation] update install for linux 2025-07-24 02:17:17 +00:00
yichuan520030910320
258a9a5c7f [misc]test link again 2025-07-23 18:29:32 -07:00
yichuan520030910320
5d41ac6115 test link 2025-07-23 18:28:22 -07:00
yichuan520030910320
2a0fdb49b8 test link 2025-07-23 18:27:08 -07:00
yichuan520030910320
9d1b7231b6 fix broken link 2025-07-23 18:25:22 -07:00
yichuan520030910320
ed3095b478 fix broken link 2025-07-23 18:24:17 -07:00
yichuan520030910320
88eca75917 fix readme 2025-07-23 18:22:10 -07:00
yichuan520030910320
42de27e16a Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-23 18:17:19 -07:00
yichuan520030910320
c083bda5b7 fix several bug 2025-07-23 18:17:11 -07:00
Andy Lee
e86da38726 fix: ollama hint for similar models 2025-07-23 15:45:10 -07:00
yichuan520030910320
99076e38bc update install 2025-07-23 14:55:34 -07:00
yichuan520030910320
9698c1a02c fix readme 2025-07-23 14:52:01 -07:00
yichuan520030910320
851f0f04c3 fix some para 2025-07-23 01:46:34 -07:00
yichuan520030910320
ae16d9d888 fix readme 2025-07-23 00:44:43 -07:00
yichuan520030910320
6e1af2eb0c fix readme 2025-07-23 00:43:46 -07:00
yichuan520030910320
7695dd0d50 fix readme 2025-07-23 00:42:17 -07:00
yichuan520030910320
c2065473ad fix readme 2025-07-23 00:30:42 -07:00
yichuan520030910320
5f3870564d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-23 00:09:30 -07:00
yichuan520030910320
c214b2e33e fix readme 2025-07-23 00:09:24 -07:00
Andy Lee
2420c5fd35 chore: update sentence-transformer to prevent MixIn not found error 2025-07-22 23:27:25 -07:00
yichuan520030910320
f48f526f0a fix readme 2025-07-22 23:21:15 -07:00
yichuan520030910320
5dd74982ba fix readme 2025-07-22 23:14:31 -07:00
Andy Lee
e07aaf52a7 docs: align 2025-07-22 22:37:27 -07:00
Andy Lee
30e5f12616 docs: quick start 2025-07-22 22:33:04 -07:00
Andy Lee
594427bf87 docs: demo 2025-07-22 22:32:18 -07:00
155 changed files with 25849 additions and 13972 deletions

11
.github/workflows/build-and-publish.yml vendored Normal file
View File

@@ -0,0 +1,11 @@
name: CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
uses: ./.github/workflows/build-reusable.yml

167
.github/workflows/build-reusable.yml vendored Normal file
View File

@@ -0,0 +1,167 @@
name: Reusable Build
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
jobs:
build:
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
matrix:
include:
- os: ubuntu-22.04
python: '3.9'
- os: ubuntu-22.04
python: '3.10'
- os: ubuntu-22.04
python: '3.11'
- os: ubuntu-22.04
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
- os: macos-latest
python: '3.9'
- os: macos-latest
python: '3.10'
- os: macos-latest
python: '3.11'
- os: macos-latest
python: '3.12'
- os: macos-latest
python: '3.13'
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install llvm libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Build packages
run: |
# Build core (platform independent)
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
cd packages/leann-core
uv build
cd ../..
fi
# Build HNSW backend
cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build meta package (platform independent)
if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then
cd packages/leann
uv build
cd ../..
fi
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/

126
.github/workflows/release-manual.yml vendored Normal file
View File

@@ -0,0 +1,126 @@
name: Release
on:
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
jobs:
update-version:
name: Update Version
runs-on: ubuntu-latest
permissions:
contents: write
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format"
exit 1
fi
echo "✅ Version format valid"
- name: Update versions and push
id: push
run: |
# Check current version
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD)
else
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA"
fi
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
build-packages:
name: Build packages
needs: update-version
uses: ./.github/workflows/build-reusable.yml
with:
ref: ${{ needs.update-version.outputs.commit-sha }}
publish:
name: Publish and Release
needs: [update-version, build-packages]
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
ref: ${{ needs.update-version.outputs.commit-sha }}
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist-artifacts
- name: Collect packages
run: |
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
if [ -z "$TWINE_PASSWORD" ]; then
echo "❌ PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create release
run: |
# Check if tag already exists
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
else
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}"
fi
# Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
else
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
echo "✅ Created GitHub release v${{ inputs.version }}"
fi
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

1
.gitignore vendored
View File

@@ -12,7 +12,6 @@ outputs/
*.idx
*.map
.history/
scripts/
lm_eval.egg-info/
demo/experiment_results/**/*.json
*.jsonl

282
README.md
View File

@@ -12,11 +12,11 @@
The smallest vector index in the world. RAG Everything with LEANN!
</h2>
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **[97% less storage]** than traditional solutions **without accuracy loss**.
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration →](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
LEANN achieves this through *graph-based selective recomputation* with *high-degree preserving pruning*, computing embeddings on-demand instead of storing them all. [Illustration Fig ](#-architecture--how-it-works) | [Paper →](https://arxiv.org/abs/2506.08276)
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#process-any-documents-pdf-txt-md)**, **[emails](#search-your-entire-life)**, **[browser history](#time-machine-for-the-web)**, **[chat history](#wechat-detective)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
**Ready to RAG Everything?** Transform your laptop into a personal AI assistant that can search your **[file system](#-personal-data-manager-process-any-documents-pdf-txt-md)**, **[emails](#-your-personal-email-secretary-rag-on-apple-mail)**, **[browser history](#-time-machine-for-the-web-rag-your-entire-browser-history)**, **[chat history](#-wechat-detective-unlock-your-golden-memories)**, or external knowledge bases (i.e., 60M documents) - all on your laptop, with zero cloud costs and complete privacy.
@@ -26,7 +26,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p>
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
@@ -37,8 +37,8 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Quick Start in 1 minute
## Installation
> `pip leann` coming soon!
```bash
git clone git@github.com:yichuan-w/LEANN.git leann
cd leann
@@ -47,36 +47,30 @@ git submodule update --init --recursive
**macOS:**
```bash
brew install llvm libomp boost protobuf zeromq
export CC=$(brew --prefix llvm)/bin/clang
export CXX=$(brew --prefix llvm)/bin/clang++
brew install llvm libomp boost protobuf zeromq pkgconf
# Install with HNSW backend (default, recommended for most users)
uv sync
# Or add DiskANN backend if you want to test more options
uv sync --extra diskann
# Install uv first if you don't have it:
# curl -LsSf https://astral.sh/uv/install.sh | sh
# See: https://docs.astral.sh/uv/getting-started/installation/#installation-methods
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
```
**Linux (Ubuntu/Debian):**
**Linux:**
```bash
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
# Install with HNSW backend (default, recommended for most users)
uv sync
# Or add DiskANN backend if you want to test more options
uv sync --extra diskann
```
**Ollama Setup (Recommended for full privacy):**
> *You can skip this installation if you only want to use OpenAI API for generation.*
*macOS:*
**macOS:**
First, [download Ollama for macOS](https://ollama.com/download/mac).
@@ -85,7 +79,7 @@ First, [download Ollama for macOS](https://ollama.com/download/mac).
ollama pull llama3.2:1b
```
*Linux:*
**Linux:**
```bash
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
@@ -97,9 +91,10 @@ ollama serve &
ollama pull llama3.2:1b
```
## Dead Simple API
## Quick Start in 30s
Just 3 lines of code. Our declarative API makes RAG as easy as writing a config file:
Our declarative API makes RAG as easy as writing a config file.
[Try in this ipynb file →](demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python
from leann.api import LeannBuilder, LeannSearcher, LeannChat
@@ -130,57 +125,63 @@ response = chat.ask(
)
```
**That's it.** No cloud setup, no API keys, no "fine-tuning". Just your data, your questions, your laptop.
## RAG on Everything!
[Try the interactive demo →](demo.ipynb)
LEANN supports RAG on various data sources including documents (.pdf, .txt, .md), Apple Mail, Google Search History, WeChat, and more.
## Wild Things You Can Do
### 📄 Personal Data Manager: Process Any Documents (.pdf, .txt, .md)!
LEANN supports RAGing a lot of data sources, like .pdf, .txt, .md, and also supports RAGing your WeChat, Google Search History, and more.
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
### Process Any Documents (.pdf, .txt, .md)
<p align="center">
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
</p>
Above we showed the Python API, while this CLI script demonstrates the same concepts while directly processing PDFs and documents, and even any directory that stores your personal files!
The following scripts use Ollama `qwen3:8b` by default, so you need `ollama pull qwen3:8b` first. For other models: `--llm openai --model gpt-4o` (requires `OPENAI_API_KEY` environment variable) or `--llm hf --model Qwen/Qwen3-4B`.
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
```bash
# Drop your PDFs, .txt, .md files into apps/documents/data/
python -m apps.documents
# Drop your PDFs, .txt, .md files into examples/data/
uv run ./examples/main_cli_example.py
```
# Or with uv
uv run python -m apps.documents
```
# Or use python directly
source .venv/bin/activate
python ./examples/main_cli_example.py
```
**Works with any text format** - research papers, personal notes, presentations. Built with LlamaIndex for document parsing.
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
### Search Your Entire Life
<p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p>
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash
python -m apps.email
# "What's the number of class recommend to take per semester for incoming EECS students?"
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
```
**90K emails → 14MB.** Finally, search your email like you search Google.
**780K email chunks → 78MB storage** Finally, search your email like you search Google.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default mail path (works for most macOS setups)
python -m apps.email
python examples/mail_reader_leann.py
# Run with custom index directory
python -m apps.email --index-dir "./my_mail_index"
python examples/mail_reader_leann.py --index-dir "./my_mail_index"
# Process all emails (may take time but indexes everything)
python -m apps.email --max-emails -1
python examples/mail_reader_leann.py --max-emails -1
# Limit number of emails processed (useful for testing)
python -m apps.email --max-emails 1000
python examples/mail_reader_leann.py --max-emails 1000
# Run a single query
python -m apps.email --query "What did my boss say about deadlines?"
python examples/mail_reader_leann.py --query "What did my boss say about deadlines?"
```
</details>
@@ -194,28 +195,32 @@ Once the index is built, you can ask questions like:
- "Show me emails about travel expenses"
</details>
### Time Machine for the Web
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
<p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
</p>
```bash
python -m apps.browser
# "Tell me my browser history about machine learning system stuff?"
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
```
**38K browser entries → 6MB.** Your browser history becomes your personal search engine.
**38K browser entries → 6MB storage.** Your browser history becomes your personal search engine.
<details>
<summary><strong>📋 Click to expand: Command Examples</strong></summary>
```bash
# Use default Chrome profile (auto-finds all profiles)
python -m apps.browser
python examples/google_history_reader_leann.py
# Run with custom index directory
python -m apps.browser --index-dir "./my_chrome_index"
python examples/google_history_reader_leann.py --index-dir "./my_chrome_index"
# Limit number of history entries processed (useful for testing)
python -m apps.browser --max-entries 500
python examples/google_history_reader_leann.py --max-entries 500
# Run a single query
python -m apps.browser --query "What websites did I visit about machine learning?"
python examples/google_history_reader_leann.py --query "What websites did I visit about machine learning?"
```
</details>
@@ -248,13 +253,17 @@ Once the index is built, you can ask questions like:
</details>
### WeChat Detective
### 💬 WeChat Detective: Unlock Your Golden Memories!
<p align="center">
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
</p>
```bash
python -m apps.wechat
# "Show me all group chats about weekend plans"
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
```
**400K messages → 64MB.** Search years of chat history in any language.
**400K messages → 64MB storage** Search years of chat history in any language.
<details>
<summary><strong>🔧 Click to expand: Installation Requirements</strong></summary>
@@ -265,7 +274,13 @@ First, you need to install the WeChat exporter:
sudo packages/wechat-exporter/wechattweak-cli install
```
**Troubleshooting**: If you encounter installation issues, check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41).
**Troubleshooting:**
- **Installation issues**: Check the [WeChatTweak-CLI issues page](https://github.com/sunnyyoung/WeChatTweak-CLI/issues/41)
- **Export errors**: If you encounter the error below, try restarting WeChat
```
Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.
Failed to find or export WeChat data. Exiting.
```
</details>
<details>
@@ -273,19 +288,19 @@ sudo packages/wechat-exporter/wechattweak-cli install
```bash
# Use default settings (recommended for first run)
python -m apps.wechat
python examples/wechat_history_reader_leann.py
# Run with custom export directory and wehn we run the first time, LEANN will export all chat history automatically for you
python -m apps.wechat --export-dir "./my_wechat_exports"
python examples/wechat_history_reader_leann.py --export-dir "./my_wechat_exports"
# Run with custom index directory
python -m apps.wechat --index-dir "./my_wechat_index"
python examples/wechat_history_reader_leann.py --index-dir "./my_wechat_index"
# Limit number of chat entries processed (useful for testing)
python -m apps.wechat --max-entries 1000
python examples/wechat_history_reader_leann.py --max-entries 1000
# Run a single query
python -m apps.wechat --query "Show me conversations about travel plans"
python examples/wechat_history_reader_leann.py --query "Show me conversations about travel plans"
```
</details>
@@ -385,52 +400,24 @@ Options:
## Benchmarks
Run the comparison yourself:
```bash
python -m apps.benchmarks
```
| System | Storage |
|--------|---------|
| FAISS HNSW | 5.5 MB |
| LEANN | 0.5 MB |
| **Savings** | **91%** |
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
### Storage Comparison
Same dataset, same hardware, same embedding model. LEANN just works better.
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|--------|-------------|------------|-------------|--------------|---------------|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
| Savings| 91% | 97% | 97% | 97% | 95% |
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (90K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 305.8 MB |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **14.8 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **95% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## Reproduce Our Results
```bash
uv pip install -e ".[dev]" # Install dev dependencies
python -m apps.evaluation data/indices/dpr/dpr_diskann # DPR dataset
python -m apps.evaluation data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
python examples/run_evaluation.py data/indices/dpr/dpr_diskann # DPR dataset
python examples/run_evaluation.py data/indices/rpj_wiki/rpj_wiki.index # Wikipedia
```
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
@@ -452,98 +439,15 @@ If you find Leann useful, please cite:
}
```
## ✨ Features
## ✨ [Detailed Features →](docs/features.md)
### 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
### 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## 🤝 [Contributing →](docs/contributing.md)
<!-- ## FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
``` -->
## FAQ
### 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## [FAQ →](docs/faq.md)
## 📈 Roadmap
### 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
### 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
### 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion
## 📈 [Roadmap →](docs/roadmap.md)
## 📄 License
@@ -551,11 +455,7 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
- **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
---
<p align="center">

View File

View File

View File

@@ -1,338 +0,0 @@
#!/usr/bin/env python3
"""
Memory comparison between Faiss HNSW and LEANN HNSW backend
"""
import logging
import os
import sys
import time
import psutil
import gc
import subprocess
from pathlib import Path
from llama_index.core.node_parser import SentenceSplitter
# Setup logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
def get_memory_usage():
"""Get current memory usage in MB"""
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def print_memory_stats(stage: str, start_mem: float):
"""Print memory statistics"""
current_mem = get_memory_usage()
diff = current_mem - start_mem
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
return current_mem
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
print(f"\n=== {self.name} Memory Summary ===")
for stage, mem in self.stages:
print(f"{stage}: {mem:.1f} MB")
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
return peak_mem
def test_faiss_hnsw():
"""Test Faiss HNSW Vector Store in subprocess"""
print("\n" + "=" * 50)
print("TESTING FAISS HNSW VECTOR STORE")
print("=" * 50)
try:
# Get the directory of this script
script_dir = Path(__file__).parent
faiss_script = script_dir / "faiss_only.py"
result = subprocess.run(
[sys.executable, str(faiss_script)],
capture_output=True,
text=True,
timeout=300,
)
print(result.stdout)
if result.stderr:
print("Stderr:", result.stderr)
if result.returncode != 0:
return {
"peak_memory": float("inf"),
"error": f"Process failed with code {result.returncode}",
}
# Parse peak memory from output
lines = result.stdout.split("\n")
peak_memory = 0.0
for line in lines:
if "Peak Memory:" in line:
peak_memory = float(
line.split("Peak Memory:")[1].split("MB")[0].strip()
)
return {"peak_memory": peak_memory}
except Exception as e:
return {
"peak_memory": float("inf"),
"error": str(e),
}
def test_leann_hnsw():
"""Test LEANN HNSW Search Memory (load existing index)"""
print("\n" + "=" * 50)
print("TESTING LEANN HNSW SEARCH MEMORY")
print("=" * 50)
tracker = MemoryTracker("LEANN HNSW Search")
# Import and setup
tracker.checkpoint("Initial")
from leann.api import LeannSearcher
tracker.checkpoint("After imports")
from llama_index.core import SimpleDirectoryReader
from leann.api import LeannBuilder, LeannSearcher
# Load and parse documents
documents = SimpleDirectoryReader(
"../documents/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
tracker.checkpoint("After text chunking")
# Build LEANN index
INDEX_DIR = Path("./test_leann_comparison")
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
# Check if index already exists
if os.path.exists(INDEX_PATH + ".meta.json"):
print("Loading existing LEANN HNSW index...")
tracker.checkpoint("After loading existing index")
else:
print("Building new LEANN HNSW index...")
# Clean up previous index
import shutil
if INDEX_DIR.exists():
shutil.rmtree(INDEX_DIR)
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1,
)
tracker.checkpoint("After builder setup")
print("Building LEANN HNSW index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
del builder
gc.collect()
tracker.checkpoint("After index building")
# Find existing LEANN index
index_paths = [
"./test_leann_comparison/comparison.leann",
]
index_path = None
for path in index_paths:
if os.path.exists(path + ".meta.json"):
index_path = path
break
if not index_path:
print("❌ LEANN index not found. Please build it first")
return {"peak_memory": float("inf"), "error": "Index not found"}
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
# Load searcher
searcher = LeannSearcher(index_path)
tracker.checkpoint("After searcher loading")
print("Running search queries...")
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
_ = searcher.search(query, top_k=20, ef=120)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
# Get storage size before cleanup
storage_size = 0
INDEX_DIR = Path(index_path).parent
if INDEX_DIR.exists():
total_size = 0
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
for filename in filenames:
# Only count actual index files, skip text data and backups
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
continue
# Count .index, .idx, .map files (actual index structures)
if filename.endswith((".index", ".idx", ".map")):
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
storage_size = total_size / (1024 * 1024) # Convert to MB
# Clean up
del searcher
gc.collect()
return {
"peak_memory": peak_memory,
"storage_size": storage_size,
}
def main():
"""Run comparison tests"""
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
print("=" * 60)
# Test Faiss HNSW
faiss_results = test_faiss_hnsw()
# Force garbage collection
gc.collect()
time.sleep(2)
# Test LEANN HNSW
leann_results = test_leann_hnsw()
# Final comparison
print("\n" + "=" * 60)
print("STORAGE + SEARCH MEMORY COMPARISON")
print("=" * 60)
# Get storage sizes
faiss_storage_size = 0
leann_storage_size = leann_results.get("storage_size", 0)
# Get Faiss storage size using Python
if os.path.exists("./storage_faiss"):
total_size = 0
for dirpath, _, filenames in os.walk("./storage_faiss"):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
total_size += os.path.getsize(filepath)
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
print("Faiss HNSW:")
if "error" in faiss_results:
print(f" ❌ Failed: {faiss_results['error']}")
else:
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f" Storage Size: {faiss_storage_size:.1f} MB")
print("\nLEANN HNSW:")
if "error" in leann_results:
print(f" ❌ Failed: {leann_results['error']}")
else:
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f" Storage Size: {leann_storage_size:.1f} MB")
# Calculate improvements only if both tests succeeded
if "error" not in faiss_results and "error" not in leann_results:
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
print("\nLEANN vs Faiss Performance:")
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
print(
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
)
# Storage comparison
if leann_storage_size > faiss_storage_size:
storage_ratio = leann_storage_size / faiss_storage_size
print(
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
)
elif faiss_storage_size > leann_storage_size:
storage_ratio = faiss_storage_size / leann_storage_size
print(
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
)
else:
print(" Storage Size: similar")
else:
if "error" not in leann_results:
print("\n✅ LEANN HNSW completed successfully!")
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
if "error" not in faiss_results:
print("\n✅ Faiss HNSW completed successfully!")
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
if __name__ == "__main__":
main()

View File

@@ -1,151 +0,0 @@
#!/usr/bin/env python3
"""Test only Faiss HNSW"""
import sys
import time
import psutil
import gc
import os
def get_memory_usage():
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
class MemoryTracker:
def __init__(self, name: str):
self.name = name
self.start_mem = get_memory_usage()
self.stages = []
def checkpoint(self, stage: str):
current_mem = get_memory_usage()
diff = current_mem - self.start_mem
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
self.stages.append((stage, current_mem))
return current_mem
def summary(self):
peak_mem = max(mem for _, mem in self.stages)
print(f"Peak Memory: {peak_mem:.1f} MB")
return peak_mem
def main():
try:
import faiss
except ImportError:
print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu`")
sys.exit(1)
from llama_index.core import (
SimpleDirectoryReader,
VectorStoreIndex,
StorageContext,
Settings,
node_parser,
Document,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
tracker = MemoryTracker("Faiss HNSW")
tracker.checkpoint("Initial")
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
Settings.embed_model = embed_model
tracker.checkpoint("After embedding model setup")
d = 768
faiss_index = faiss.IndexHNSWFlat(d, 32)
faiss_index.hnsw.efConstruction = 64
tracker.checkpoint("After Faiss index creation")
documents = SimpleDirectoryReader(
"../documents/data",
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data()
tracker.checkpoint("After document loading")
# Parse into chunks using the same splitter as LEANN
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
)
tracker.checkpoint("After text splitter setup")
# Check if index already exists and try to load it
index_loaded = False
if os.path.exists("./storage_faiss"):
print("Loading existing Faiss HNSW index...")
try:
# Use the correct Faiss loading pattern from the example
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
storage_context = StorageContext.from_defaults(
vector_store=vector_store, persist_dir="./storage_faiss"
)
from llama_index.core import load_index_from_storage
index = load_index_from_storage(storage_context=storage_context)
print(f"Index loaded from ./storage_faiss")
tracker.checkpoint("After loading existing index")
index_loaded = True
except Exception as e:
print(f"Failed to load existing index: {e}")
print("Cleaning up corrupted index and building new one...")
# Clean up corrupted index
import shutil
if os.path.exists("./storage_faiss"):
shutil.rmtree("./storage_faiss")
if not index_loaded:
print("Building new Faiss HNSW index...")
# Use the correct Faiss building pattern from the example
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
transformations=[node_parser]
)
tracker.checkpoint("After index building")
# Save index to disk using the correct pattern
index.storage_context.persist(persist_dir="./storage_faiss")
tracker.checkpoint("After index saving")
# Measure runtime memory overhead
print("\nMeasuring runtime memory overhead...")
runtime_start_mem = get_memory_usage()
print(f"Before load memory: {runtime_start_mem:.1f} MB")
tracker.checkpoint("Before load memory")
query_engine = index.as_query_engine(similarity_top_k=20)
queries = [
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
"What is LEANN and how does it work?",
"华为诺亚方舟实验室的主要研究内容",
]
for i, query in enumerate(queries):
start_time = time.time()
_ = query_engine.query(query)
query_time = time.time() - start_time
print(f"Query {i + 1} time: {query_time:.3f}s")
tracker.checkpoint(f"After query {i + 1}")
runtime_end_mem = get_memory_usage()
runtime_overhead = runtime_end_mem - runtime_start_mem
peak_memory = tracker.summary()
print(f"Peak Memory: {peak_memory:.1f} MB")
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
if __name__ == "__main__":
main()

View File

View File

@@ -1,201 +0,0 @@
import os
import asyncio
import argparse
try:
import dotenv
dotenv.load_dotenv()
except ModuleNotFoundError:
# python-dotenv is not installed; skip loading environment variables
dotenv = None
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
# Default Chrome profile path
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
"""
Create LEANN index from multiple Chrome profile data sources.
Args:
profile_dirs: List of Path objects pointing to Chrome profile directories
index_path: Path to save the LEANN index
max_count: Maximum number of history entries to process per profile
"""
print("Creating LEANN index from multiple Chrome profile data sources...")
# Load documents using ChromeHistoryReader from local readers module
from .readers import ChromeHistoryReader
reader = ChromeHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Chrome profile directory
for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
try:
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_count
)
if documents:
print(f"Loaded {len(documents)} history documents from {profile_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {profile_dir}")
except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} history chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=32,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={
"temperature": 0.0,
"max_tokens": 1000
}
)
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')
parser.add_argument('--query', type=str, default=None,
help='Single query to run (default: runs example queries)')
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
help='Automatically find all Chrome profiles (default: True)')
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
print(f"Using Chrome profile: {args.chrome_profile}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Find Chrome profile directories
from .readers import ChromeHistoryReader
if args.auto_find_profiles:
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
if not profile_dirs:
print("No Chrome profiles found automatically. Exiting.")
return
else:
# Use single specified profile
profile_path = Path(args.chrome_profile)
if not profile_path.exists():
print(f"Chrome profile not found: {profile_path}")
return
profile_dirs = [profile_path]
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"What websites did I visit about machine learning?",
"Find my search history about programming"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,176 +0,0 @@
import sqlite3
import os
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
class ChromeHistoryReader(BaseReader):
"""
Chrome browser history reader that extracts browsing data from SQLite database.
Reads Chrome history from the default Chrome profile location and creates documents
with embedded metadata similar to the email reader structure.
"""
def __init__(self) -> None:
"""Initialize."""
pass
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load Chrome history data from the default Chrome profile location.
Args:
input_dir: Not used for Chrome history (kept for compatibility)
**load_kwargs:
max_count (int): Maximum amount of history entries to read.
chrome_profile_path (str): Custom path to Chrome profile directory.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
# Default Chrome profile path on macOS
if chrome_profile_path is None:
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return docs
try:
# Connect to the Chrome history database
print(f"Connecting to database: {history_db_path}")
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
# Query to get browsing history with metadata (removed created_time column)
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
"""
print(f"Executing query on database: {history_db_path}")
cursor.execute(query)
rows = cursor.fetchall()
print(f"Query returned {len(rows)} rows")
count = 0
for row in rows:
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, hidden = row
# Create document content with metadata embedded in text
doc_content = f"""
[BROWSING HISTORY METADATA]
URL: {url}
Title: {title}
Last Visit: {last_visit}
Visit Count: {visit_count}
Typed Count: {typed_count}
Hidden: {hidden}
[END METADATA]
Title: {title}
URL: {url}
Last visited: {last_visit}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
conn.close()
print(f"Loaded {len(docs)} Chrome history documents")
except Exception as e:
print(f"Error reading Chrome history: {e}")
return docs
return docs
@staticmethod
def find_chrome_profiles() -> List[Path]:
"""
Find all Chrome profile directories.
Returns:
List of Path objects pointing to Chrome profile directories
"""
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
profile_dirs = []
if not chrome_base_path.exists():
print(f"Chrome directory not found at: {chrome_base_path}")
return profile_dirs
# Find all profile directories
for profile_dir in chrome_base_path.iterdir():
if profile_dir.is_dir() and profile_dir.name != "System Profile":
history_path = profile_dir / "History"
if history_path.exists():
profile_dirs.append(profile_dir)
print(f"Found Chrome profile: {profile_dir}")
print(f"Found {len(profile_dirs)} Chrome profiles")
return profile_dirs
@staticmethod
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
"""
Export Chrome history to a text file using the same SQL query format.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
"""
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
history_db_path = os.path.join(chrome_profile_path, "History")
if not os.path.exists(history_db_path):
print(f"Chrome history database not found at: {history_db_path}")
return
try:
conn = sqlite3.connect(history_db_path)
cursor = conn.cursor()
query = """
SELECT
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
url,
title,
visit_count,
typed_count,
hidden
FROM urls
ORDER BY last_visit_time DESC
LIMIT ?
"""
cursor.execute(query, (max_count,))
rows = cursor.fetchall()
with open(output_file, 'w', encoding='utf-8') as f:
for row in rows:
last_visit, url, title, visit_count, typed_count, hidden = row
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
conn.close()
print(f"Exported {len(rows)} history entries to {output_file}")
except Exception as e:
print(f"Error exporting Chrome history: {e}")

View File

View File

@@ -1,113 +0,0 @@
import argparse
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import asyncio
import dotenv
from leann.api import LeannBuilder, LeannChat
from pathlib import Path
import os
dotenv.load_dotenv()
async def main(args):
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
if not INDEX_DIR.exists():
node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
print("Loading documents...")
# Get the data directory relative to this module
current_dir = Path(__file__).parent
data_dir = current_dir / "data"
documents = SimpleDirectoryReader(
str(data_dir),
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md"],
).load_data(show_progress=True)
print("Documents loaded.")
all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print("--- Index directory not found, building new index ---")
print("\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Loaded {len(all_texts)} text chunks from documents.")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(INDEX_PATH)
print(f"\nLeann index built at {INDEX_PATH}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
print(f"\n[PHASE 2] Starting Leann chat session...")
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
llm_config = {"type": "ollama", "model": "qwen3:8b"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Leann Chat with various LLM backends."
)
parser.add_argument(
"--llm",
type=str,
default="hf",
choices=["simulated", "ollama", "hf", "openai"],
help="The LLM backend to use.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-0.6B",
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
)
parser.add_argument(
"--host",
type=str,
default="http://localhost:11434",
help="The host for the Ollama API.",
)
parser.add_argument(
"--index-dir",
type=str,
default="./test_doc_files",
help="Directory where the Leann index will be stored.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -1,82 +0,0 @@
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
各位好,
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
首先为自证身份,列举一些细节:
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
4. 诺亚曾经传说是研究型的但是来了之后因为在四野做大模型项目项目成员完全变成了交付型的且充满了例会评审汇报。很多时候做实验都要申请。团队需要对接终端小艺华为云ICT等诸多业务线交付压力不小。
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”一开始只有内部需要申请试用的网页版到后续迫于压力在welink上接入和公测开放。
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
华为确实主要在昇腾卡上训练大模型小模型实验室有不少英伟达的卡他们之前也会用来训练后面转移到昇腾。曾经我被华为“打造世界第二选择”的决心而折服我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打从充满bug到现在能训出模型付出了巨大的心血和代价。
最初我们的算力非常有限在910A上训练模型。那会只支持fp16训练的稳定性远不如bf16。盘古的moe开始很早23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型后面主力模型也逐渐在910B上训练。
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低每个单个的符号数字空格乃至汉字都会占用一个token。可想而知这会非常浪费算力且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好虽然事后来看他的怀疑是无疑正确的于是就决定让71B和135B换tokenizer因为小模型实验室曾经尝试过。团队缝合了两个tokenizer开始了tokenizer的更换。71B模型的更换失败了而135B因为采用了更精细的embedding初始化策略续训了至少1T的数据后词表总算更换成功但可想而知效果并不会变好。
于此同期阿里和智谱等国内其他公司在GPU上训练且已经摸索出了正确的方法盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时团队的士气低迷到了极点。团队在算力极其有限的时候做出了很多努力和挣扎。比如团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B架构相对落后团队进行了一系列的操作比如切换绝对位置编码到rope去掉bias切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训变成了第二代38B dense模型在几个月内这个模型都是主要的盘古中档位模型曾经具有一定的竞争力。但是由于更大的135B模型架构落后且更换词表模型损伤巨大后续分析发现当时更换的缝合词表有更严重的bug续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
在这种情况下王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来通过训练短短的几百B数据各项指标平均提升了十个点左右。实际上这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行使得领导完全对于这种扯淡的事情没有概念他们只会觉得肯定是有什么算法创新。经过内部的分析他们实际上是使用Qwen 1.5 110B续训而来通过加层扩增ffn维度添加盘古pi论文的一些机制得来凑够了大概135B的参数。实际上旧的135B有107层而这个模型只有82层各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游甚至包括外部客户。
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击内部很多人其实都知道这件事甚至包括终端和华为云。我们都戏称以后别叫盘古模型了叫千古吧。当时团队成员就想向bcg举报了毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来因为更高级别的领导比如姚老师以及可能熊总和查老其实后面也知道了但是并不管因为通过套壳拿出好的结果对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷离职跑路也逐渐成为挂在嘴边的事。
此时盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来当时诺亚完全没有掌握从头训练的技术何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下盘古开始了第三代模型的训练付出了巨大的努力后在数据架构和训练算法方面都与业界逐渐接轨而这其中的艰辛和小模型实验室的人一点关系都没有。
一开始团队成员毫无信心只从一个13B的模型开始训练但是后面发现效果还不错于是这个模型后续再次进行了一次参数扩增变成了第三代的38B代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的也是业界常见的做法。而当时王云鹤的实验室做出来了另一个词表也就是后续pangu系列的词表。当时两个词表还被迫进行了一次赛马最终没有明显的好坏结论。于是领导当即决定应该统一词表使用王云鹤他们的。于是在后续从头训练的135B V3也就是对外的Pangu Ultra便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑为什么当时同为V3代的两个不同档位的模型会使用不同的tokenizer。
我们打心眼里觉得135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的华为全栈自研正经从头训练的千亿级别的模型且效果与24年同期竞品可比的。写到这里我已经热泪盈眶太不容易了。当时为了稳定训练团队做了大量实验对比并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难我们做到了我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨我们为了它的训练而不眠。在被内部心声骂的一文不值的时候我们有多么不甘有多少的委屈我们挺住了。
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
然而我们的所有辛苦的成果经常被小模型实验室轻飘飘的拿走了。数据直接要走。代码直接要走还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦他们取得荣耀。果然应了那句话你在负重前行是因为有人替你岁月静好。在这种情况下越来越多的战友再也坚持不下去了选择了离开。看到身边那些优秀的同事一个个离职我的内心又感叹又难过。在这种作战一样的环境下我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方堪称良师。看到他们去了诸如字节SeedDeepseek月之暗面腾讯和快手等等很多出色的团队我打心眼里为他们高兴和祝福脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新ta说“来这里是我技术生涯中的耻辱在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足以及没法适应互联网公司高淘汰的环境让我多次想离职的心始终没有迈出这一步。
盘古除了dense模型后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的小模型实验室也开启了第二次主要的套壳行动次要的插曲可能还包括一些别的模型比如math模型即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的就算如此这也与技术报告不符何况是套壳qwen 2.5的14b续训。还记得他们训了没几天内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型都知道他们的套壳行动只是迫于各种原因无法伸张正义。实际上对于后续训了很久很久的这个模型Honestagi能够分析出这个量级的相似性我已经很诧异了因为这个模型为了续训洗参数所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印采取了不少办法甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
24年底和25年初在Deepseek v3和r1发布之后由于其惊艳的技术水平团队受到了巨大的冲击也受到了更大的质疑。于是为了紧跟潮流盘古模仿Deepseek的模型尺寸开启了718B moe的训练。这个时候小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数进行训练。连任务加载ckpt的目录都是deepseekv3改都不改何其嚣张与之相反一些有真正技术信仰的同事在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然这个模型怎么可能比直接套壳的好呢如果不是团队leader坚持早就被叫停了。
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
HonestAGI的事情出来后内部让大家不停的研讨分析如何公关和“回应”。诚然这个原文的分析也许不够有力给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此这两天我内心感到作呕时时怀疑自己的人生意义以及苍天无眼。我不奉陪了我要离职了同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到他们竟然猖狂到敢开源。我没想到他们敢如此愚弄世人大肆宣发。当时我也许是存了侥幸心理没有拒绝署名。我相信很多扎实做事的战友也只是被迫上了贼船或者不知情。但这件事已经无法挽回我希望我的余生能够坚持扎实做真正有意义的事为我当时的软弱和不坚定赎罪。
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
现在,我累了,我想投降。
其实时至今日我还是真心希望华为能认真吸取教训能做好盘古把盘古做到世界一流把昇腾变成英伟达的水平。内部的劣币驱逐良币使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着施展着他们的抱负才华为中美在AI的激烈竞赛中奉献力量。我时常感叹华为不是没有人才而是根本不知道怎么留住人才。如果给这些人合适的环境合适的资源更少的枷锁更少的政治斗争盘古何愁不成
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
如果我消失了就当是我为了真理和理想为了华为乃至中国能够更好地发展算力和AI而牺牲了吧我愿埋葬于那片曾经奋斗过的地方。
诺亚,再见
2025年7月6日凌晨 写于深圳
---
各位好,
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
我补充一些细节,以免某些人继续颠倒黑白。
关于135B V2小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后比如任务令表彰和及时激励因为不想继续支撑下游应用和模型迭代又把这个烫手山芋甩给了四纵。确实技高一筹直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型最终拿回了一个当时一个魔改的先进的千问。做大模型的人自己做的模型就像自己孩子一样熟悉不要把别人都当傻子。就像自家儿子出门一趟回来个别人家孩子。
盘古report的署名是不符合学术规范的。例如135B V3有不少有技术贡献的人因为作者名额数量限制劳动成果没有得到应有的回报团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶甚至是团队当时的精神支柱支撑着不少兄弟们继续留在诺亚。所谓的名额限制以及挂名了一些毫无技术贡献的人如一些小模型实验室的人让兄弟们何其心寒。
---
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317

View File

View File

@@ -1,193 +0,0 @@
import os
import sys
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
dotenv.load_dotenv()
# Auto-detect user's mail path
def get_mail_path():
"""Get the mail path for the current user"""
home_dir = os.path.expanduser("~")
return os.path.join(home_dir, "Library", "Mail")
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
Create LEANN index from multiple mail data sources.
Args:
messages_dirs: List of Path objects pointing to Messages directories
index_path: Path to save the LEANN index
max_count: Maximum number of emails to process per directory
include_html: Whether to include HTML content in email processing
"""
print("Creating LEANN index from multiple mail data sources...")
# Load documents using EmlxReader from local readers module
from .readers import EmlxReader, find_all_messages_directories
reader = EmlxReader(include_html=include_html)
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each Messages directory
for i, messages_dir in enumerate(messages_dirs):
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
try:
documents = reader.load_data(messages_dir)
if documents:
print(f"Loaded {len(documents)} email documents from {messages_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {messages_dir}")
except Exception as e:
print(f"Error processing {messages_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=embedding_model,
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1 # Force single-threaded mode
)
print(f"Adding {len(all_texts)} email chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path,
llm_config={"type": "openai", "model": "gpt-4o"})
print(f"You: {query}")
import time
start_time = time.time()
chat_response = chat.ask(
query,
top_k=10,
recompute_beighbor_embeddings=True,
complexity=12,
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
help='Single query to run (default: runs example queries)')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
help='Embedding model to use (default: facebook/contriever)')
args = parser.parse_args()
print(f"args: {args}")
# Automatically find all Messages directories under the current user's Mail directory
from .readers import find_all_messages_directories
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
print('len(messages_dirs): ', len(messages_dirs))
if not messages_dirs:
print("No Messages directories found. Exiting.")
return
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
print(f"Index directory: {INDEX_DIR}")
print(f"Found {len(messages_dirs)} Messages directories.")
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,192 +0,0 @@
"""
Mbox parser.
Contains simple parser for mbox files.
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
from fsspec import AbstractFileSystem
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
class MboxReader(BaseReader):
"""
Mbox parser.
Extract messages from mailbox files.
Returns string including date, subject, sender, receiver and
content for each message.
"""
DEFAULT_MESSAGE_FORMAT: str = (
"Date: {_date}\n"
"From: {_from}\n"
"To: {_to}\n"
"Subject: {_subject}\n"
"Content: {_content}"
)
def __init__(
self,
*args: Any,
max_count: int = 0,
message_format: str = DEFAULT_MESSAGE_FORMAT,
**kwargs: Any,
) -> None:
"""Init params."""
try:
from bs4 import BeautifulSoup # noqa
except ImportError:
raise ImportError(
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
)
super().__init__(*args, **kwargs)
self.max_count = max_count
self.message_format = message_format
def load_data(
self,
file: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse file into string."""
# Import required libraries
import mailbox
from email.parser import BytesParser
from email.policy import default
from bs4 import BeautifulSoup
if fs:
logger.warning(
"fs was specified but MboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
i = 0
results: List[str] = []
# Load file using mailbox
bytes_parser = BytesParser(policy=default).parse
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
# Iterate through all messages
for _, _msg in enumerate(mbox):
try:
msg: mailbox.mboxMessage = _msg
# Parse multipart messages
if msg.is_multipart():
for part in msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if "attachment" in cdispo:
print(f"Attachment found: {part.get_filename()}")
if ctype == "text/plain" and "attachment" not in cdispo:
content = part.get_payload(decode=True) # decode
break
# Get plain message payload for non-multipart messages
else:
content = msg.get_payload(decode=True)
# Parse message HTML content and remove unneeded whitespace
soup = BeautifulSoup(content)
stripped_content = " ".join(soup.get_text().split())
# Format message to include date, sender, receiver and subject
msg_string = self.message_format.format(
_date=msg["date"],
_from=msg["from"],
_to=msg["to"],
_subject=msg["subject"],
_content=stripped_content,
)
# Add message string to results
results.append(msg_string)
except Exception as e:
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
# Increment counter and return if max count is met
i += 1
if self.max_count > 0 and i >= self.max_count:
break
return [Document(text=result, metadata=extra_info or {}) for result in results]
class EmlxMboxReader(MboxReader):
"""
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
Extends MboxReader to work with Apple Mail's .emlx format by:
1. Reading .emlx files from a directory
2. Converting them to mbox format in memory
3. Using the parent MboxReader's parsing logic
"""
def load_data(
self,
directory: Path,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Parse .emlx files from directory into strings using MboxReader logic."""
import tempfile
import os
if fs:
logger.warning(
"fs was specified but EmlxMboxReader doesn't support loading "
"from fsspec filesystems. Will load from local filesystem instead."
)
# Find all .emlx files in the directory
emlx_files = list(directory.glob("*.emlx"))
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
if not emlx_files:
logger.warning(f"No .emlx files found in {directory}")
return []
# Create a temporary mbox file
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
temp_mbox_path = temp_mbox.name
# Convert .emlx files to mbox format
for emlx_file in emlx_files:
try:
# Read the .emlx file
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx format: first line is length, rest is email content
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1] # Skip the length line
# Write to mbox format (each message starts with "From " and ends with blank line)
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
except Exception as e:
logger.warning(f"Failed to process {emlx_file}: {e}")
continue
# Close the temporary file so MboxReader can read it
temp_mbox.close()
try:
# Use the parent MboxReader's logic to parse the mbox file
return super().load_data(Path(temp_mbox_path), extra_info, fs)
finally:
# Clean up temporary file
try:
os.unlink(temp_mbox_path)
except:
pass

View File

@@ -1,124 +0,0 @@
import os
import email
from pathlib import Path
from typing import List, Any
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
def find_all_messages_directories(root: str = None) -> List[Path]:
"""
Recursively find all 'Messages' directories under the given root.
Returns a list of Path objects.
"""
if root is None:
# Auto-detect user's mail path
home_dir = os.path.expanduser("~")
root = os.path.join(home_dir, "Library", "Mail")
messages_dirs = []
for dirpath, dirnames, filenames in os.walk(root):
if os.path.basename(dirpath) == "Messages":
messages_dirs.append(Path(dirpath))
return messages_dirs
class EmlxReader(BaseReader):
"""
Apple Mail .emlx file reader with embedded metadata.
Reads individual .emlx files from Apple Mail's storage format.
"""
def __init__(self, include_html: bool = False) -> None:
"""
Initialize.
Args:
include_html: Whether to include HTML content in the email body (default: False)
"""
self.include_html = include_html
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
"""
Load data from the input directory containing .emlx files.
Args:
input_dir: Directory containing .emlx files
**load_kwargs:
max_count (int): Maximum amount of messages to read.
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
count = 0
# Walk through the directory recursively
for dirpath, dirnames, filenames in os.walk(input_dir):
# Skip hidden directories
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
for filename in filenames:
if count >= max_count:
break
if filename.endswith(".emlx"):
filepath = os.path.join(dirpath, filename)
try:
# Read the .emlx file
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# .emlx files have a length prefix followed by the email content
# The first line contains the length, followed by the email
lines = content.split('\n', 1)
if len(lines) >= 2:
email_content = lines[1]
# Parse the email using Python's email module
try:
msg = email.message_from_string(email_content)
# Extract email metadata
subject = msg.get('Subject', 'No Subject')
from_addr = msg.get('From', 'Unknown')
to_addr = msg.get('To', 'Unknown')
date = msg.get('Date', 'Unknown')
# Extract email body
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
if part.get_content_type() == "text/html" and not self.include_html:
continue
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
# break
else:
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
# Create document content with metadata embedded in text
doc_content = f"""
[EMAIL METADATA]
File: {filename}
From: {from_addr}
To: {to_addr}
Subject: {subject}
Date: {date}
[END METADATA]
{body}
"""
# No separate metadata - everything is in the text
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error parsing email from {filepath}: {e}")
continue
except Exception as e:
print(f"Error reading file {filepath}: {e}")
continue
print(f"Loaded {len(docs)} email documents")
return docs

View File

View File

@@ -1,382 +0,0 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List
from leann.api import LeannSearcher, LeannBuilder
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
if not data_root.exists():
print(f"Data directory '{data_root}' not found.")
print(
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
)
try:
from huggingface_hub import snapshot_download
if download_embeddings:
# Download everything including embeddings (large files)
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
)
print("Data download complete (including embeddings)!")
else:
# Download only specific folders, excluding embeddings
allow_patterns = [
"ground_truth/**",
"indices/**",
"queries/**",
"*.md",
"*.txt",
]
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=allow_patterns,
)
print("Data download complete (excluding embeddings)!")
except ImportError:
print(
"Error: huggingface_hub is not installed. Please install it to download the data:"
)
print("uv pip install -e '.[dev]'")
sys.exit(1)
except Exception as e:
print(f"An error occurred during data download: {e}")
sys.exit(1)
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
"""Download embeddings files specifically."""
embeddings_dir = data_root / "embeddings"
if dataset_type:
# Check if specific dataset embeddings exist
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
print(f"Embeddings for {dataset_type} already exist")
return str(target_file)
print("Downloading embeddings from HuggingFace Hub...")
try:
from huggingface_hub import snapshot_download
# Download only embeddings folder
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir=data_root,
local_dir_use_symlinks=False,
allow_patterns=["embeddings/**/*.pkl"],
)
print("Embeddings download complete!")
if dataset_type:
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
if target_file.exists():
return str(target_file)
return str(embeddings_dir)
except Exception as e:
print(f"Error downloading embeddings: {e}")
sys.exit(1)
# --- Helper Function to get Golden Passages ---
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
"""
Retrieves the text for golden passage IDs directly from the LeannSearcher's
passage manager.
"""
golden_texts = set()
for gid in golden_ids:
try:
# PassageManager uses string IDs
passage_data = searcher.passage_manager.get_passage(str(gid))
golden_texts.add(passage_data["text"])
except KeyError:
print(
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
)
return golden_texts
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
return queries
def build_index_from_embeddings(
embeddings_file: str, output_path: str, backend: str = "hnsw"
):
"""
Build a LEANN index from pre-computed embeddings.
Args:
embeddings_file: Path to pickle file with (ids, embeddings) tuple
output_path: Path where to save the index
backend: Backend to use ("hnsw" or "diskann")
"""
print(f"Building {backend} index from embeddings: {embeddings_file}")
# Create builder with appropriate parameters
if backend == "hnsw":
builder_kwargs = {
"M": 32, # Graph degree
"efConstruction": 256, # Construction complexity
"is_compact": True, # Use compact storage
"is_recompute": True, # Enable pruning for better recall
}
elif backend == "diskann":
builder_kwargs = {
"complexity": 64,
"graph_degree": 32,
"search_memory_maximum": 8.0, # GB
"build_memory_maximum": 16.0, # GB
}
else:
builder_kwargs = {}
builder = LeannBuilder(
backend_name=backend,
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
dimensions=768, # Will be auto-detected from embeddings
**builder_kwargs,
)
# Build index from precomputed embeddings
builder.build_index_from_embeddings(output_path, embeddings_file)
print(f"Index saved to: {output_path}")
return output_path
def main():
parser = argparse.ArgumentParser(
description="Run recall evaluation on a LEANN index."
)
parser.add_argument(
"index_path",
type=str,
nargs="?",
help="Path to the LEANN index to evaluate or build (optional).",
)
parser.add_argument(
"--mode",
choices=["evaluate", "build"],
default="evaluate",
help="Mode: 'evaluate' existing index or 'build' from embeddings",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Path to embeddings pickle file (optional for build mode)",
)
parser.add_argument(
"--backend",
choices=["hnsw", "diskann"],
default="hnsw",
help="Backend to use for building index (default: hnsw)",
)
parser.add_argument(
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
)
parser.add_argument(
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
)
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
args = parser.parse_args()
# --- Path Configuration ---
# Assumes a project structure where the script is in 'examples/'
# and data is in 'data/' at the project root.
project_root = Path(__file__).resolve().parent.parent
data_root = project_root / "data"
# Download data based on mode
if args.mode == "build":
# For building mode, we need embeddings
download_data_if_needed(
data_root, download_embeddings=False
) # Basic data first
# Auto-detect dataset type and download embeddings
if args.embeddings_file:
embeddings_file = args.embeddings_file
# Try to detect dataset type from embeddings file path
if "rpj_wiki" in str(embeddings_file):
dataset_type = "rpj_wiki"
elif "dpr" in str(embeddings_file):
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default
else:
# Auto-detect from index path if provided, otherwise default to DPR
if args.index_path:
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
dataset_type = "dpr" # Default to DPR
else:
dataset_type = "dpr" # Default to DPR
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
# Auto-generate index path if not provided
if not args.index_path:
indices_dir = data_root / "indices" / dataset_type
indices_dir.mkdir(parents=True, exist_ok=True)
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
print(f"Auto-generated index path: {args.index_path}")
print(f"Building index from embeddings: {embeddings_file}")
built_index_path = build_index_from_embeddings(
embeddings_file, args.index_path, args.backend
)
print(f"Index built successfully: {built_index_path}")
# Ask if user wants to run evaluation
eval_response = (
input("Run evaluation on the built index? (y/n): ").strip().lower()
)
if eval_response != "y":
print("Index building complete. Exiting.")
return
else:
# For evaluation mode, don't need embeddings
download_data_if_needed(data_root, download_embeddings=False)
# Auto-detect index path if not provided
if not args.index_path:
# Default to using downloaded indices
indices_dir = data_root / "indices"
# Try common datasets in order of preference
for dataset in ["dpr", "rpj_wiki"]:
dataset_dir = indices_dir / dataset
if dataset_dir.exists():
# Look for index files
index_files = list(dataset_dir.glob("*.index")) + list(
dataset_dir.glob("*_disk.index")
)
if index_files:
args.index_path = str(
index_files[0].with_suffix("")
) # Remove .index extension
print(f"Using index: {args.index_path}")
break
if not args.index_path:
print(
"No indices found. The data download should have included pre-built indices."
)
print(
"Please check the data/indices/ directory or provide --index-path manually."
)
sys.exit(1)
# Detect dataset type from index path to select the correct ground truth
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
# Fallback: try to infer from the index directory name
dataset_type = Path(args.index_path).name
print(
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
)
queries_file = data_root / "queries" / "nq_open.jsonl"
golden_results_file = (
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
)
print(f"INFO: Detected dataset type: {dataset_type}")
print(f"INFO: Using queries file: {queries_file}")
print(f"INFO: Using ground truth file: {golden_results_file}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(queries_file)
with open(golden_results_file, "r") as f:
golden_results_data = json.load(f)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(
queries[i], top_k=args.top_k, ef=args.ef_search
)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
# Get golden texts directly from the searcher's passage manager
golden_ids = golden_results_data["indices"][i][: args.top_k]
golden_texts = get_golden_texts(searcher, golden_ids)
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print("--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print("\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

View File

@@ -1,230 +0,0 @@
import os
import asyncio
import dotenv
import argparse
from pathlib import Path
from typing import List, Any, Optional
from leann.api import LeannBuilder, LeannSearcher, LeannChat
from llama_index.core.node_parser import SentenceSplitter
import requests
import time
dotenv.load_dotenv()
# Default WeChat export directory
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
def create_leann_index_from_multiple_wechat_exports(
export_dirs: List[Path],
index_path: str = "wechat_history_index.leann",
max_count: int = -1,
):
"""
Create LEANN index from multiple WeChat export data sources.
Args:
export_dirs: List of Path objects pointing to WeChat export directories
index_path: Path to save the LEANN index
max_count: Maximum number of chat entries to process per export
"""
print("Creating LEANN index from multiple WeChat export data sources...")
# Load documents using WeChatHistoryReader from local readers module
from .readers import WeChatHistoryReader
reader = WeChatHistoryReader()
INDEX_DIR = Path(index_path).parent
if not INDEX_DIR.exists():
print(f"--- Index directory not found, building new index ---")
all_documents = []
total_processed = 0
# Process each WeChat export directory
for i, export_dir in enumerate(export_dirs):
print(
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
)
try:
documents = reader.load_data(
wechat_export_dir=str(export_dir),
max_count=max_count,
concatenate_messages=True, # Disable concatenation - one message per document
)
if documents:
print(f"Loaded {len(documents)} chat documents from {export_dir}")
all_documents.extend(documents)
total_processed += len(documents)
# Check if we've reached the max count
if max_count > 0 and total_processed >= max_count:
print(f"Reached max count of {max_count} documents")
break
else:
print(f"No documents loaded from {export_dir}")
except Exception as e:
print(f"Error processing {export_dir}: {e}")
continue
if not all_documents:
print("No documents loaded from any source. Exiting.")
return None
print(
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
for doc in all_documents:
# Split the document into chunks
nodes = text_splitter.get_nodes_from_documents([doc])
for node in nodes:
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
all_texts.append(text)
print(
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
)
# Create LEANN index directory
print(f"--- Index directory not found, building new index ---")
INDEX_DIR.mkdir(exist_ok=True)
print(f"--- Building new LEANN index ---")
print(f"\n[PHASE 1] Building Leann index...")
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="Qwen/Qwen3-Embedding-0.6B",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
num_threads=1, # Force single-threaded mode
)
print(f"Adding {len(all_texts)} chat chunks to index...")
for chunk_text in all_texts:
builder.add_text(chunk_text)
builder.build_index(index_path)
print(f"\nLEANN index built at {index_path}!")
else:
print(f"--- Using existing index at {INDEX_DIR} ---")
return index_path
async def query_leann_index(index_path: str, query: str):
"""
Query the LEANN index.
Args:
index_path: Path to the LEANN index
query: The query string
"""
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=index_path)
print(f"You: {query}")
chat_response = chat.ask(
query,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=16,
beam_width=1,
llm_config={
"type": "openai",
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann: {chat_response}")
async def main():
"""Main function with integrated WeChat export functionality."""
# Parse command line arguments
parser = argparse.ArgumentParser(
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
)
parser.add_argument(
"--export-dir",
type=str,
default=DEFAULT_WECHAT_EXPORT_DIR,
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
)
parser.add_argument(
"--index-dir",
type=str,
default="./wechat_history_magic_test_11Debug_new",
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
)
parser.add_argument(
"--max-entries",
type=int,
default=50,
help="Maximum number of chat entries to process (default: 5000)",
)
parser.add_argument(
"--query",
type=str,
default=None,
help="Single query to run (default: runs example queries)",
)
parser.add_argument(
"--force-export",
action="store_true",
default=False,
help="Force re-export of WeChat data even if exports exist",
)
args = parser.parse_args()
INDEX_DIR = Path(args.index_dir)
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
print(f"Using WeChat export directory: {args.export_dir}")
print(f"Index directory: {INDEX_DIR}")
print(f"Max entries: {args.max_entries}")
# Initialize WeChat reader with export capabilities
from .readers import WeChatHistoryReader
reader = WeChatHistoryReader()
# Find existing exports or create new ones using the centralized method
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
if not export_dirs:
print("Failed to find or export WeChat data. Exiting.")
return
# Create or load the LEANN index from all sources
index_path = create_leann_index_from_multiple_wechat_exports(
export_dirs, INDEX_PATH, max_count=args.max_entries
)
if index_path:
if args.query:
# Run single query
await query_leann_index(index_path, args.query)
else:
# Example queries
queries = [
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
]
for query in queries:
print("\n" + "=" * 60)
await query_leann_index(index_path, query)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,719 +0,0 @@
import json
import os
import re
import subprocess
import sys
import time
from pathlib import Path
from typing import List, Any, Dict, Optional
from llama_index.core import Document
from llama_index.core.readers.base import BaseReader
from datetime import datetime
class WeChatHistoryReader(BaseReader):
"""
WeChat chat history reader that extracts chat data from exported JSON files.
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
and creates documents with embedded metadata similar to the Chrome history reader structure.
Also includes utilities for automatic WeChat chat history export.
"""
def __init__(self) -> None:
"""Initialize."""
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
def check_wechat_running(self) -> bool:
"""Check if WeChat is currently running."""
try:
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
return result.returncode == 0
except Exception:
return False
def install_wechattweak(self) -> bool:
"""Install WeChatTweak CLI tool."""
try:
# Create wechat-exporter directory if it doesn't exist
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
if not wechattweak_path.exists():
print("Downloading WeChatTweak CLI...")
subprocess.run([
"curl", "-L", "-o", str(wechattweak_path),
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
], check=True)
# Make executable
wechattweak_path.chmod(0o755)
# Install WeChatTweak
print("Installing WeChatTweak...")
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
return True
except Exception as e:
print(f"Error installing WeChatTweak: {e}")
return False
def restart_wechat(self):
"""Restart WeChat to apply WeChatTweak."""
try:
print("Restarting WeChat...")
subprocess.run(["pkill", "-f", "WeChat"], check=False)
time.sleep(2)
subprocess.run(["open", "-a", "WeChat"], check=True)
time.sleep(5) # Wait for WeChat to start
except Exception as e:
print(f"Error restarting WeChat: {e}")
def check_api_available(self) -> bool:
"""Check if WeChatTweak API is available."""
try:
result = subprocess.run([
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
], capture_output=True, text=True, timeout=5)
return result.returncode == 0 and result.stdout.strip()
except Exception:
return False
def _extract_readable_text(self, content: str) -> str:
"""
Extract readable text from message content, removing XML and system messages.
Args:
content: The raw message content (can be string or dict)
Returns:
Cleaned, readable text
"""
if not content:
return ""
# Handle dictionary content (like quoted messages)
if isinstance(content, dict):
# Extract text from dictionary structure
text_parts = []
if 'title' in content:
text_parts.append(str(content['title']))
if 'quoted' in content:
text_parts.append(str(content['quoted']))
if 'content' in content:
text_parts.append(str(content['content']))
if 'text' in content:
text_parts.append(str(content['text']))
if text_parts:
return " | ".join(text_parts)
else:
# If we can't extract meaningful text from dict, return empty
return ""
# Handle string content
if not isinstance(content, str):
return ""
# Remove common prefixes like "wxid_xxx:\n"
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If it's just XML or system message, return empty
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
return ""
return clean_content.strip()
def _is_text_message(self, content: str) -> bool:
"""
Check if a message contains readable text content.
Args:
content: The message content (can be string or dict)
Returns:
True if the message contains readable text, False otherwise
"""
if not content:
return False
# Handle dictionary content
if isinstance(content, dict):
# Check if dict has any readable text fields
text_fields = ['title', 'quoted', 'content', 'text']
for field in text_fields:
if field in content and content[field]:
return True
return False
# Handle string content
if not isinstance(content, str):
return False
# Skip image messages (contain XML with img tags)
if '<img' in content and 'cdnurl' in content:
return False
# Skip emoji messages (contain emoji XML tags)
if '<emoji' in content and 'productid' in content:
return False
# Skip voice messages
if '<voice' in content:
return False
# Skip video messages
if '<video' in content:
return False
# Skip file messages
if '<appmsg' in content and 'appid' in content:
return False
# Skip system messages (like "recalled a message")
if 'recalled a message' in content:
return False
# Check if there's actual readable text (not just XML or system messages)
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
# If after cleaning we have meaningful text, consider it readable
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
return True
return False
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
"""
Concatenate messages based on length and time rules.
Args:
messages: List of message dictionaries
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
overlap_messages: Number of messages to overlap between consecutive groups
Returns:
List of concatenated message groups
"""
if not messages:
return []
concatenated_groups = []
current_group = []
current_length = 0
last_timestamp = None
for message in messages:
# Extract message info
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Skip empty messages
if not readable_text.strip():
continue
# Check time window constraint (only if time_window_minutes != -1)
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
time_diff_minutes = (create_time - last_timestamp) / 60
if time_diff_minutes > time_window_minutes:
# Time gap too large, start new group
if current_group:
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Check length constraint (only if max_length != -1)
message_length = len(readable_text)
if max_length != -1 and current_length + message_length > max_length and current_group:
# Current group would exceed max length, save it and start new
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
# Keep last few messages for overlap
if overlap_messages > 0 and len(current_group) > overlap_messages:
current_group = current_group[-overlap_messages:]
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
else:
current_group = []
current_length = 0
# Add message to current group
current_group.append(message)
current_length += message_length
last_timestamp = create_time
# Add the last group if it exists
if current_group:
concatenated_groups.append({
'messages': current_group,
'total_length': current_length,
'start_time': current_group[0].get('createTime', 0),
'end_time': current_group[-1].get('createTime', 0)
})
return concatenated_groups
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
"""
Create concatenated content from a group of messages.
Args:
message_group: Dictionary containing messages and metadata
contact_name: Name of the contact
Returns:
Formatted concatenated content
"""
messages = message_group['messages']
start_time = message_group['start_time']
end_time = message_group['end_time']
# Format timestamps
if start_time:
try:
start_timestamp = datetime.fromtimestamp(start_time)
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
start_time_str = str(start_time)
else:
start_time_str = "Unknown"
if end_time:
try:
end_timestamp = datetime.fromtimestamp(end_time)
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
end_time_str = str(end_time)
else:
end_time_str = "Unknown"
# Build concatenated message content
message_parts = []
for message in messages:
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text:
readable_text = message_text
# Format individual message
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
# change to YYYY-MM-DD HH:MM:SS
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
sender = "[Me]" if is_sent_from_self else "[Contact]"
message_parts.append(f"({time_str}) {sender}: {readable_text}")
concatenated_text = "\n".join(message_parts)
# Create final document content
doc_content = f"""
Contact: {contact_name}
Time Range: {start_time_str} - {end_time_str}
Messages ({len(messages)} messages, {message_group['total_length']} chars):
{concatenated_text}
"""
# TODO @yichuan give better format and rich info here!
doc_content = f"""
{concatenated_text}
"""
return doc_content, contact_name
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
"""
Load WeChat chat history data from exported JSON files.
Args:
input_dir: Directory containing exported WeChat JSON files
**load_kwargs:
max_count (int): Maximum amount of chat entries to read.
wechat_export_dir (str): Custom path to WeChat export directory.
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
concatenate_messages (bool): Whether to concatenate messages based on length rules.
max_length (int): Maximum length for concatenated message groups (default: 1000).
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
"""
docs: List[Document] = []
max_count = load_kwargs.get('max_count', 1000)
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
include_non_text = load_kwargs.get('include_non_text', False)
concatenate_messages = load_kwargs.get('concatenate_messages', False)
max_length = load_kwargs.get('max_length', 1000)
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
# Default WeChat export path
if wechat_export_dir is None:
wechat_export_dir = "./wechat_export_test"
if not os.path.exists(wechat_export_dir):
print(f"WeChat export directory not found at: {wechat_export_dir}")
return docs
try:
# Find all JSON files in the export directory
json_files = list(Path(wechat_export_dir).glob("*.json"))
print(f"Found {len(json_files)} WeChat chat history files")
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, 'r', encoding='utf-8') as f:
chat_data = json.load(f)
# Extract contact name from filename
contact_name = json_file.stem
if concatenate_messages:
# Filter messages to only include readable text messages
readable_messages = []
for message in chat_data:
try:
content = message.get('content', '')
if not include_non_text and not self._is_text_message(content):
continue
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
readable_messages.append(message)
except Exception as e:
print(f"Error processing message in {json_file}: {e}")
continue
# Concatenate messages based on rules
message_groups = self._concatenate_messages(
readable_messages,
max_length=-1,
time_window_minutes=-1,
overlap_messages=0 # Keep 2 messages overlap between groups
)
# Create documents from concatenated groups
for message_group in message_groups:
if count >= max_count and max_count > 0:
break
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
docs.append(doc)
count += 1
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
else:
# Original single-message processing
for message in chat_data:
if count >= max_count and max_count > 0:
break
# Extract message information
from_user = message.get('fromUser', '')
to_user = message.get('toUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
is_sent_from_self = message.get('isSentFromSelf', False)
# Handle content that might be dict or string
try:
# Check if this is a readable text message
if not include_non_text and not self._is_text_message(content):
continue
# Extract readable text
readable_text = self._extract_readable_text(content)
if not readable_text and not include_non_text:
continue
except Exception as e:
# Skip messages that cause processing errors
print(f"Error processing message in {json_file}: {e}")
continue
# Convert timestamp to readable format
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
# Create document content with metadata header and contact info
doc_content = f"""
Contact: {contact_name}
Is sent from self: {is_sent_from_self}
Time: {time_str}
Message: {readable_text if readable_text else message_text}
"""
# Create document with embedded metadata
doc = Document(text=doc_content, metadata={})
docs.append(doc)
count += 1
except Exception as e:
print(f"Error reading {json_file}: {e}")
continue
print(f"Loaded {len(docs)} WeChat chat documents")
except Exception as e:
print(f"Error reading WeChat history: {e}")
return docs
return docs
@staticmethod
def find_wechat_export_dirs() -> List[Path]:
"""
Find all WeChat export directories.
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for common export directory names
possible_dirs = [
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_chat_history"),
Path("./chat_export")
]
for export_dir in possible_dirs:
if export_dir.exists() and export_dir.is_dir():
json_files = list(export_dir.glob("*.json"))
if json_files:
export_dirs.append(export_dir)
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
print(f"Found {len(export_dirs)} WeChat export directories")
return export_dirs
@staticmethod
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
"""
Export WeChat chat history to a text file.
Args:
output_file: Path to the output file
max_count: Maximum number of entries to export
export_dir: Directory containing WeChat JSON files
include_non_text: Whether to include non-text messages
"""
if export_dir is None:
export_dir = "./wechat_export_test"
if not os.path.exists(export_dir):
print(f"WeChat export directory not found at: {export_dir}")
return
try:
json_files = list(Path(export_dir).glob("*.json"))
with open(output_file, 'w', encoding='utf-8') as f:
count = 0
for json_file in json_files:
if count >= max_count and max_count > 0:
break
try:
with open(json_file, 'r', encoding='utf-8') as json_f:
chat_data = json.load(json_f)
contact_name = json_file.stem
f.write(f"\n=== Chat with {contact_name} ===\n")
for message in chat_data:
if count >= max_count and max_count > 0:
break
from_user = message.get('fromUser', '')
content = message.get('content', '')
message_text = message.get('message', '')
create_time = message.get('createTime', 0)
# Skip non-text messages unless requested
if not include_non_text:
reader = WeChatHistoryReader()
if not reader._is_text_message(content):
continue
readable_text = reader._extract_readable_text(content)
if not readable_text:
continue
message_text = readable_text
if create_time:
try:
timestamp = datetime.fromtimestamp(create_time)
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
except:
time_str = str(create_time)
else:
time_str = "Unknown"
f.write(f"[{time_str}] {from_user}: {message_text}\n")
count += 1
except Exception as e:
print(f"Error processing {json_file}: {e}")
continue
print(f"Exported {count} chat entries to {output_file}")
except Exception as e:
print(f"Error exporting WeChat chat history: {e}")
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
"""
Export WeChat chat history using wechat-exporter tool.
Args:
export_dir: Directory to save exported chat history
Returns:
Path to export directory if successful, None otherwise
"""
try:
import subprocess
import sys
# Create export directory
export_path = Path(export_dir)
export_path.mkdir(exist_ok=True)
print(f"Exporting WeChat chat history to {export_path}...")
# Check if wechat-exporter directory exists
if not self.wechat_exporter_dir.exists():
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
return None
# Install requirements if needed
requirements_file = self.wechat_exporter_dir / "requirements.txt"
if requirements_file.exists():
print("Installing wechat-exporter requirements...")
subprocess.run([
"uv", "pip", "install", "-r", str(requirements_file)
], check=True)
# Run the export command
print("Running wechat-exporter...")
result = subprocess.run([
sys.executable, str(self.wechat_exporter_dir / "main.py"),
"export-all", str(export_path)
], capture_output=True, text=True, check=True)
print("Export command output:")
print(result.stdout)
if result.stderr:
print("Export errors:")
print(result.stderr)
# Check if export was successful
if export_path.exists() and any(export_path.glob("*.json")):
json_files = list(export_path.glob("*.json"))
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
return export_path
else:
print("Export completed but no JSON files found")
return None
except subprocess.CalledProcessError as e:
print(f"Export command failed: {e}")
print(f"Command output: {e.stdout}")
print(f"Command errors: {e.stderr}")
return None
except Exception as e:
print(f"Export failed: {e}")
print("Please ensure WeChat is running and WeChatTweak is installed.")
return None
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
"""
Find existing WeChat exports or create new ones.
Args:
export_dir: Directory to save exported chat history if needed
Returns:
List of Path objects pointing to WeChat export directories
"""
export_dirs = []
# Look for existing exports in common locations
possible_export_dirs = [
Path("./wechat_database_export"),
Path("./wechat_export_test"),
Path("./wechat_export"),
Path("./wechat_export_direct"),
Path("./wechat_chat_history"),
Path("./chat_export")
]
for export_dir_path in possible_export_dirs:
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
export_dirs.append(export_dir_path)
print(f"Found existing export: {export_dir_path}")
# If no existing exports, try to export automatically
if not export_dirs:
print("No existing WeChat exports found. Starting direct export...")
# Try to export using wechat-exporter
exported_path = self.export_wechat_chat_history(export_dir)
if exported_path:
export_dirs = [exported_path]
else:
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
return export_dirs

View File

@@ -1,37 +1,321 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quick Start in 30s"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
"# install this if you areusing colab\n",
"! pip install leann"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the index"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"M: 64 for level: 0\n",
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
"[0.00s] Reading Index HNSW header...\n",
"[0.00s] Header read: d=768, ntotal=5\n",
"[0.00s] Reading HNSW struct vectors...\n",
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
"[0.00s] Read assign_probas (6)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
"[0.11s] Read cum_nneighbor_per_level (7)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
"[0.21s] Read levels (5)\n",
"[0.30s] Probing for compact storage flag...\n",
"[0.30s] Found compact flag: False\n",
"[0.30s] Compact flag is False, reading original format...\n",
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
"[0.30s] Read offsets (6)\n",
"[0.40s] Attempting to read neighbors vector...\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
"[0.40s] Read neighbors (320)\n",
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
"[0.50s] Checking for storage data...\n",
"[0.50s] Found storage fourcc: 49467849.\n",
"[0.50s] Converting to CSR format...\n",
"[0.50s] Conversion loop finished. \n",
"[0.50s] Running validation checks...\n",
" Checking total valid neighbor count...\n",
" OK: Total valid neighbors = 20\n",
" Checking final pointer indices...\n",
" OK: Final pointers match data size.\n",
"[0.50s] Deleting original neighbors and offsets arrays...\n",
" CSR Stats: |data|=20, |level_ptr|=10\n",
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
" Pruning embeddings: Writing NULL storage marker.\n",
"[0.69s] Conversion complete.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
]
}
],
"source": [
"from leann.api import LeannBuilder\n",
"\n",
"# 1. Build the index (no embeddings stored!)\n",
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
"builder.add_text(\"C# is a powerful programming language\")\n",
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
"builder.add_text(\"Machine learning transforms industries\")\n",
"builder.add_text(\"Neural networks process complex data\")\n",
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
"builder.build_index(\"knowledge.leann\")\n",
"builder.build_index(\"knowledge.leann\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search with real-time embeddings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'programming languages'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n",
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n"
]
},
{
"data": {
"text/plain": [
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from leann.api import LeannSearcher\n",
"\n",
"# 2. Search with real-time embeddings\n",
"searcher = LeannSearcher(\"knowledge.leann\")\n",
"results = searcher.search(\"programming languages\", top_k=2)\n",
"results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chat with LEANN using retrieved results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n",
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"data": {
"text/plain": [
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from leann.api import LeannChat\n",
"\n",
"# 3. Chat with LEANN using retrieved results\n",
"llm_config = {\n",
" \"type\": \"ollama\",\n",
" \"model\": \"llama3.2:1b\"\n",
" \"type\": \"hf\",\n",
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
"}\n",
"\n",
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
"response = chat.ask(\n",
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
" top_k=2,\n",
")"
" llm_kwargs={\"max_tokens\": 128}\n",
")\n",
"response"
]
}
],

22
docs/RELEASE.md Normal file
View File

@@ -0,0 +1,22 @@
# Release Guide
## Setup (One-time)
Add `PYPI_API_TOKEN` to GitHub Secrets:
1. Get token: https://pypi.org/manage/account/token/
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
## Release (One-click)
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
2. Click "Run workflow"
3. Enter version: `0.1.2`
4. Click green "Run workflow" button
That's it! The workflow will automatically:
- ✅ Update version in all packages
- ✅ Build all packages
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
Check progress: https://github.com/yichuan-w/LEANN/actions

11
docs/contributing.md Normal file
View File

@@ -0,0 +1,11 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results

10
docs/faq.md Normal file
View File

@@ -0,0 +1,10 @@
# FAQ
## 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

22
docs/features.md Normal file
View File

@@ -0,0 +1,22 @@
# ✨ Detailed Features
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
## 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment

21
docs/roadmap.md Normal file
View File

@@ -0,0 +1,21 @@
# 📈 Roadmap
## 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
## 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
## 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion

View File

@@ -135,6 +135,7 @@ def test_leann_hnsw():
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking")

View File

Binary file not shown.

View File

File diff suppressed because it is too large Load Diff

View File

File diff suppressed because it is too large Load Diff

146
examples/document_search.py Normal file
View File

@@ -0,0 +1,146 @@
#!/usr/bin/env python3
"""
Document search demo with recompute mode
"""
import os
from pathlib import Path
import shutil
import time
# Import backend packages to trigger plugin registration
try:
import leann_backend_diskann
import leann_backend_hnsw
print("INFO: Backend packages imported successfully.")
except ImportError as e:
print(f"WARNING: Could not import backend packages. Error: {e}")
# Import upper-level API from leann-core
from leann.api import LeannBuilder, LeannSearcher, LeannChat
def load_sample_documents():
"""Create sample documents for demonstration"""
docs = [
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
]
return docs
def main():
print("==========================================================")
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
print("==========================================================")
INDEX_DIR = Path("./test_indices")
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
BACKEND_TO_TEST = "diskann"
if INDEX_DIR.exists():
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
shutil.rmtree(INDEX_DIR)
# --- 1. Build index ---
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
builder = LeannBuilder(
backend_name=BACKEND_TO_TEST,
graph_degree=32,
complexity=64
)
documents = load_sample_documents()
print(f"Loaded {len(documents)} sample documents.")
for doc in documents:
builder.add_text(doc["content"], metadata={"title": doc["title"]})
builder.build_index(INDEX_PATH)
print(f"\nIndex built!")
# --- 2. Basic search demo ---
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
searcher = LeannSearcher(index_path=INDEX_PATH)
query = "What is machine learning?"
print(f"\nQuery: '{query}'")
print("\n--- Basic search mode (PQ computation) ---")
start_time = time.time()
results = searcher.search(query, top_k=2)
basic_time = time.time() - start_time
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<")
for i, res in enumerate(results, 1):
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...")
print("\n--- Recompute search mode (get real embeddings via network) ---")
# Configure recompute parameters
recompute_params = {
"recompute_beighbor_embeddings": True, # Enable network recomputation
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
"skip_search_reorder": True, # Skip search reordering
"dedup_node_dis": True, # Enable node distance deduplication
"prune_ratio": 0.1, # Pruning ratio 10%
"batch_recompute": False, # Don't use batch recomputation
"global_pruning": False, # Don't use global pruning
"zmq_port": 5555, # ZMQ port
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
}
print("Recompute parameter configuration:")
for key, value in recompute_params.items():
print(f" {key}: {value}")
print(f"\n🔄 Executing Recompute search...")
try:
start_time = time.time()
recompute_results = searcher.search(query, top_k=2, **recompute_params)
recompute_time = time.time() - start_time
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results
print(f"\n--- Result comparison ---")
print(f"Basic search time: {basic_time:.3f} seconds")
print(f"Recompute time: {recompute_time:.3f} seconds")
print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))):
basic_score = results[i].score
recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
if recompute_time > basic_time:
print(f"✅ Recompute mode working correctly (more accurate but slower)")
else:
print(f" Recompute time is unusually fast, network recomputation may not be enabled")
except Exception as e:
print(f"❌ Recompute search failed: {e}")
print("This usually indicates an embedding server connection issue")
# --- 4. Chat demo ---
print(f"\n[PHASE 4] Starting chat session...")
chat = LeannChat(index_path=INDEX_PATH)
chat_response = chat.ask(query)
print(f"You: {query}")
print(f"Leann: {chat_response}")
print("\n==========================================================")
print("✅ Demo finished successfully!")
print("==========================================================")
if __name__ == "__main__":
main()

View File

@@ -37,7 +37,7 @@ def main():
import faiss
except ImportError:
print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu`")
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
sys.exit(1)
from llama_index.core import (

View File

@@ -222,14 +222,15 @@ async def query_leann_index(index_path: str, query: str):
"max_tokens": 1000
}
)
print(f"Leann: {chat_response}")
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./all_google_new",
parser.add_argument('--index-dir', type=str, default="./google_history_index",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')

View File

@@ -22,7 +22,7 @@ def get_mail_path():
return os.path.join(home_dir, "Library", "Mail")
# Default mail path for macOS
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
"""
@@ -77,7 +77,7 @@ def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_pa
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -158,7 +158,7 @@ def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max
print(f"Loaded {len(documents)} email documents")
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -218,22 +218,22 @@ async def query_leann_index(index_path: str, query: str):
start_time = time.time()
chat_response = chat.ask(
query,
top_k=10,
top_k=20,
recompute_beighbor_embeddings=True,
complexity=12,
complexity=32,
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
# print(f"Time taken: {end_time - start_time} seconds")
# highlight the answer
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
parser.add_argument('--index-dir', type=str, default="./mail_index",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')
@@ -253,6 +253,9 @@ async def main():
mail_path = get_mail_path()
print(f"Searching for email data in: {mail_path}")
messages_dirs = find_all_messages_directories(mail_path)
# messages_dirs = find_all_messages_directories(DEFAULT_MAIL_PATH)
# messages_dirs = [DEFAULT_MAIL_PATH]
# messages_dirs = messages_dirs[:1]
print('len(messages_dirs): ', len(messages_dirs))

View File

@@ -0,0 +1,108 @@
import os
import sys
import argparse
from pathlib import Path
from typing import List, Any
# Add the project root to Python path so we can import from examples
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.core.node_parser import SentenceSplitter
# --- EMBEDDING MODEL ---
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import torch
# --- END EMBEDDING MODEL ---
# Import EmlxReader from the new module
from examples.email_data.LEANN_email_reader import EmlxReader
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
print("Creating index from mail data with embedded metadata...")
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
if not documents:
print("No documents loaded. Exiting.")
return None
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
# Use facebook/contriever as the embedder
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
# set on device
import torch
if torch.cuda.is_available():
embed_model._model.to("cuda")
# set mps
elif torch.backends.mps.is_available():
embed_model._model.to("mps")
else:
embed_model._model.to("cpu")
index = VectorStoreIndex.from_documents(
documents,
transformations=[text_splitter],
embed_model=embed_model
)
os.makedirs(save_dir, exist_ok=True)
index.storage_context.persist(persist_dir=save_dir)
print(f"Index saved to {save_dir}")
return index
def load_index(save_dir: str = "mail_index_embedded"):
try:
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = VectorStoreIndex.from_vector_store(
storage_context.vector_store,
storage_context=storage_context
)
print(f"Index loaded from {save_dir}")
return index
except Exception as e:
print(f"Error loading index: {e}")
return None
def query_index(index, query: str):
if index is None:
print("No index available for querying.")
return
query_engine = index.as_query_engine()
response = query_engine.query(query)
print(f"Query: {query}")
print(f"Response: {response}")
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
parser.add_argument('--mail-path', type=str,
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
help='Path to mail data directory')
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
help='Directory to store the index (default: mail_index_embedded)')
parser.add_argument('--max-emails', type=int, default=10000,
help='Maximum number of emails to process')
parser.add_argument('--include-html', action='store_true', default=False,
help='Include HTML content in email processing (default: False)')
args = parser.parse_args()
mail_path = args.mail_path
save_dir = args.save_dir
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
print("Loading existing index...")
index = load_index(save_dir)
else:
print("Creating new index...")
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
if index:
queries = [
"Hows Berkeley Graduate Student Instructor",
"how's the icloud related advertisement saying",
"Whats the number of class recommend to take per semester for incoming EECS students"
]
for query in queries:
print("\n" + "="*50)
query_index(index, query)
if __name__ == "__main__":
main()

View File

@@ -63,16 +63,14 @@ async def main(args):
llm_config = {"type": "openai", "model": "gpt-4o"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
query = args.query
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
if __name__ == "__main__":
@@ -110,6 +108,12 @@ if __name__ == "__main__":
default="examples/data",
help="Directory containing documents to index (PDF, TXT, MD files).",
)
parser.add_argument(
"--query",
type=str,
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
help="The query to ask the Leann chat system.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,319 @@
#!/usr/bin/env python3
"""
Multi-Vector Aggregator for Fat Embeddings
==========================================
This module implements aggregation strategies for multi-vector embeddings,
similar to ColPali's approach where multiple patch vectors represent a single document.
Key features:
- MaxSim aggregation (take maximum similarity across patches)
- Voting-based aggregation (count patch matches)
- Weighted aggregation (attention-score weighted)
- Spatial clustering of matching patches
- Document-level result consolidation
"""
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import json
@dataclass
class PatchResult:
"""Represents a single patch search result."""
patch_id: int
image_name: str
image_path: str
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
score: float
attention_score: float
scale: float
metadata: Dict[str, Any]
@dataclass
class AggregatedResult:
"""Represents an aggregated document-level result."""
image_name: str
image_path: str
doc_score: float
patch_count: int
best_patch: PatchResult
all_patches: List[PatchResult]
aggregation_method: str
spatial_clusters: Optional[List[List[PatchResult]]] = None
class MultiVectorAggregator:
"""
Aggregates multiple patch-level results into document-level results.
"""
def __init__(self,
aggregation_method: str = "maxsim",
spatial_clustering: bool = True,
cluster_distance_threshold: float = 100.0):
"""
Initialize the aggregator.
Args:
aggregation_method: "maxsim", "voting", "weighted", or "mean"
spatial_clustering: Whether to cluster spatially close patches
cluster_distance_threshold: Distance threshold for spatial clustering
"""
self.aggregation_method = aggregation_method
self.spatial_clustering = spatial_clustering
self.cluster_distance_threshold = cluster_distance_threshold
def aggregate_results(self,
search_results: List[Dict[str, Any]],
top_k: int = 10) -> List[AggregatedResult]:
"""
Aggregate patch-level search results into document-level results.
Args:
search_results: List of search results from LeannSearcher
top_k: Number of top documents to return
Returns:
List of aggregated document results
"""
# Group results by image
image_groups = defaultdict(list)
for result in search_results:
metadata = result.metadata
if "image_name" in metadata and "patch_id" in metadata:
patch_result = PatchResult(
patch_id=metadata["patch_id"],
image_name=metadata["image_name"],
image_path=metadata["image_path"],
coordinates=tuple(metadata["coordinates"]),
score=result.score,
attention_score=metadata.get("attention_score", 0.0),
scale=metadata.get("scale", 1.0),
metadata=metadata
)
image_groups[metadata["image_name"]].append(patch_result)
# Aggregate each image group
aggregated_results = []
for image_name, patches in image_groups.items():
if len(patches) == 0:
continue
agg_result = self._aggregate_image_patches(image_name, patches)
aggregated_results.append(agg_result)
# Sort by aggregated score and return top-k
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
return aggregated_results[:top_k]
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
"""Aggregate patches for a single image."""
if self.aggregation_method == "maxsim":
doc_score = max(patch.score for patch in patches)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "voting":
# Count patches above threshold
threshold = np.percentile([p.score for p in patches], 75)
doc_score = sum(1 for patch in patches if patch.score >= threshold)
best_patch = max(patches, key=lambda p: p.score)
elif self.aggregation_method == "weighted":
# Weight by attention scores
total_weighted_score = sum(p.score * p.attention_score for p in patches)
total_weights = sum(p.attention_score for p in patches)
doc_score = total_weighted_score / max(total_weights, 1e-8)
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
elif self.aggregation_method == "mean":
doc_score = np.mean([patch.score for patch in patches])
best_patch = max(patches, key=lambda p: p.score)
else:
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
# Spatial clustering if enabled
spatial_clusters = None
if self.spatial_clustering:
spatial_clusters = self._cluster_patches_spatially(patches)
return AggregatedResult(
image_name=image_name,
image_path=patches[0].image_path,
doc_score=float(doc_score),
patch_count=len(patches),
best_patch=best_patch,
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
aggregation_method=self.aggregation_method,
spatial_clusters=spatial_clusters
)
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
"""Cluster patches that are spatially close to each other."""
if len(patches) <= 1:
return [patches]
clusters = []
remaining_patches = patches.copy()
while remaining_patches:
# Start new cluster with highest scoring remaining patch
seed_patch = max(remaining_patches, key=lambda p: p.score)
current_cluster = [seed_patch]
remaining_patches.remove(seed_patch)
# Add nearby patches to cluster
added_to_cluster = True
while added_to_cluster:
added_to_cluster = False
for patch in remaining_patches.copy():
if self._is_patch_nearby(patch, current_cluster):
current_cluster.append(patch)
remaining_patches.remove(patch)
added_to_cluster = True
clusters.append(current_cluster)
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
"""Check if a patch is spatially close to any patch in the cluster."""
patch_center = self._get_patch_center(patch.coordinates)
for cluster_patch in cluster:
cluster_center = self._get_patch_center(cluster_patch.coordinates)
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
(patch_center[1] - cluster_center[1])**2)
if distance <= self.cluster_distance_threshold:
return True
return False
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
"""Get center point of a patch."""
x1, y1, x2, y2 = coordinates
return ((x1 + x2) / 2, (y1 + y2) / 2)
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
"""Pretty print aggregated results."""
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
print("=" * 80)
for i, result in enumerate(results):
print(f"\n{i+1}. {result.image_name}")
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
print(f" Path: {result.image_path}")
# Show best patch
best = result.best_patch
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
# Show top patches
print(f" 📍 Top Patches:")
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
# Show spatial clusters if available
if result.spatial_clusters and len(result.spatial_clusters) > 1:
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
cluster_score = max(p.score for p in cluster)
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
def demo_aggregation():
"""Demonstrate the multi-vector aggregation functionality."""
print("=== Multi-Vector Aggregation Demo ===")
# Simulate some patch-level search results
# In real usage, these would come from LeannSearcher.search()
class MockResult:
def __init__(self, score, metadata):
self.score = score
self.metadata = metadata
# Simulate results for 2 images with multiple patches each
mock_results = [
# Image 1: cats_and_kitchen.jpg - 4 patches
MockResult(0.85, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 3,
"coordinates": [100, 50, 224, 174], # Kitchen area
"attention_score": 0.92,
"scale": 1.0
}),
MockResult(0.78, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 7,
"coordinates": [200, 300, 324, 424], # Cat area
"attention_score": 0.88,
"scale": 1.0
}),
MockResult(0.72, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 12,
"coordinates": [150, 100, 274, 224], # Appliances
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.65, {
"image_name": "cats_and_kitchen.jpg",
"image_path": "/path/to/cats_and_kitchen.jpg",
"patch_id": 15,
"coordinates": [50, 250, 174, 374], # Furniture
"attention_score": 0.70,
"scale": 1.0
}),
# Image 2: city_street.jpg - 3 patches
MockResult(0.68, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 2,
"coordinates": [300, 100, 424, 224], # Buildings
"attention_score": 0.80,
"scale": 1.0
}),
MockResult(0.62, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 8,
"coordinates": [100, 350, 224, 474], # Street level
"attention_score": 0.75,
"scale": 1.0
}),
MockResult(0.55, {
"image_name": "city_street.jpg",
"image_path": "/path/to/city_street.jpg",
"patch_id": 11,
"coordinates": [400, 200, 524, 324], # Sky area
"attention_score": 0.60,
"scale": 1.0
}),
]
# Test different aggregation methods
methods = ["maxsim", "voting", "weighted", "mean"]
for method in methods:
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
aggregator = MultiVectorAggregator(
aggregation_method=method,
spatial_clustering=True,
cluster_distance_threshold=100.0
)
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
aggregator.print_aggregated_results(aggregated)
if __name__ == "__main__":
demo_aggregation()

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python3
"""
OpenAI Embedding Example
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
"""
import os
import dotenv
from pathlib import Path
from leann.api import LeannBuilder, LeannSearcher
# Load environment variables
dotenv.load_dotenv()
def main():
# Check if OpenAI API key is available
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY environment variable not set")
return False
print(f"✅ OpenAI API key found: {api_key[:10]}...")
# Sample texts
sample_texts = [
"Machine learning is a powerful technology that enables computers to learn from data.",
"Natural language processing helps computers understand and generate human language.",
"Deep learning uses neural networks with multiple layers to solve complex problems.",
"Computer vision allows machines to interpret and understand visual information.",
"Reinforcement learning trains agents to make decisions through trial and error.",
"Data science combines statistics, math, and programming to extract insights from data.",
"Artificial intelligence aims to create machines that can perform human-like tasks.",
"Python is a popular programming language used extensively in data science and AI.",
"Neural networks are inspired by the structure and function of the human brain.",
"Big data refers to extremely large datasets that require special tools to process."
]
INDEX_DIR = Path("./simple_openai_test_index")
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
print(f"\n=== Building Index with OpenAI Embeddings ===")
print(f"Index path: {INDEX_PATH}")
try:
# Use proper configuration for OpenAI embeddings
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
# HNSW settings for OpenAI embeddings
M=16, # Smaller graph degree
efConstruction=64, # Smaller construction complexity
is_compact=True, # Enable compact storage for recompute
is_recompute=True, # MUST enable for OpenAI embeddings
num_threads=1,
)
print(f"Adding {len(sample_texts)} texts to the index...")
for i, text in enumerate(sample_texts):
metadata = {"id": f"doc_{i}", "topic": "AI"}
builder.add_text(text, metadata)
print("Building index...")
builder.build_index(INDEX_PATH)
print(f"✅ Index built successfully!")
except Exception as e:
print(f"❌ Error building index: {e}")
import traceback
traceback.print_exc()
return False
print(f"\n=== Testing Search ===")
try:
searcher = LeannSearcher(INDEX_PATH)
test_queries = [
"What is machine learning?",
"How do neural networks work?",
"Programming languages for data science"
]
for query in test_queries:
print(f"\n🔍 Query: '{query}'")
results = searcher.search(query, top_k=3)
print(f" Found {len(results)} results:")
for i, result in enumerate(results):
print(f" {i+1}. Score: {result.score:.4f}")
print(f" Text: {result.text[:80]}...")
print(f"\n✅ Search test completed successfully!")
return True
except Exception as e:
print(f"❌ Error during search: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
if success:
print(f"\n🎉 Simple OpenAI index test completed successfully!")
else:
print(f"\n💥 Simple OpenAI index test failed!")

18
examples/resue_index.py Normal file
View File

@@ -0,0 +1,18 @@
import asyncio
from leann.api import LeannChat
from pathlib import Path
INDEX_DIR = Path("./test_pdf_index_huawei")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
async def main():
print(f"\n[PHASE 2] Starting Leann chat session...")
chat = LeannChat(index_path=INDEX_PATH)
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
print(f"\n[PHASE 2] Response: {response}")
if __name__ == "__main__":
asyncio.run(main())

81
examples/simple_demo.py Normal file
View File

@@ -0,0 +1,81 @@
"""
Simple demo showing basic leann usage
Run: uv run python examples/simple_demo.py
"""
import argparse
from leann import LeannBuilder, LeannSearcher, LeannChat
def main():
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
args = parser.parse_args()
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
print()
# Sample knowledge base
chunks = [
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
"Natural language processing helps computers understand and generate human language.",
"Computer vision enables machines to interpret and understand visual information from images and videos.",
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
"Big data refers to extremely large datasets that require special tools and techniques to process.",
"Cloud computing provides on-demand access to computing resources over the internet.",
]
print("1. Building index (no embeddings stored)...")
builder = LeannBuilder(
embedding_model=args.embedding_model,
backend_name="hnsw",
)
for chunk in chunks:
builder.add_text(chunk)
builder.build_index("demo_knowledge.leann")
print()
print("2. Searching with real-time embeddings...")
searcher = LeannSearcher("demo_knowledge.leann")
queries = [
"What is machine learning?",
"How does neural network work?",
"Tell me about data processing",
]
for query in queries:
print(f"Query: {query}")
results = searcher.search(query, top_k=2)
for i, result in enumerate(results, 1):
print(f" {i}. Score: {result.score:.3f}")
print(f" Text: {result.text[:100]}...")
print()
print("3. Interactive chat demo:")
print(" (Note: Requires OpenAI API key for real responses)")
chat = LeannChat("demo_knowledge.leann")
# Demo questions
demo_questions: list[str] = [
"What is the difference between machine learning and deep learning?",
"How is data science related to big data?",
]
for question in demo_questions:
print(f" Q: {question}")
response = chat.ask(question)
print(f" A: {response}")
print()
print("Demo completed! Try running:")
print(" uv run python examples/document_search.py")
if __name__ == "__main__":
main()

View File

@@ -78,7 +78,7 @@ def create_leann_index_from_multiple_wechat_exports(
)
# Create text splitter with 256 chunk size
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
text_splitter = SentenceSplitter(chunk_size=192, chunk_overlap=64)
# Convert Documents to text strings and chunk them
all_texts = []
@@ -234,7 +234,7 @@ async def query_leann_index(index_path: str, query: str):
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann: {chat_response}")
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.1.0"
dependencies = ["leann-core==0.1.0", "numpy"]
version = "0.1.9"
dependencies = ["leann-core==0.1.9", "numpy"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -92,8 +92,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if success:
logger.info("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
# index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"

View File

@@ -6,9 +6,14 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.1.0"
version = "0.1.9"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"]
dependencies = [
"leann-core==0.1.9",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
]
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]

View File

@@ -4,15 +4,23 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.1.0"
description = "Core API and plugin system for Leann."
version = "0.1.9"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
# All required dependencies included
dependencies = [
"numpy>=1.20.0",
"tqdm>=4.60.0"
"tqdm>=4.60.0",
"psutil>=5.8.0",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",
"torch>=2.0.0",
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"python-dotenv>=1.0.0",
]
[project.scripts]

View File

@@ -142,7 +142,7 @@ class LeannBuilder:
def __init__(
self,
backend_name: str,
embedding_model: str = "facebook/contriever-msmarco",
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
**backend_kwargs,
@@ -441,9 +441,9 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
logger.info(f" Embedding time: {embedding_time} seconds")
# logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
results = self.backend_impl.search(
@@ -458,7 +458,7 @@ class LeannSearcher:
**kwargs,
)
search_time = time.time() - start_time
logger.info(f" Search time: {search_time} seconds")
# logger.info(f" Search time: {search_time} seconds")
logger.info(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
@@ -479,15 +479,25 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}),
)
)
# Color codes for better logging
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RESET = "\033[0m"
# Truncate text for display (first 100 chars)
display_text = passage_data['text']
logger.info(
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
)
except KeyError:
RED = "\033[91m"
logger.error(
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
logger.info(f" Final enriched results: {len(enriched_results)} passages")
logger.info(f" {GREEN} Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results
@@ -517,7 +527,7 @@ class LeannChat:
):
if llm_kwargs is None:
llm_kwargs = {}
search_time = time.time()
results = self.searcher.search(
question,
top_k=top_k,
@@ -529,6 +539,8 @@ class LeannChat:
expected_zmq_port=expected_zmq_port,
**search_kwargs,
)
search_time = time.time() - search_time
# logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"

View File

@@ -9,6 +9,7 @@ from typing import Dict, Any, Optional, List
import logging
import os
import difflib
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -28,6 +29,68 @@ def check_ollama_models() -> List[str]:
return []
def check_ollama_model_exists_remotely(model_name: str) -> tuple[bool, list[str]]:
"""Check if a model exists in Ollama's remote library and return available tags
Returns:
(model_exists, available_tags): bool and list of matching tags
"""
try:
import requests
import re
# Split model name and tag
if ':' in model_name:
base_model, requested_tag = model_name.split(':', 1)
else:
base_model, requested_tag = model_name, None
# First check if base model exists in library
library_response = requests.get("https://ollama.com/library", timeout=8)
if library_response.status_code != 200:
return True, [] # Assume exists if can't check
# Extract model names from library page
models_in_library = re.findall(r'href="/library/([^"]+)"', library_response.text)
if base_model not in models_in_library:
return False, [] # Base model doesn't exist
# If base model exists, get available tags
tags_response = requests.get(f"https://ollama.com/library/{base_model}/tags", timeout=8)
if tags_response.status_code != 200:
return True, [] # Base model exists but can't get tags
# Extract tags for this model - be more specific to avoid HTML artifacts
tag_pattern = rf'{re.escape(base_model)}:[a-zA-Z0-9\.\-_]+'
raw_tags = re.findall(tag_pattern, tags_response.text)
# Clean up tags - remove HTML artifacts and duplicates
available_tags = []
seen = set()
for tag in raw_tags:
# Skip if it looks like HTML (contains < or >)
if '<' in tag or '>' in tag:
continue
if tag not in seen:
seen.add(tag)
available_tags.append(tag)
# Check if exact model exists
if requested_tag is None:
# User just requested base model, suggest tags
return True, available_tags[:10] # Return up to 10 tags
else:
exact_match = model_name in available_tags
return exact_match, available_tags[:10]
except Exception:
pass
# If scraping fails, assume model might exist (don't block user)
return True, []
def search_ollama_models_fuzzy(query: str, available_models: List[str]) -> List[str]:
"""Use intelligent fuzzy search for Ollama models"""
if not available_models:
@@ -243,24 +306,66 @@ def validate_model_and_suggest(model_name: str, llm_type: str) -> Optional[str]:
if llm_type == "ollama":
available_models = check_ollama_models()
if available_models and model_name not in available_models:
# Use intelligent fuzzy search based on locally installed models
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\nTo list all models: ollama list"
error_msg += "\nTo download a new model: ollama pull <model_name>"
error_msg += "\nBrowse models: https://ollama.com/library"
# Check if the model exists remotely and get available tags
model_exists_remotely, available_tags = check_ollama_model_exists_remotely(model_name)
if model_exists_remotely and model_name in available_tags:
# Exact model exists remotely - suggest pulling it
error_msg += f"\n\nTo install the requested model:\n"
error_msg += f" ollama pull {model_name}\n"
# Show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
elif model_exists_remotely and available_tags:
# Base model exists but requested tag doesn't - suggest correct tags
base_model = model_name.split(':')[0]
requested_tag = model_name.split(':', 1)[1] if ':' in model_name else None
error_msg += f"\n\nModel '{base_model}' exists, but tag '{requested_tag}' is not available."
error_msg += f"\n\nAvailable {base_model} models you can install:\n"
for i, tag in enumerate(available_tags[:8], 1):
error_msg += f" {i}. ollama pull {tag}\n"
if len(available_tags) > 8:
error_msg += f" ... and {len(available_tags) - 8} more variants\n"
# Also show local alternatives
suggestions = search_ollama_models_fuzzy(model_name, available_models)
if suggestions:
error_msg += "\nOr use one of these similar installed models:\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
# Model doesn't exist remotely - show fuzzy suggestions
suggestions = search_ollama_models_fuzzy(model_name, available_models)
error_msg += f"\n\nModel '{model_name}' was not found in Ollama's library."
if suggestions:
error_msg += "\n\nDid you mean one of these installed models?\n"
for i, suggestion in enumerate(suggestions, 1):
error_msg += f" {i}. {suggestion}\n"
else:
error_msg += "\n\nYour installed models:\n"
for i, model in enumerate(available_models[:8], 1):
error_msg += f" {i}. {model}\n"
if len(available_models) > 8:
error_msg += f" ... and {len(available_models) - 8} more\n"
error_msg += "\n\nCommands:"
error_msg += "\n ollama list # List installed models"
if model_exists_remotely and available_tags:
if model_name in available_tags:
error_msg += f"\n ollama pull {model_name} # Install requested model"
else:
error_msg += f"\n ollama pull {available_tags[0]} # Install recommended variant"
error_msg += "\n https://ollama.com/library # Browse available models"
return error_msg
elif llm_type == "hf":
@@ -397,7 +502,7 @@ class OllamaChat(LLMInterface):
class HFChat(LLMInterface):
"""LLM interface for local Hugging Face Transformers models."""
"""LLM interface for local Hugging Face Transformers models with proper chat templates."""
def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"):
logger.info(f"Initializing HFChat with model='{model_name}'")
@@ -408,7 +513,7 @@ class HFChat(LLMInterface):
raise ValueError(model_error)
try:
from transformers.pipelines import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
except ImportError:
raise ImportError(
@@ -417,54 +522,101 @@ class HFChat(LLMInterface):
# Auto-detect device
if torch.cuda.is_available():
device = "cuda"
self.device = "cuda"
logger.info("CUDA is available. Using GPU.")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
self.device = "mps"
logger.info("MPS is available. Using Apple Silicon GPU.")
else:
device = "cpu"
self.device = "cpu"
logger.info("No GPU detected. Using CPU.")
self.pipeline = pipeline("text-generation", model=model_name, device=device)
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32,
device_map="auto" if self.device != "cpu" else None,
trust_remote_code=True
)
# Move model to device if not using device_map
if self.device != "cpu" and "device_map" not in str(self.model):
self.model = self.model.to(self.device)
# Set pad token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def ask(self, prompt: str, **kwargs) -> str:
# Map OpenAI-style arguments to Hugging Face equivalents
if "max_tokens" in kwargs:
# Prefer user-provided max_new_tokens if both are present
kwargs.setdefault("max_new_tokens", kwargs["max_tokens"])
# Remove the unsupported key to avoid errors in Transformers
kwargs.pop("max_tokens")
print('kwargs in HF: ', kwargs)
# Check if this is a Qwen model and add /no_think by default
is_qwen_model = "qwen" in self.model.config._name_or_path.lower()
# For Qwen models, automatically add /no_think to the prompt
if is_qwen_model and "/no_think" not in prompt and "/think" not in prompt:
prompt = prompt + " /no_think"
# Prepare chat template
messages = [{"role": "user", "content": prompt}]
# Apply chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
try:
formatted_prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed, using raw prompt: {e}")
formatted_prompt = prompt
else:
# Fallback for models without chat template
formatted_prompt = prompt
# Handle temperature=0 edge-case for greedy decoding
if "temperature" in kwargs and kwargs["temperature"] == 0.0:
# Remove unsupported zero temperature and use deterministic generation
kwargs.pop("temperature")
kwargs.setdefault("do_sample", False)
# Tokenize input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
)
# Move inputs to device
if self.device != "cpu":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Sensible defaults for text generation
params = {"max_length": 500, "num_return_sequences": 1, **kwargs}
logger.info(f"Generating text with Hugging Face model with params: {params}")
results = self.pipeline(prompt, **params)
# Set generation parameters
generation_config = {
"max_new_tokens": kwargs.get("max_tokens", kwargs.get("max_new_tokens", 512)),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.9),
"do_sample": kwargs.get("temperature", 0.7) > 0,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
# Handle temperature=0 for greedy decoding
if generation_config["temperature"] == 0.0:
generation_config["do_sample"] = False
generation_config.pop("temperature")
# Handle different response formats from transformers
if isinstance(results, list) and len(results) > 0:
generated_text = (
results[0].get("generated_text", "")
if isinstance(results[0], dict)
else str(results[0])
logger.info(f"Generating with HuggingFace model, config: {generation_config}")
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generation_config
)
else:
generated_text = str(results)
# Extract only the newly generated portion by removing the original prompt
if isinstance(generated_text, str) and generated_text.startswith(prompt):
response = generated_text[len(prompt) :].strip()
else:
# Fallback: return the full response if prompt removal fails
response = str(generated_text)
return response
# Decode response
generated_tokens = outputs[0][inputs["input_ids"].shape[1]:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
class OpenAIChat(LLMInterface):

View File

@@ -101,7 +101,7 @@ def compute_embeddings_sentence_transformers(
if device == "mps":
batch_size = 128 # MPS optimal batch size from benchmark
if model_name == "Qwen/Qwen3-Embedding-0.6B":
batch_size = 64
batch_size = 32
elif device == "cuda":
batch_size = 256 # CUDA optimal batch size
# Keep original batch_size for CPU

View File

@@ -269,7 +269,9 @@ class EmbeddingServerManager:
]
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
# Convert to absolute path to ensure subprocess can find the file
passages_file = Path(kwargs["passages_file"]).resolve()
command.extend(["--passages-file", str(passages_file)])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])

View File

@@ -112,8 +112,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
# Convert to absolute path to ensure server can find it
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
str(passages_source_file.resolve()), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[

40
packages/leann/README.md Normal file
View File

@@ -0,0 +1,40 @@
# LEANN - The smallest vector index in the world
LEANN is a revolutionary vector database that democratizes personal AI. Transform your laptop into a powerful RAG system that can index and search through millions of documents while using **97% less storage** than traditional solutions **without accuracy loss**.
## Installation
```bash
# Default installation (HNSW backend, recommended)
uv pip install leann
# With DiskANN backend (for large-scale deployments)
uv pip install leann[diskann]
```
## Quick Start
```python
from leann import LeannBuilder, LeannSearcher, LeannChat
# Build an index
builder = LeannBuilder(backend_name="hnsw")
builder.add_text("LEANN saves 97% storage compared to traditional vector databases.")
builder.build_index("my_index.leann")
# Search
searcher = LeannSearcher("my_index.leann")
results = searcher.search("storage savings", top_k=3)
# Chat with your data
chat = LeannChat("my_index.leann", llm_config={"type": "ollama", "model": "llama3.2:1b"})
response = chat.ask("How much storage does LEANN save?")
```
## Documentation
For full documentation, visit [https://leann.readthedocs.io](https://leann.readthedocs.io)
## License
MIT License

View File

@@ -0,0 +1,12 @@
"""
LEANN - Low-storage Embedding Approximation for Neural Networks
A revolutionary vector database that democratizes personal AI.
"""
__version__ = "0.1.0"
# Re-export main API from leann-core
from leann_core import LeannBuilder, LeannSearcher, LeannChat
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat"]

View File

@@ -0,0 +1,42 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.1.9"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT" }
authors = [
{ name = "LEANN Team" }
]
keywords = ["vector-database", "rag", "embeddings", "search", "ai"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
# Default installation: core + hnsw
dependencies = [
"leann-core>=0.1.0",
"leann-backend-hnsw>=0.1.0",
]
[project.optional-dependencies]
diskann = [
"leann-backend-diskann>=0.1.0",
]
[project.urls]
Homepage = "https://github.com/yourusername/leann"
Documentation = "https://leann.readthedocs.io"
Repository = "https://github.com/yourusername/leann"
Issues = "https://github.com/yourusername/leann/issues"

View File

@@ -33,8 +33,8 @@ dependencies = [
"msgpack>=1.1.1",
"llama-index-vector-stores-faiss>=0.4.0",
"llama-index-embeddings-huggingface>=0.5.5",
"mlx>=0.26.3",
"mlx-lm>=0.26.0",
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
"psutil>=5.8.0",
]

View File

@@ -1,12 +0,0 @@
import faiss
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(hnsw_index.ntotal)
# print stats of the graph
print(hnsw_index.hnsw.print_neighbor_stats(0))
# save_degree_distribution
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")

View File

@@ -1,11 +0,0 @@
import faiss
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(nsg_index.ntotal)
# print stats of the graph
print(nsg_index.nsg.print_neighbor_stats(0))
# save degree distribution
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")

View File

@@ -1,63 +0,0 @@
import torch
import torch.nn as nn
import time
# import bitsandbytes as bnb
from bitsandbytes.nn import Linear8bitLt
# set default to half
import torch
torch.set_default_dtype(torch.float16)
M = 2048
N = 2048
bsz = 2048
import torch_int
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
fp16_model = nn.Sequential(
nn.Linear(M, N),
# nn.Linear(2048, 2048)
)
int8_model = nn.Sequential(
Linear8bitLt(M, N, has_fp16_weights=False),
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
# Create random input tensor
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
int8_time = speed_test(int8_model, input_tensor, "INT8")
# Calculate speedup
speedup = fp16_time / int8_time
print(f"INT8 is {speedup:.2f}x faster than FP16")

View File

@@ -1,89 +0,0 @@
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
1 n d seqlen bs latency h flop io intensity throughput series
2 3 256 256 2048 0.009623501679245285 768 618475290624 167.48502132816208 3692720015.912285 64267177503366.266 dense
3 3 256 256 1024 0.004853848615384615 768 309237645312 166.15392854317415 1861151572.059558 63709783682138.234 dense
4 3 256 256 512 0.0024687246971962615 768 154618822656 163.57953256539062 945221081.3366361 62631051097597.516 dense
5 3 256 256 256 0.0012845360838052097 768 77309411328 157.64931990085577 490388486.1451936 60184694149645.54 dense
6 3 256 256 128 0.0006901147179878049 768 38654705664 147.57393422494675 261934506.70684624 56012000116019.945 dense
7 3 256 256 64 0.0003363830693015702 768 19327352832 153.1328437752606 126212981.84970059 57456378146882.51 dense
8 3 256 256 32 0.00018671159748991485 768 9663676416 141.10249365427362 68486928.65540518 51757237075334.75 dense
9 3 256 256 16 0.00012353640857142858 768 4831838208 111.40488993609125 43371868.24359184 39112665358133.98 dense
10 3 256 256 8 9.774760007849294e-05 768 2415919104 76.43260800265766 31608487.09906635 24715891766754.14 dense
11 3 256 256 4 6.672271167474822e-05 768 1207959552 64.82614227498455 18633833.660438772 18104173551704.773 dense
12 3 256 256 2 4.9758770289855074e-05 768 603979776 55.317122669351576 10918495.880745342 12138157202874.861 dense
13 3 256 1 2048 9.785507940251571e-05 768 2415919104 76.34865809334705 31643242.518371396 24688745017132.86 dense
14 3 256 1 1024 6.692813470149253e-05 768 1207959552 64.62717090938949 18691202.70936228 18048606275785.867 dense
15 3 256 1 512 4.9680950036205655e-05 768 603979776 55.40377142534654 10901419.893658841 12157170415618.898 dense
16 3 256 1 256 4.2781118741058655e-05 768 301989888 45.95672244805227 6571179.83862661 7058952568020.829 dense
17 3 256 1 128 5.0662328255350016e-05 768 150994944 31.046026784880404 4863583.512513602 2980418571348.519 dense
18 3 256 1 64 4.475009253945481e-05 768 75497472 30.75426042497223 2454862.219307235 1687090857598.4766 dense
19 3 256 1 32 4.51682671454219e-05 768 37748736 28.29313765537115 1334201.1218340008 835735758435.5786 dense
20 3 256 1 16 5.03585186661834e-05 768 18874368 24.401035466223117 773506.846712577 374799904761.1871 dense
21 3 256 1 8 5.023459565217391e-05 768 9437184 23.972005435021096 393675.19858030166 187862246674.45105 dense
22 3 256 1 4 5.053219391083726e-05 768 4718592 23.58765586356967 200044.97383259286 93377936614.54384 dense
23 3 256 1 2 4.4607398995335484e-05 768 2359296 26.58285456464288 88752.54515134107 52890239133.797226 dense
24 12 256 256 2048 0.14480779847058822 3072 9895604649984 44.620009282941716 221775046868.20184 68336130750540.26 dense
25 12 256 256 1024 0.07254347629166667 3072 4947802324992 44.664248332585096 110777691547.58836 68204648824643.82 dense
26 12 256 256 512 0.036310761444444443 3072 2473901162496 44.876147984203506 55127306456.13385 68131349056975.164 dense
27 12 256 256 256 0.01821551906896552 3072 1236950581248 45.24607467289738 27338295977.947884 67906414116709.98 dense
28 12 256 256 128 0.009229417903030302 3072 618475290624 45.67217092440895 13541622351.335684 67011299859001.46 dense
29 12 256 256 64 0.004754550595394737 3072 309237645312 46.31372736116993 6677019167.566916 65040352207320.695 dense
30 12 256 256 32 0.002405752659340659 3072 154618822656 49.68826015254682 3111777755.5766335 64270456921525.82 dense
31 12 256 256 16 0.0012287219045005488 3072 77309411328 56.323579604557374 1372594069.3184311 62918558743709.18 dense
32 12 256 256 8 0.0006206816149425287 3072 38654705664 70.95456179103653 544781120.315271 62277832520589.78 dense
33 12 256 256 4 0.0003875502697142857 3072 19327352832 81.16954743236613 238110885.71245712 49870569942445.75 dense
34 12 256 256 2 0.00027502018627941914 3072 9663676416 91.50537035282076 105607751.53129694 35138062215483.168 dense
35 12 256 1 2048 0.0006202853873290136 3072 38654705664 70.99988634205897 544433345.6784943 62317614526515.766 dense
36 12 256 1 1024 0.00038721467732724153 3072 19327352832 81.2398957010995 237904697.74985722 49913791918755.53 dense
37 12 256 1 512 0.000274364799 3072 9663676416 91.72395326121995 105356082.81599998 35221998052308.45 dense
38 12 256 1 256 0.00012488918589482266 3072 4831838208 176.31707535146046 27404255.647778228 38689003962834.75 dense
39 12 256 1 128 8.976711102514506e-05 3072 2415919104 227.78088507574267 10606329.425740216 26913187652026.21 dense
40 12 256 1 64 8.715176287471176e-05 3072 1207959552 225.59268282689945 5354604.31102229 13860414432884.701 dense
41 12 256 1 32 8.523013435114503e-05 3072 603979776 226.06539514085782 2671703.8033338524 7086458100741.991 dense
42 12 256 1 16 7.901561645904116e-05 3072 301989888 241.35704882952732 1251216.3595988373 3821901309300.556 dense
43 12 256 1 8 7.827949114210329e-05 3072 150994944 242.37091635608994 622991.1833900034 1928920867994.581 dense
44 12 256 1 4 7.779445951035782e-05 3072 75497472 243.25022783249054 310369.58391664835 970473636235.5986 dense
45 12 256 1 2 7.758845406626506e-05 3072 37748736 243.57933441822672 154975.11761480253 486525172518.07056 dense
46 3 256 256 2048 0.00507974918466899 768 206158430208 475.59810852303485 433471930.42508715 40584371927298.98 qk_init
47 3 256 256 1024 0.0025616677649325623 768 103079215104 471.5519977009198 218595649.27424532 40239103803811.82 qk_init
48 3 256 256 512 0.0013029336670480549 768 51539607552 463.55374128015677 111183672.92143403 39556585922573.38 qk_init
49 3 256 256 256 0.0006738189029345373 768 25769803776 448.1766342333362 57499213.050413854 38244406121244.69 qk_init
50 3 256 256 128 0.000358254672959467 768 12884901888 421.47375986100144 30571065.425874516 35965760841472.125 qk_init
51 3 256 256 64 0.0002007051105022831 768 6442450944 376.1611839930762 17126836.096194826 32099087700742.5 qk_init
52 3 256 256 32 0.00012189697230142565 768 3221225472 309.6773881032524 10401874.969721656 26425803784810.87 qk_init
53 3 256 256 16 8.453561698040722e-05 768 1610612736 223.2711923587723 7213705.982328083 19052475081281.902 qk_init
54 3 256 256 8 6.407660705009276e-05 768 805306368 147.2797083750448 5467870.468274581 12567868448003.822 qk_init
55 3 256 256 4 5.036328747284576e-05 768 402653184 93.69110391262903 4297667.197682838 7994974200544.344 qk_init
56 3 256 256 2 4.5488761135057476e-05 768 201326592 51.865470527877875 3881707.616858238 4425853485045.578 qk_init
57 12 256 256 2048 0.020202365999999996 3072 824633720832 478.3437947812648 1723935231.9999998 40818670488001.266 qk_init
58 12 256 256 1024 0.010124155888157895 3072 412316860416 477.2583770318811 863927969.1228071 40726048173387.19 qk_init
59 12 256 256 512 0.005085633937062937 3072 206158430208 475.04777848703077 433974095.9627039 40537410430893.29 qk_init
60 12 256 256 256 0.0025654916853281853 3072 103079215104 470.84913933193053 218921957.14800516 40179126556324.74 qk_init
61 12 256 256 128 0.0013045765704467354 3072 51539607552 462.9699702434292 111323867.34478809 39506770794105.96 qk_init
62 12 256 256 64 0.0006742801519939804 3072 25769803776 447.87005387442576 57538572.970153 38218244597284.33 qk_init
63 12 256 256 32 0.00035831976790671853 3072 12884901888 421.3971919051604 30576620.194706645 35959227042573.69 qk_init
64 12 256 256 16 0.0002005369068918302 3072 6442450944 376.4766953382971 17112482.721436176 32126011335534.68 qk_init
65 12 256 256 8 0.00012179187250509165 3072 3221225472 309.94462293386505 10392906.453767821 26448607823689.82 qk_init
66 12 256 256 4 8.452507263643351e-05 3072 1610612736 223.2990450204527 7212806.198308992 19054851841745.297 qk_init
67 12 256 256 2 6.412381767545489e-05 3072 805306368 147.17127491946468 5471899.108305484 12558615459794.32 qk_init
68 3 256 256 2048 0.0016183739398395718 768 805306368 811597824.0 0.9922480620155039 1265467.7325087283 qk_ar
69 3 256 256 1024 0.0008322699728813558 768 402653184 405798912.0 0.9922480620155039 1230369.9921491416 qk_ar
70 3 256 256 512 0.00043886859397590365 768 201326592 202899456.0 0.9922480620155039 1166636.2255762408 qk_ar
71 3 256 256 256 0.00024185948322147648 768 100663296 101449728.0 0.9922480620155039 1058465.8355760013 qk_ar
72 3 256 256 128 0.00014308985100166944 768 50331648 50724864.0 0.9922480620155039 894542.82818777 qk_ar
73 3 256 256 64 9.382939365815932e-05 768 25165824 25362432.0 0.9922480620155039 682089.028872613 qk_ar
74 3 256 256 32 6.856070612244899e-05 768 12582912 12681216.0 0.9922480620155039 466739.6503012703 qk_ar
75 3 256 256 16 5.452260553129549e-05 768 6291456 6340608.0 0.9922480620155039 293456.26174846216 qk_ar
76 3 256 256 8 4.608557533261417e-05 768 3145728 3170304.0 0.9922480620155039 173590.1080166944 qk_ar
77 3 256 256 4 4.386146957766642e-05 768 1572864 1585152.0 0.9922480620155039 91196.21477609445 qk_ar
78 3 256 256 2 4.330941094420601e-05 768 786432 792576.0 0.9922480620155039 46179.33969539622 qk_ar
79 12 256 256 2048 0.006347041645299144 3072 3221225472 3246391296.0 0.9922480620155039 322670.011392918 qk_ar
80 12 256 256 1024 0.0031943104467592586 3072 1610612736 1623195648.0 0.9922480620155039 320569.96872013 qk_ar
81 12 256 256 512 0.0016183416350267381 3072 805306368 811597824.0 0.9922480620155039 316373.2483416833 qk_ar
82 12 256 256 256 0.0008325934893977947 3072 402653184 405798912.0 0.9922480620155039 307472.9784221131 qk_ar
83 12 256 256 128 0.0004389725746987952 3072 201326592 202899456.0 0.9922480620155039 291589.9702568624 qk_ar
84 12 256 256 64 0.00024191767449664432 3072 100663296 101449728.0 0.9922480620155039 264552.8076159138 qk_ar
85 12 256 256 32 0.0001431546143572621 3072 50331648 50724864.0 0.9922480620155039 223534.53392804778 qk_ar
86 12 256 256 16 9.404283597678917e-05 3072 25165824 25362432.0 0.9922480620155039 170135.23501087292 qk_ar
87 12 256 256 8 6.855550037091989e-05 3072 12582912 12681216.0 0.9922480620155039 116693.773026467 qk_ar
88 12 256 256 4 5.4802094978165945e-05 3072 6291456 6340608.0 0.9922480620155039 72989.91036006316 qk_ar
89 12 256 256 2 4.608510707869206e-05 3072 3145728 3170304.0 0.9922480620155039 43397.96795057727 qk_ar

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

View File

@@ -1,594 +0,0 @@
# python embedd_micro.py --use_int8 Fastest
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from torchao import quantize_
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
from contextlib import contextmanager
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False # Add this parameter
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
class CUDAGraphContainer:
"""Container for managing CUDA graphs for different batch sizes."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: Dict[int, CUDAGraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
if batch_size not in self.graphs:
self.graphs[batch_size] = CUDAGraphWrapper(
self.model, batch_size, self.seq_length
)
return self.graphs[batch_size]
class CUDAGraphWrapper:
"""Wrapper for CUDA graph capture and replay."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device="cuda",
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
model = model.cuda()
print("- Model moved to GPU")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention
if config.use_flash_attention:
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using CUDA events."""
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@contextmanager
def timing(self):
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
def elapsed_time(self) -> float:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
self.cuda_graphs = (
CUDAGraphContainer(self.model, config.seq_length)
if config.use_cuda_graphs
else None
)
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
model = AutoModel.from_pretrained(
self.config.model_path,
torch_dtype=compute_dtype,
)
# 定义替换函数
def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# 获取原始线性层的参数
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# 创建8bit线性层
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False
)
# 复制权重和偏置
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# 替换模块
setattr(model, name, new_module)
else:
# 递归处理子模块
replace_linear_with_linear8bitlt(module)
return model
# 替换所有线性层
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# 将模型移到GPU量化发生在这里
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# 使用原来的Int4量化方法
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration
# Int8 quantization using TorchAO
elif self.config.use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
# Import the quantize_ function and the quantization config
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
print("- Successfully imported TorchAO")
# Load model normally first
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = AutoModel.from_pretrained(
self.config.model_path,
device_map="auto"
)
print("- Model loaded in full precision")
print(f"- Model type: {type(model)}")
# Apply quantization - call the function to get the config, then apply it
# quantize_(model, int8_dynamic_activation_int8_weight())
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())
print("- Model successfully quantized with int8 weights and int8 activations")
# add torch compile
model = torch.compile(model)
# For older PyTorch versions that have issues with tensor subclasses
from torchao.utils import unwrap_tensor_subclass
import torch
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
unwrap_tensor_subclass(model)
# Apply optimizations
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# For better performance with int8 dynamic quantization
torch._inductor.config.force_fuse_int_mm_with_mul = True
print("- Enabled fusion of int matmul with mul operations")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {str(e)}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device="cuda",
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
# Reset peak memory stats
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create CUDA graph for this batch size
cuda_graph_wrapper = (
self.cuda_graphs.get_or_create(batch_size)
if self.cuda_graphs is not None
else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--use_fp16",
action="store_true",
help="Enable FP16 inference",
)
parser.add_argument(
"--use_int4",
action="store_true",
help="Enable INT4 quantization using bitsandbytes",
)
parser.add_argument(
"--use_int8",
action="store_true",
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available",
)
parser.add_argument(
"--use_linear8bitlt",
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=args.use_fp16,
use_int4=args.use_int4,
use_int8=args.use_int8, # Add this line
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
"results": {str(k): v for k, v in results.items()}
},
f,
indent=2
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,376 +0,0 @@
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import AutoModel
from tqdm import tqdm
from contextlib import contextmanager
import math
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_cuda_graphs: bool = False
use_flash_attention: bool = False
max_batch_size: int = 256 # Maximum batch size before splitting
class CUDAGraphContainer:
"""Container for managing CUDA graphs for different batch sizes."""
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
self.model = model
self.seq_length = seq_length
self.max_batch_size = max_batch_size
self.graphs: Dict[int, CUDAGraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
# For CUDA graphs, we always use the actual batch size or max_batch_size
effective_batch_size = min(batch_size, self.max_batch_size)
if effective_batch_size not in self.graphs:
self.graphs[effective_batch_size] = CUDAGraphWrapper(
self.model, effective_batch_size, self.seq_length
)
return self.graphs[effective_batch_size]
class CUDAGraphWrapper:
"""Wrapper for CUDA graph capture and replay."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device="cuda",
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
# Move to GPU
model = model.cuda()
print("- Model moved to GPU")
# FP16
if config.use_fp16:
model = model.half()
print("- Using FP16 precision")
# Check if using SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
# No need to do anything as it's automatically enabled
else:
print("- PyTorch SDPA not available")
# Flash Attention
if config.use_flash_attention:
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Optimize LayerNorm
try:
num_layernorms = 0
for module in model.modules():
if isinstance(module, torch.nn.LayerNorm):
module.forward = torch.jit.script(module.forward)
num_layernorms += 1
if num_layernorms > 0:
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
except Exception as e:
print(f"- LayerNorm optimization failed: {e}")
# Memory efficient attention
try:
from xformers.ops import memory_efficient_attention
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using CUDA events."""
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@contextmanager
def timing(self):
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
def elapsed_time(self) -> float:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model = self._load_model()
self.cuda_graphs = (
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
if config.use_cuda_graphs
else None
)
self.timer = Timer()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
return ModelOptimizer.optimize(model, self.config)
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device="cuda",
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
original_batch_size = input_ids.shape[0]
print(f"Original input_ids shape: {input_ids.shape}")
# Split large batches to avoid OOM
max_batch_size = self.config.max_batch_size
if original_batch_size > max_batch_size:
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
total_time = 0
outputs = []
with torch.no_grad():
for i in range(0, original_batch_size, max_batch_size):
end_idx = min(i + max_batch_size, original_batch_size)
batch_slice = input_ids[i:end_idx]
mask_slice = attention_mask[i:end_idx]
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
# Use CUDA graph if available (with the smaller batch size)
chunk_cuda_graph = None
if cuda_graph_wrapper is not None:
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
with self.timer.timing():
if chunk_cuda_graph is not None:
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
else:
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
total_time += self.timer.elapsed_time()
outputs.append(chunk_output.last_hidden_state)
# Combine outputs
combined_output = torch.cat(outputs, dim=0)
print(f"Combined output shape: {combined_output.shape}")
# Create a wrapper object similar to model output to maintain consistency
class DummyOutput:
def __init__(self, hidden_states):
self.last_hidden_state = hidden_states
output = DummyOutput(combined_output)
return total_time, output
else:
# Process normally for small batches
with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
print(f"Output shape: {output.last_hidden_state.shape}")
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create CUDA graph for this batch size
cuda_graph_wrapper = None
if self.cuda_graphs is not None:
if batch_size <= self.config.max_batch_size:
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
else:
# For large batches, we'll use the max_batch_size graph in chunks
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
# Run benchmark
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
times.append(elapsed_time)
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--no_fp16",
action="store_true",
help="Disable FP16 inference",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=256,
help="Maximum batch size before splitting to prevent OOM",
)
args = parser.parse_args()
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=not args.no_fp16,
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
max_batch_size=args.max_batch_size,
)
benchmark = Benchmark(config)
results = benchmark.run()
# Print overall summary
print("\n===== BENCHMARK SUMMARY =====")
print(f"Model: {config.model_path}")
print(f"Sequence Length: {config.seq_length}")
print(f"FP16: {config.use_fp16}")
print(f"CUDA Graphs: {config.use_cuda_graphs}")
print(f"Flash Attention: {config.use_flash_attention}")
print(f"Max Batch Size: {config.max_batch_size}")
print("\nResults:")
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
print("-" * 50)
for bs in sorted(results.keys()):
r = results[bs]
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
if __name__ == "__main__":
main()

View File

@@ -1,218 +0,0 @@
import torch
import torch.nn as nn
import time
import torch.nn.functional as F
# Import necessary functions from the quantize.py file
def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self, in_features: int, out_features: int,
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
return linear_forward_int4(
input,
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
# Define dimensions that satisfy the requirements for INT4 quantization
# in_features must be divisible by inner_k_tiles * 16
# out_features must be divisible by 8
in_features = 1024 # Must be divisible by inner_k_tiles * 16
out_features = 2048 # Must be divisible by 8
groupsize = 128
inner_k_tiles = 8
# Create models
fp16_model = nn.Sequential(
nn.Linear(in_features, out_features, bias=False)
)
# Create INT4 model
int4_model = nn.Sequential(
WeightOnlyInt4Linear(in_features, out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
)
# Quantize the weights and set up the INT4 model
with torch.no_grad():
# Convert FP16 weights to INT4
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
fp16_weight, groupsize, inner_k_tiles
)
# Set the quantized weights in the INT4 model
int4_model[0].weight.copy_(weight_int4pack)
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
# Move models to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fp16_model = fp16_model.to(device)
int4_model = int4_model.to(device)
# Create random input tensor
batch_size = 1024
input_tensor = torch.randn(batch_size, in_features, device=device)
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
int4_time = speed_test(int4_model, input_tensor, "INT4")
# Calculate speedup
speedup = fp16_time / int4_time
print(f"INT4 is {speedup:.2f}x faster than FP16")
# Calculate memory savings
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
memory_reduction = fp16_memory / int4_memory
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
# Check accuracy
with torch.no_grad():
fp16_output = fp16_model(input_tensor_bf16)
int4_output = int4_model(input_tensor)
# Calculate error metrics
abs_error = torch.abs(fp16_output - int4_output)
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
print(f"Max absolute error: {abs_error.max().item():.6f}")
print(f"Mean relative error: {rel_error.mean().item():.6f}")

View File

@@ -1,83 +0,0 @@
import torch
import nvmath.bindings.cublas
import ctypes
# 创建 CUBLAS 句柄
handle = nvmath.bindings.cublas.create()
# 准备数据 - 使用 uint8 类型,并确保内存连续
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
# 确保张量在 CUDA 上
assert a.is_cuda and b.is_cuda and c.is_cuda
# 确保张量是连续的
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
# 获取指针
a_ptr = a.data_ptr()
b_ptr = b.data_ptr()
c_ptr = c.data_ptr()
# 设置参数
transa = 0 # CUBLAS_OP_N (不转置)
transb = 0 # CUBLAS_OP_N (不转置)
transc = 0 # CUBLAS_OP_N (不转置)
# 设置偏置值
a_bias = 0
b_bias = 0
c_bias = 0
# 设置正确的 leading dimensions
lda = k # A 的 leading dimension
ldb = n # B 的 leading dimension
ldc = n # C 的 leading dimension
c_mult = 1
c_shift = 0
# 打印调试信息
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
try:
# 调用 uint8gemm_bias
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr, a_bias, lda,
b_ptr, b_bias, ldb,
c_ptr, c_bias, ldc,
c_mult, c_shift
)
except Exception as e:
print(f"Error: {e}")
# 尝试使用 ctypes 转换指针
a_ptr_c = ctypes.c_void_p(a_ptr).value
b_ptr_c = ctypes.c_void_p(b_ptr).value
c_ptr_c = ctypes.c_void_p(c_ptr).value
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
# 再次尝试调用
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr_c, a_bias, lda,
b_ptr_c, b_bias, ldb,
c_ptr_c, c_bias, ldc,
c_mult, c_shift
)
# 销毁 CUBLAS 句柄
nvmath.bindings.cublas.destroy(handle)
# 打印结果
print("Result:")
print(c)

View File

@@ -1,23 +0,0 @@
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor import oneshot
# Select quantization algorithm. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
# * quantize the weights to int8 with GPTQ (static per channel)
# * quantize the activations to int8 (dynamic per token)
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
]
# Apply quantization using the built in open_platypus dataset.
# * See examples for demos showing how to pass a custom calibration set
oneshot(
model="facebook/contriever",
dataset="open_platypus",
recipe=recipe,
output_dir="contriever-INT4",
max_seq_length=2048,
num_calibration_samples=512,
)

View File

@@ -1,41 +0,0 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
"""
This example demonstrates basic matrix multiplication of FP8 tensors.
In narrow-precision operations, quantization scales must be provided for each tensor. These
scales are used to dequantize input operands and quantize the result. Without proper
scaling, the results of FP8 operations will likely exceed the type's range.
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
capability 8.9 or higher.
"""
import torch
import nvmath
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
# transpose it.
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
# range of the data type used. Scales can be provided either as a dictionary or as a
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
scales = {"a": 1, "b": 1, "d": 0.1}
# Perform the multiplication. The result of the multiplication will be:
# (scales.a * A) @ (scales.b * B) * scales.d
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
print(result_without_scaling)
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
print(result)

View File

View File

@@ -1,58 +0,0 @@
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
def save_model_in_pth_format(model_name, output_dir):
"""
Download a model from Hugging Face and save it in PTH format
for use with quantization benchmarks.
Args:
model_name: Name of the model on Hugging Face
output_dir: Directory to save the model
"""
print(f"Loading model {model_name}...")
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Save tokenizer
tokenizer.save_pretrained(output_dir)
# Extract and save the model weights in PTH format
model_state_dict = model.state_dict()
# Save the model weights
model_path = Path(output_dir) / "model.pth"
torch.save(model_state_dict, model_path)
print(f"Model saved to {model_path}")
# Print model size information
param_count = sum(p.numel() for p in model.parameters())
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
print(f"Model parameters: {param_count:,}")
print(f"Model size: {model_size_mb:.2f} MB")
return model_path
if __name__ == "__main__":
# Use a small model for testing
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
output_dir = "./tinyllama-1.1b-chat"
model_path = save_model_in_pth_format(model_name, output_dir)
print("\nYou can now use this model with the INT4 benchmark script.")
print("Example command:")
print(f"python int4benchmark.py --model_path {model_path}")

View File

@@ -1,677 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cab91cfc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import copy\n",
"import dataclasses\n",
"import os\n",
"import time\n",
"import pathlib\n",
"import itertools\n",
"import multiprocessing\n",
"import scipy\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pickle\n",
"import gzip\n",
"import threading\n",
"import queue\n",
"import pytz\n",
"import traceback\n",
"from datetime import datetime\n",
"from tqdm.auto import tqdm, trange\n",
"from typing import Any\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as mtick\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format='retina'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d24fbd7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Apr 12 00:10:05 2025 \n",
"+-----------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
"|-----------------------------------------+------------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+========================+======================|\n",
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+------------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=========================================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "538b2c11",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
" latency = []\n",
" \n",
" # First run, ignore min_secs\n",
" if f_setup is not None:\n",
" f_setup()\n",
" st = time.perf_counter_ns()\n",
" f()\n",
" ed = time.perf_counter_ns()\n",
" latency.append((ed-st)/1e9)\n",
" \n",
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
" min_nanos = int(min_secs * 1e9)\n",
" start_nanos = time.perf_counter_ns()\n",
" while True:\n",
" now_nanos = time.perf_counter_ns()\n",
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
" break\n",
" if f_setup is not None:\n",
" f_setup()\n",
" st = time.perf_counter_ns()\n",
" f()\n",
" ed = time.perf_counter_ns()\n",
" latency.append((ed-st)/1e9)\n",
" return np.array(latency)\n",
"\n",
"def tail_mean(xs, skip=0.2):\n",
" return xs[int(len(xs) * skip):].mean()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02c9c9b1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"torch.set_grad_enabled(False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3405fdc7",
"metadata": {},
"outputs": [],
"source": [
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
"seqlen_list = [256]\n",
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "10dc981a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(12, 256), (3, 256)]\n",
"[256]\n",
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
]
}
],
"source": [
"print(nd_list)\n",
"print(seqlen_list)\n",
"print(bs_list)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7e0ee385",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
" seqlen_list = [1] + seqlen_list\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" maxbs = max(bs_list)\n",
" print(maxbs, n, d, seqlen)\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" def run():\n",
" torch.matmul(X[:bs], W)\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, X, W\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c206a502",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" try:\n",
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
" except ValueError:\n",
" pbar.update(len(bs_list))\n",
" continue\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" if bs > maxbs:\n",
" pbar.update()\n",
" continue\n",
" Q = Qmax[:bs]\n",
" K = Kmax[:bs]\n",
" def run():\n",
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, Q, K, Qmax, Kmax\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a3a2103c",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" try:\n",
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
" except ValueError:\n",
" pbar.update(len(bs_list))\n",
" continue\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" if bs > maxbs:\n",
" pbar.update()\n",
" continue\n",
" Q = Qmax[:bs]\n",
" K = Kmax[:bs]\n",
" def run():\n",
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, Q, K, Qmax, Kmax\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3aaad98a",
"metadata": {},
"outputs": [],
"source": [
"data = {}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "18137de3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/22 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
"data[\"qk_init\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "26c76e15",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
"data[\"qk_ar\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "313e36eb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 3 256 256\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 3 256 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 12 256 256\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 12 256 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
"data[\"dense\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "50c37959",
"metadata": {},
"outputs": [],
"source": [
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
" pickle.dump(data, f)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "828ddb54",
"metadata": {},
"outputs": [],
"source": [
"df_dense = (\n",
" pd.DataFrame.from_dict(data[\"dense\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
" .assign(series=\"dense\")\n",
")\n",
"df_qk_init = (\n",
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
" .assign(series=\"qk_init\")\n",
")\n",
"df_qk_ar = (\n",
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
" .assign(series=\"qk_ar\")\n",
")\n",
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "c296a395",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a25cdd5a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b8a531",
"metadata": {},
"outputs": [],
"source": [
"import transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af90eff1",
"metadata": {},
"outputs": [],
"source": [
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
" return transformers.OPTConfig(\n",
" num_hidden_layers=n_layers,\n",
" hidden_size=d_model,\n",
" ffn_dim=d_model*4,\n",
" num_attention_heads=n_heads,\n",
" **kwargs\n",
" )\n",
"optcfg = {\n",
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b9ebbec",
"metadata": {},
"outputs": [],
"source": [
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
" bs, tgt_len = input_ids.shape\n",
" if past_key_values is not None:\n",
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
" assert bs == _bs\n",
" else:\n",
" src_len = 0\n",
" if attention_mask is None:\n",
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
" ret = model(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" past_key_values=past_key_values,\n",
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
" )\n",
" return ret\n",
"\n",
"def time_greedy_generate(model, input_ids, new_tokens):\n",
" ts = []\n",
" output = input_ids\n",
" past_key_values = None\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
" for _ in range(new_tokens):\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" st = time.perf_counter_ns()\n",
" \n",
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
" output = torch.cat([output, input_ids], axis=1)\n",
" past_key_values = ret.past_key_values\n",
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
" \n",
" torch.cuda.synchronize()\n",
" ed = time.perf_counter_ns()\n",
" ts.append((ed-st)/1e9)\n",
" return np.array(ts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc92f940",
"metadata": {},
"outputs": [],
"source": [
"opt_config = optcfg[\"6.7b\"]\n",
"\n",
"torch.set_default_dtype(torch.bfloat16)\n",
"with transformers.modeling_utils.no_init_weights():\n",
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
"torch.set_default_dtype(torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c19fa396",
"metadata": {},
"outputs": [],
"source": [
"db = {}\n",
"input_tokens = 200\n",
"new_tokens = 500\n",
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
" stack = []\n",
" for _ in range(10):\n",
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
" stack.append(l)\n",
" db[bs] = np.median(np.stack(stack), axis=0)\n",
" del x\n",
" torch.cuda.empty_cache()\n",
"del model\n",
"torch.cuda.empty_cache()\n",
"\n",
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
" pickle.dump(db, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,165 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Set plot parameters
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# Path settings
FIGURE_PATH = "./paper_plot/figures"
# Load accuracy data
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
# Create figure with 4 subplots (one for each dataset)
fig, axs = plt.subplots(1, 4)
fig.set_size_inches(9, 2.5)
# Reduce the spacing between subplots
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
# Define datasets and their columns
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
metrics = ["Exact Match", "F1"]
# Define bar settings - make bars thicker
# total_width, n = 0.9, 3 # increased total width and n for three models
# width = total_width / n
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
# It's also used as the base for calculating the actual plotted bar width.
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
n = 3 # Number of models
width = 0.64 # Distance between centers of adjacent bars in a group
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
# Colors and hatches
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
# Create plots for each dataset
for i, dataset in enumerate(datasets):
ax = axs[i]
# Get data for this dataset and convert to percentages
em_values = [
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
]
f1_values = [
acc_data.loc[0, f"{dataset} F1"] * 100,
acc_data.loc[1, f"{dataset} F1"] * 100,
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
]
# Define x positions for bars
# For EM: center - width, center, center + width
# For F1: center - width, center, center + width
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
bar_offsets = [-width, 0, width]
# Plot all bars on the same axis
for metric_idx, metric_group_center in enumerate(group_centers):
values_to_plot = em_values if metric_idx == 0 else f1_values
for j, model_label in enumerate(labels):
x_pos = metric_group_center + bar_offsets[j]
bar_value = values_to_plot[j]
ax.bar(
x_pos,
bar_value,
width=width * bar_width_plotting_factor, # Use the new factor for bar width
color="white",
edgecolor=edgecolors[j],
hatch=hatches[j],
linewidth=1.5,
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
)
# Add value on top of bar
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
f"{bar_value:.1f}", ha='center', va='bottom',
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
# Set x-ticks and labels
ax.set_xticks(group_centers) # Position ticks at the center of each group
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
# Now, shift these labels slightly to the right
# Adjust this value to control the amount of shift (in data coordinates)
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
# horizontal_shift = 0.7 # Try adjusting this value
# for label in xticklabels:
# # Get the current x position (which is the tick location)
# current_x_pos = label.get_position()[0]
# # Set the new x position by adding the shift
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
# # Ensure the label remains horizontally centered on this new x position
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
# label.set_horizontalalignment('center')
# Set title
ax.set_title(dataset, fontsize=14)
# Set y-label for all subplots
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
else:
# Hide y-tick labels for non-first subplots to save space
ax.tick_params(axis='y', labelsize=10)
# Set y-limits based on data range
all_values = em_values + f1_values
max_val = max(all_values)
min_val = min(all_values)
# Special handling for GPQA which has very low values
if dataset == "GPQA":
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
else:
# Reduce the extra space above the bars
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
# Format y-ticks as percentages
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
# Set x-limits to properly space the bars with less blank space
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
# Add a box around the subplot
# for spine in ax.spines.values():
# spine.set_visible(True)
# spine.set_linewidth(1.0)
# Add legend to first subplot
if i == 0:
ax.legend(
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
ncol=3, # Changed to 3 columns for three labels
loc="upper center",
labelspacing=0.1,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=1.0,
handletextpad=0.6,
columnspacing=0.8,
prop={"weight": "bold", "size": 12},
)
# Save figure with tight layout but no additional padding
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
plt.show()

View File

@@ -1,309 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /hnsw_degree_visit_plot_binned_academic.py
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
# per degree bin, styled for academic publications, with caching.
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from collections import Counter
import os # For robust filepath manipulation
import math # For calculating scaling factor
import pickle # For caching data
# %%
# --- Matplotlib parameters for academic paper style (from reference) ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
# --- Define styles from reference ---
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
# %%
# --- File Paths ---
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
visit_log_file = './re.log'
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
# --- CACHE FILE PATH: Keep this consistent ---
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
# --- Configuration ---
# Set to True to bypass cache and force recomputation.
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
FORCE_RECOMPUTE = False
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
# Create directory for figures if it doesn't exist
output_dir = os.path.dirname(output_image_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created directory: {output_dir}")
# %%
# --- Attempt to load data from cache or compute ---
df_plot_data = None
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
try:
with open(CACHE_FILE_PATH, 'rb') as f:
cache_content = pickle.load(f)
df_plot_data = cache_content['data']
bin_size_for_plot = cache_content['bin_size']
# Basic validation of cached data
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
if not isinstance(df_plot_data, pd.DataFrame) or \
'degree_bin_label' not in df_plot_data.columns or \
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
not isinstance(bin_size_for_plot, int):
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
else:
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
# --- Modify the label loaded from cache for display purpose ---
# This modification only happens when data is loaded from cache and meets specific conditions.
# Assumption: If the bin_size_for_plot in cache is 5,
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
# Check if "0-4" label exists
if '0-4' in df_plot_data['degree_bin_label'].values:
# Use .loc to ensure the modification is on the original DataFrame
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
except Exception as e:
print(f"Error loading from cache: {e}. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
if df_plot_data is None:
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
# --- 1. Read Degree Distribution File ---
degrees_data = []
try:
with open(degree_file, 'r') as f:
for i, line in enumerate(f):
line_stripped = line.strip()
if line_stripped:
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
except FileNotFoundError:
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
if not degrees_data:
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
exit()
df_degrees = pd.DataFrame(degrees_data)
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
# --- 2. Read Visit Log File and Count Frequencies ---
visit_counts = Counter()
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
try:
with open(visit_log_file, 'r') as f_log:
for line_num, line in enumerate(f_log, 1):
match = node_id_pattern.search(line)
if match:
try:
node_id = int(match.group(2))
visit_counts[node_id] += 1 # Increment visit count for the node
except ValueError:
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
except FileNotFoundError:
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
if not df_degrees.empty:
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
# Generate visit counts to test different probability magnitudes
if node_id_val % 23 == 0: # Very low probability
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
elif node_id_val % 11 == 0: # Low probability
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
elif node_id_val % 5 == 0: # Moderate probability
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
else: # Higher probability (but still < 1000 visits for a single node usually)
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
visit_counts[node_id_val] = np.random.poisson(lambda_val)
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
if not visit_counts:
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
else:
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
df_visits = pd.DataFrame(df_visits_list)
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
# --- 3. Merge Degree Data with Visit Data ---
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
print(f"Merged data contains {len(df_merged)} entries.")
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
current_bin_size = 5
bin_size_for_plot = current_bin_size
if not df_degrees.empty:
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
df_merged_with_bins = df_merged.copy()
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
total_visit_count_in_bin=('visit_count', 'sum'),
node_count_in_bin=('node_id', 'nunique')
).reset_index()
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
# This value is what gets cached.
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
bin_to_drop_label = '60-64'
original_length = len(df_binned_analysis)
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
if len(df_plot_data_intermediate) < original_length:
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
else:
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
df_plot_data = df_plot_data_intermediate
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
if df_plot_data is not None and not df_plot_data.empty:
try:
with open(CACHE_FILE_PATH, 'wb') as f:
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
except Exception as e:
print(f"Error saving data to cache: {e}")
elif df_plot_data is None or df_plot_data.empty:
print("Computed data for binned plot is empty, not saving to cache.")
else:
print("Degree data (df_degrees) is empty. Cannot perform binning.")
df_plot_data = pd.DataFrame()
bin_size_for_plot = current_bin_size
# %%
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
base_name, ext = os.path.splitext(output_image_file)
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
binned_output_image_file = base_name + ext
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
df_plot_data_plotting = df_plot_data.copy()
# Calculate per-query probability: (avg visits over N queries) / N
df_plot_data_plotting['per_query_visit_probability'] = \
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
apply_scaling_to_label_and_values = False # Initialize flag
exponent_for_label_display = 0 # Initialize exponent
if pd.notna(max_probability) and max_probability > 0:
potential_exponent = math.floor(math.log10(max_probability))
if potential_exponent <= -4 or potential_exponent >= 0:
apply_scaling_to_label_and_values = True
exponent_for_label_display = potential_exponent
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
if apply_scaling_to_label_and_values:
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
else:
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
elif pd.notna(max_probability) and max_probability == 0:
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
else:
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
ax.bar(
df_plot_data_plotting['degree_bin_label'],
y_axis_values_to_plot,
color='white',
edgecolor=edgecolors_ref[0],
linewidth=1.5,
width=0.8
)
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
num_bins = len(df_plot_data_plotting)
if num_bins > 12:
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
elif num_bins > 8:
ax.tick_params(axis='x', labelsize=9)
else:
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
padding_factor = 0.05
current_max_y_on_axis = y_axis_values_to_plot.max()
upper_y_limit = 0.1 # Default small upper limit
if pd.notna(current_max_y_on_axis):
if current_max_y_on_axis > 0:
# Adjust minimum visible range based on whether scaling was applied and the exponent
min_meaningful_limit = 0.01
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
else: # current_max_y_on_axis is 0
upper_y_limit = 0.1
ax.set_ylim(0, upper_y_limit)
else:
ax.set_ylim(0, 1.0) # Default for empty or NaN data
plt.tight_layout()
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
print(f"Binned bar chart saved to {binned_output_image_file}")
plt.show()
plt.close(fig)
else:
if df_plot_data is None:
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
elif df_plot_data.empty:
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
# %%
print("Script finished.")

View File

@@ -1,7 +0,0 @@
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data up to 50× smaller than standard indexes while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.

View File

@@ -1,81 +0,0 @@
import numpy as np
import os
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
BIG_GRAPH_PATHS = [
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
]
STATS_FILE_NAME = "degree_distribution.txt"
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
# Average degrees are static and can be directly used in the plotting script or also cached.
# For simplicity here, we'll focus on caching the dynamic degree arrays.
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
# --- Cache File Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
def create_degree_data_cache():
"""
Reads degree distribution data from specified text files and saves it
into a compressed NumPy (.npz) cache file.
"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
cached_data = {}
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
for i, base_path in enumerate(BIG_GRAPH_PATHS):
method_label = BIG_GRAPH_LABELS[i]
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
print(f"Processing: {method_label} from {degree_file_path}")
try:
# Load degrees as integers
degrees = np.loadtxt(degree_file_path, dtype=int)
if degrees.size == 0:
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
# Store an empty array or handle as needed. For npz, an empty array is fine.
cached_data[method_label] = np.array([], dtype=int)
else:
# Store the loaded degrees array with the method label as the key
cached_data[method_label] = degrees
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
except FileNotFoundError:
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
# Storing None might require special handling when loading. Empty array is safer for np.load.
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
except Exception as e:
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
if not cached_data:
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
return
try:
# Save all collected degree arrays into a single .npz file.
# Using savez_compressed for potentially smaller file size.
np.savez_compressed(cache_file_path, **cached_data)
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
print("Cached arrays (keys):", list(cached_data.keys()))
except Exception as e:
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
if __name__ == "__main__":
print("--- Degree Distribution Data Caching Script ---")
create_degree_data_cache()
print("--- Caching script finished. ---")

View File

@@ -1,4 +0,0 @@
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
1 Model NQ Exact Match NQ F1 TriviaQA Exact Match TriviaQA F1 GPQA Exact Match GPQA F1 HotpotQA Exact Match HotpotQA F1
2 BM25 0.192 0.277 0.406 0.474 0.020089 0.04524 0.162 0.239
3 PQ 5 0.2075 0.291 0.422 0.495 0.0201 0.0445 0.148 0.219
4 Ours 0.265 0.361 0.533 0.604 0.02008 0.0452 0.182 0.2729

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
size 227482438

View File

@@ -1,21 +0,0 @@
2,1,512,1024,0.541,0.326,1.659509202
2,2,512,1024,0.979,0.621,1.576489533
2,4,512,1024,1.846,0.977,1.889457523
2,8,512,1024,3.575,1.943,1.83993824
2,16,512,1024,7.035,3.733,1.884543263
2,32,512,1024,15.655,8.517,1.838088529
2,64,512,1024,32.772,17.43,1.88020654
4,1,512,1024,2.675,1.38,1.938405797
4,2,512,1024,5.397,2.339,2.307396323
4,4,512,1024,10.672,4.944,2.158576052
4,8,512,1024,21.061,9.266,2.272933305
4,16,512,1024,46.332,18.334,2.527108105
4,32,512,1024,99.607,36.156,2.754923111
4,64,512,1024,186.348,72.356,2.575432583
8,1,512,1024,7.325,4.087,1.792268167
8,2,512,1024,14.109,7.491,1.883460152
8,4,512,1024,28.499,14.013,2.033754371
8,8,512,1024,65.222,27.453,2.375769497
8,16,512,1024,146.294,52.55,2.783901047
8,32,512,1024,277.099,103.61,2.674442621
8,64,512,1024,512.979,208.36,2.461984066
1 2 1 512 1024 0.541 0.326 1.659509202
2 2 2 512 1024 0.979 0.621 1.576489533
3 2 4 512 1024 1.846 0.977 1.889457523
4 2 8 512 1024 3.575 1.943 1.83993824
5 2 16 512 1024 7.035 3.733 1.884543263
6 2 32 512 1024 15.655 8.517 1.838088529
7 2 64 512 1024 32.772 17.43 1.88020654
8 4 1 512 1024 2.675 1.38 1.938405797
9 4 2 512 1024 5.397 2.339 2.307396323
10 4 4 512 1024 10.672 4.944 2.158576052
11 4 8 512 1024 21.061 9.266 2.272933305
12 4 16 512 1024 46.332 18.334 2.527108105
13 4 32 512 1024 99.607 36.156 2.754923111
14 4 64 512 1024 186.348 72.356 2.575432583
15 8 1 512 1024 7.325 4.087 1.792268167
16 8 2 512 1024 14.109 7.491 1.883460152
17 8 4 512 1024 28.499 14.013 2.033754371
18 8 8 512 1024 65.222 27.453 2.375769497
19 8 16 512 1024 146.294 52.55 2.783901047
20 8 32 512 1024 277.099 103.61 2.674442621
21 8 64 512 1024 512.979 208.36 2.461984066

View File

@@ -1,9 +0,0 @@
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
NQ,Latency,6.9,5.8,4.2,3.7
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
TriviaQA,Latency,17.054,14.542,12.046,10.83
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
GPQA,Latency,9.164,7.639,6.798,5.77
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
HotpotQA,Latency,60.279,39.827,50.664,29.868
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
1 Dataset Metric Original original + batch original + two_level original + two_level + batch
2 NQ Latency 6.9 5.8 4.2 3.7
3 NQ SpeedUp 1 1.18965517 1.64285714 1.86486486
4 TriviaQA Latency 17.054 14.542 12.046 10.83
5 TriviaQA SpeedUp 1 1.17274103 1.41573967 1.57469990
6 GPQA Latency 9.164 7.639 6.798 5.77
7 GPQA SpeedUp 1 1.19963346 1.34804354 1.58821490
8 HotpotQA Latency 60.279 39.827 50.664 29.868
9 HotpotQA SpeedUp 1 1.51352098 1.18977972 2.01817999

View File

@@ -1,25 +0,0 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25 LLM_Gen_Time_1B LLM_Gen_Time_3B LLM_Gen_Time_7B
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 3.323 0.021 0.085 0.217 0.472
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 4.616 0 0.085 0.217 0.472
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 19.494 0 0.085 0.217 0.472
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 7.971 0.033 0.316 0.717 1.468
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 13.843 0 0.316 0.717 1.468
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 44.363 0 0.316 0.717 1.468
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 3.752 0.033 0.139 0.156 0.315
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 5.777 0 0.139 0.156 0.315
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 20.944 0 0.139 0.156 0.315
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 8.889 0.036 0.325 0.692 1.415
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 17.145 0 0.325 0.692 1.415
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 47.909 0 0.325 0.692 1.415
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 1.897 0.137 0.443 0.396 0.651
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 1.733 0 0.443 0.396 0.651
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 4.033 0 0.443 0.396 0.651
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 4.762 0.100 0.37 0.813 1.676
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 5.223 0 0.37 0.813 1.676
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 9.715 0 0.37 0.813 1.676
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 10.358 0.70 0.144 0.196 0.420
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 15.515 0 0.144 0.196 0.420
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 61.757 0 0.144 0.196 0.420
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 23.636 0.052 0.144 0.196 0.420
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 44.803 0 0.144 0.196 0.420
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 140.62 0 0.144 0.196 0.420

View File

@@ -1,25 +0,0 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 4.243
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 8.136
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 27.275
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 10.672
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 19.941
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 61.383
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 5.612
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 10.737
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 36.387
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 12.825
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 24.977
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 85.734
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 2.269
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 3.200
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 7.445
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 6.123
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 8.507
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 19.577
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 14.713
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 33.561
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 68.626
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 34.783
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 53.004
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 95.413

View File

@@ -1,3 +0,0 @@
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
RAM,190,171,10,0,0,0,0
Storage,185.4,171,240,171,0.5,5,59
1 Hardware HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25
2 RAM 190 171 10 0 0 0 0
3 Storage 185.4 171 240 171 0.5 5 59

View File

@@ -1,12 +0,0 @@
Torch,8,55.592
Torch,16,75.439
Torch,32,110.025
Torch,64,186.496
Tutel,8,56.718
Tutel,16,82.121
Tutel,32,125.070
Tutel,64,216.191
BRT,8,56.725
BRT,16,79.291
BRT,32,93.180
BRT,64,118.923
1 Torch 8 55.592
2 Torch 16 75.439
3 Torch 32 110.025
4 Torch 64 186.496
5 Tutel 8 56.718
6 Tutel 16 82.121
7 Tutel 32 125.070
8 Tutel 64 216.191
9 BRT 8 56.725
10 BRT 16 79.291
11 BRT 32 93.180
12 BRT 64 118.923

View File

@@ -1,6 +0,0 @@
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
Latency,,,,,
NQ,4.616,4.133,3.826,3.511,3.323
TriviaQA,5.777,4.979,4.553,4.141,3.916
GPQA,1.733,1.593,1.468,1.336,1.259
Hotpot,15.515,13.479,12.383,11.216,10.606
1 Disk cache size 0 2.5%(180G*2.5%) 5% 8% 10%
2 Latency
3 NQ 4.616 4.133 3.826 3.511 3.323
4 TriviaQA 5.777 4.979 4.553 4.141 3.916
5 GPQA 1.733 1.593 1.468 1.336 1.259
6 Hotpot 15.515 13.479 12.383 11.216 10.606

View File

@@ -1,151 +0,0 @@
import matplotlib
from matplotlib.axes import Axes
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
# plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % Use Helvetica font for text
\usepackage{sfmath} % Use sans-serif font for math
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
\usepackage[T1]{fontenc} % Recommended for font encoding
"""
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# New data in dictionary format
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
latency_data = {
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
}
cache_hit_counts = {
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
}
# Create the figure with 4 subplots in a 2x2 grid
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
# Bar style settings
width = 0.7
x = np.arange(len(cache_ratios))
# Define hatch patterns for different cache ratios
hatch_patterns = ['//', '//', '//', '//', '//']
# Find max cache hit value across all datasets for unified y-axis
all_hit_counts = []
for dataset in datasets:
all_hit_counts.extend(cache_hit_counts[dataset])
max_unified_hit = max(all_hit_counts) * 1.13
for i, dataset in enumerate(datasets):
latencies = latency_data[dataset]
hit_counts = cache_hit_counts[dataset]
for j, val in enumerate(latencies):
container = axes[i].bar(
x[j],
val,
width=width,
color="white",
edgecolor="black",
linewidth=1.0,
zorder=10,
)
axes[i].bar_label(
container,
[f"{val:.2f}"],
fontsize=10,
zorder=200,
fontweight="bold",
)
axes[i].set_title(dataset, fontsize=font_size)
axes[i].set_xticks(x)
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
max_val = max(latencies) * max_val_ratios[i]
axes[i].set_ylim(0, max_val)
axes[i].tick_params(axis='y', labelsize=12)
if i % 2 == 0:
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
ax2: Axes = axes[i].twinx()
ax2.plot(x, hit_counts,
linestyle='--',
marker='o',
markersize=6,
linewidth=1.5,
color='k',
markerfacecolor='none',
zorder=20)
ax2.set_ylim(0, max_unified_hit)
ax2.tick_params(axis='y', labelsize=12)
if i % 2 == 1:
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
for j, val in enumerate(hit_counts):
if val > 0:
ax2.annotate(f"{val:.1f}%",
(x[j], val),
textcoords="offset points",
xytext=(0, 5),
ha='center',
va='bottom',
fontsize=10,
fontweight='bold')
# Create legend for both plots
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
# --- MODIFICATION FOR LEGEND AT THE TOP ---
fig.legend(handles=[bar_patch, line_patch],
loc='upper center', # Position the legend at the upper center
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
# 0.97 means 97% from the bottom, so near the top)
ncol=3,
fontsize=font_size-2)
# --- END OF MODIFICATION ---
# Set common x-axis label - you might want to add this back if needed
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
# --- MODIFICATION FOR TIGHT LAYOUT ---
# Adjust rect to make space for the legend at the top.
# (left, bottom, right, top_for_subplots)
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
# leaving the top portion (0.93 to 1.0) for the legend.
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
# --- END OF MODIFICATION ---
# Create directory if it doesn't exist (optional, good practice)
import os
if not os.path.exists(SAVE_PTH):
os.makedirs(SAVE_PTH)
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
# plt.show() # Optional: to display the plot

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

View File

Binary file not shown.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 41 KiB

Some files were not shown because too many files have changed in this diff Show More