Compare commits

...

62 Commits

Author SHA1 Message Date
GitHub Actions
802020cb41 chore: release v0.1.12 2025-07-26 23:35:28 +00:00
yichuan520030910320
cdb92f7cf4 update pytoml version && fix colab env && fix pdf extract in pip 2025-07-26 16:33:13 -07:00
yichuan520030910320
dc69bdec00 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 17:54:43 -07:00
yichuan520030910320
98073e9868 update missing pkg 2025-07-25 17:54:21 -07:00
GitHub Actions
cf2ef48967 chore: release v0.1.11 2025-07-26 00:12:37 +00:00
yichuan520030910320
0692bbf7a2 change workflow 2025-07-25 17:11:56 -07:00
GitHub Actions
52584a171f chore: release v0.1.10 2025-07-25 23:12:16 +00:00
Andy Lee
efd6b5324b fix: add protobuf as a dependency for DiskANN backend
- Fixes 'No module named google' error when starting DiskANN embedding server
- Prevents users from having to manually install protobuf
2025-07-25 16:10:25 -07:00
Andy Lee
2baaa4549b fix: handle relative paths in HNSW embedding server metadata
- Convert relative paths to absolute paths based on metadata file location
- Fixes FileNotFoundError when starting embedding server
- Resolves issue with passages file not found in different working directories
2025-07-25 16:09:53 -07:00
Andy Lee
35310ddd52 fix: pure Python packages not building due to ubuntu-latest check
The build workflow was checking for matrix.os == 'ubuntu-latest',
but we changed the matrix to use 'ubuntu-22.04', causing the
pure Python packages (leann-core and leann) to never be built.

Changed to use pattern matching [[ == ubuntu-* ]] to match any
Ubuntu version.

This explains why v0.1.9 only published the C++ backend packages
but not the pure Python packages.
2025-07-25 15:14:21 -07:00
Andy Lee
fc9c5cb39d fix: make release workflow idempotent
- Check if version is already updated before trying to update
- Check if tag already exists before creating
- Check if GitHub release already exists before creating
- This allows re-running the workflow after partial failures

Previously, if the workflow failed after updating version but before
completing the release, it couldn't be re-run with the same version.
2025-07-25 14:47:35 -07:00
Andy Lee
8f2a1e87ea Merge pull request #7 from yichuan-w/fix/simple-ubuntu22-build
fix: simplify build system for Colab compatibility
2025-07-25 14:08:37 -07:00
Andy Lee
50caf65f28 fix: change ubuntu-latest to ubuntu-22.04 and add Python 3.13
- Explicitly use ubuntu-22.04 instead of ubuntu-latest
- Add Python 3.13 to the build matrix
- This ensures we build on the same OS version as Google Colab
2025-07-25 13:48:59 -07:00
Andy Lee
1b48794ca8 cleanup: remove cibuildwheel workflow files
- Remove ci-cibuildwheel.yml and build-cibuildwheel.yml
- These files were not present in v0.1.5
- Keep only the simple build system
2025-07-25 13:48:08 -07:00
Andy Lee
4aef1d814e revert: simplify build system by removing manylinux/cibuildwheel
- Revert to simple Ubuntu 22.04 builds that should work with Colab
- Remove all manylinux container complexity
- Colab runs on Ubuntu 22.04, so direct builds should be compatible
- Restore build-reusable.yml to v0.1.5 version
- Remove cibuildwheel option from release workflow

This should fix the overcomplicated build issues while maintaining
Colab compatibility through direct Ubuntu 22.04 builds.
2025-07-25 13:46:51 -07:00
GitHub Actions
75ddcd6158 chore: release v0.1.9 2025-07-25 20:04:42 +00:00
Andy Lee
2a4df11f5c fix: absolute path for passages 2025-07-25 11:59:30 -07:00
Andy Lee
5eb893c62b ci: add Python 3.13 support to build matrix 2025-07-25 09:53:36 -07:00
yichuan520030910320
d91ce2e94d readme 2025-07-25 02:19:54 -07:00
yichuan520030910320
5c2ff8a641 clean research stuff 2025-07-25 02:14:15 -07:00
yichuan520030910320
d4f474c9b7 update broken link 2025-07-25 02:13:22 -07:00
yichuan520030910320
170f7644e9 simplify readme 2025-07-25 02:11:02 -07:00
yichuan520030910320
cd8b970eff Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 01:45:57 -07:00
yichuan520030910320
52153bbb69 update faiss compare 2025-07-25 01:45:50 -07:00
GitHub Actions
e1ae087207 chore: release v0.1.8 2025-07-25 08:24:40 +00:00
Andy Lee
48c5e12ac1 fix: use absolute path for passages_file to prevent FileNotFoundError
When embedding server is launched as a subprocess, it may run in a different
working directory. Using absolute paths ensures the server can always find
the metadata file regardless of where it's launched from.
2025-07-25 01:23:47 -07:00
yichuan520030910320
f8b5c97190 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 00:37:33 -07:00
yichuan520030910320
d038c81b8b update benchmard section 2025-07-25 00:37:27 -07:00
Andy Lee
29cbbbd0d6 fix: resolve libzmq pkg-config issues in manylinux containers
- Add gcc-c++ and cmake to dependencies
- Create libzmq.pc file if missing (CentOS 7 issue)
- Set PKG_CONFIG_PATH through CIBW_ENVIRONMENT_LINUX
- Add protobuf-devel to ensure all headers are available
- Fix shell variable escaping in heredoc
2025-07-25 00:35:52 -07:00
Andy Lee
179f30bc36 fix: improve system dependency installation in manylinux containers
- Add yum cache cleaning and updating
- Make package installations more resilient with fallbacks
- Use pkgconfig instead of pkg-config (CentOS 7 naming)
- Handle optional packages that might not be available
- Add error handling for package installation failures
2025-07-25 00:30:29 -07:00
Andy Lee
c4a0a68581 fix: handle pure Python packages in cibuildwheel workflow
- Build pure Python packages (leann-core, leann) with standard build tool
- Only use cibuildwheel for C extension packages (leann-backend-hnsw, leann-backend-diskann)
- Build pure Python packages only once on ubuntu-latest
- Add Python setup for building pure packages
- Add package listing step for debugging
2025-07-25 00:26:15 -07:00
Andy Lee
5c836ad08e fix: handle git dubious ownership error in manylinux containers
- Add multiple safe.directory configurations to cover different possible paths
- This fixes 'detected dubious ownership in repository' error
- Ensures git works properly in manylinux2014 containers
2025-07-25 00:22:01 -07:00
Andy Lee
673fd9b7cd fix: upgrade to actions v4 and handle manylinux2014 compatibility
- Upgrade all GitHub Actions to v4 (v3 is deprecated)
- Use manual git checkout in manylinux2014 containers to avoid Node.js issues
- Update artifact naming to ensure uniqueness (required by v4)
- Add fail-fast: false to build strategies
- This maintains manylinux2014 compatibility while using latest actions
2025-07-25 00:20:21 -07:00
Andy Lee
84b24b233d feat: add cibuildwheel option to release workflow
- Add optional use_cibuildwheel parameter to release workflow
- Create separate CI workflow for testing cibuildwheel
- Support conditional build workflow selection in release process
- This allows building wheels compatible with Google Colab and older systems
- Maintains backward compatibility with existing build process
2025-07-25 00:16:08 -07:00
Andy Lee
499cdd7822 feat: add cibuildwheel workflow for better platform compatibility
- Use cibuildwheel for professional wheel building
- Specifically use manylinux2014 for Google Colab compatibility
- Supports Python 3.9-3.12 on Linux and macOS
- Handles monorepo structure with separate builds per package
- Includes basic import tests for each package
- This should resolve compatibility issues with older systems like Google Colab
2025-07-25 00:16:08 -07:00
yichuan520030910320
800d4cf111 Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-25 00:12:47 -07:00
yichuan520030910320
b6d43f5fd9 add gif 2025-07-25 00:12:35 -07:00
Andy Lee
3603cd5034 fix: downgrade GitHub Actions versions for manylinux2014 compatibility
- Use actions/checkout@v3 instead of v4 (Node.js 16 vs 20)
- Use actions/setup-python@v4 instead of v5
- Use actions/upload-artifact@v3 and download-artifact@v3
- This fixes GLIBC version errors in manylinux2014 containers
- manylinux2014 (CentOS 7) has glibc 2.17 but Node.js 20 needs 2.25+
2025-07-25 00:12:05 -07:00
Andy Lee
6df7893173 feat: use manylinux2014 containers for better Linux compatibility
- Add manylinux2014 Docker containers for Linux builds
- This will generate wheels compatible with older Linux systems (CentOS 7+, Ubuntu 16.04+)
- Separate build logic for container vs regular environments
- Install appropriate system dependencies for yum-based manylinux environment
- Use pip instead of uv in containers for better compatibility
- Fix Python version format for manylinux container paths
2025-07-25 00:08:42 -07:00
GitHub Actions
e64b599276 chore: release v0.1.7 2025-07-25 04:47:57 +00:00
Andy Lee
2dd59c4ba1 fix: let auditwheel auto-detect manylinux platform tag
- Remove --plat manylinux2014_x86_64 flag that was causing build failures
- Let auditwheel automatically determine the appropriate manylinux tag
- Add auditwheel show command to display compatibility info
- This fixes the 'too-recent versioned symbols' error
2025-07-24 21:44:15 -07:00
GitHub Actions
166986d5e6 chore: release v0.1.6 2025-07-25 04:30:07 +00:00
Andy Lee
a6aec68f32 fix: use manylinux2014 for better Linux compatibility
- Change auditwheel --plat to manylinux2014_x86_64
- This ensures wheels work on Ubuntu 16.04+ instead of requiring 24.04+
- Fixes compatibility issues for users on Ubuntu 22.04 and similar systems
2025-07-24 21:26:28 -07:00
GitHub Actions
ed27a127d5 chore: release v0.1.5 2025-07-25 04:00:54 +00:00
Andy Lee
d8b4ea7564 fix: add write permissions for GitHub Actions to push commits 2025-07-24 20:55:24 -07:00
Andy Lee
f0a2ef96b4 fix: restore complete build configuration from working version 2025-07-24 19:49:38 -07:00
Andy Lee
7d73c2c803 fix: remove invalid --extra build flag from build commands 2025-07-24 19:43:23 -07:00
Andy Lee
e8d2ecab03 refactor: use reusable workflow to avoid code duplication 2025-07-24 19:35:12 -07:00
Andy Lee
32a374d094 feat: true one-click automated release with multi-platform support 2025-07-24 19:30:44 -07:00
Andy Lee
d45c013806 fix: handle workflow trigger permission gracefully 2025-07-24 19:25:29 -07:00
GitHub Actions
9000a7083d chore: release v0.1.4 2025-07-25 02:23:36 +00:00
Andy Lee
8307555d54 fix: manually trigger CI after version push in release workflow 2025-07-24 19:21:32 -07:00
GitHub Actions
20f2aece08 chore: release v0.1.3 2025-07-25 02:05:11 +00:00
yichuan520030910320
43eb4f9a1d Merge branch 'main' of https://github.com/yichuan-w/LEANN 2025-07-24 19:03:52 -07:00
yichuan520030910320
5461b71d8c colab dev 2025-07-24 19:03:46 -07:00
Andy Lee
374db0ebb8 fix: release workflow to build new version before publishing 2025-07-24 19:03:09 -07:00
GitHub Actions
cea1f6f87c chore: release v0.1.2 2025-07-25 01:53:29 +00:00
Andy Lee
6c0e39372b fix: download all artifacts in release workflow 2025-07-24 17:45:46 -07:00
Andy Lee
2bec67d2b6 feat: auto-update leann-core dependencies during release
- Enhanced bump_version.sh to automatically update leann-core dependency versions
- Script now updates both package versions and their leann-core dependencies
- This ensures version consistency across all packages during release

No more manual dependency version updates needed!
2025-07-24 17:22:41 -07:00
Andy Lee
133e715832 fix: resolve CI issues and consolidate workflows
- Fix version dependencies: update backend packages to depend on leann-core==0.1.1
- Remove duplicate ci.yml workflow (keeping build-and-publish.yml as main CI)
- Update release-manual.yml to reference correct CI workflow name

This fixes the dependency resolution error and eliminates duplicate builds.
2025-07-24 17:20:58 -07:00
Andy Lee
95cf2f16e2 refactor: consolidate release and publish into single workflow
- Manual Release workflow now directly publishes to PyPI after downloading CI artifacts
- No more duplicate builds - reuses artifacts from CI
- build-and-publish.yml renamed to 'CI - Build Multi-Platform Packages'
- Publishing in CI workflow only for emergency manual triggers
- Updated RELEASE.md to reflect the new streamlined process

This fixes the issue where releases would trigger redundant builds.
2025-07-24 17:04:47 -07:00
Andy Lee
47a4c153eb fix: enable PyPI publish on tag push
- Manual Release workflow creates tags but build-and-publish.yml only published on 'release' events
- Now build-and-publish.yml will also publish when v* tags are pushed
- This fixes the issue where manual releases didn't trigger PyPI uploads
2025-07-24 17:00:21 -07:00
121 changed files with 3672 additions and 13891 deletions

View File

@@ -1,256 +1,11 @@
name: Build and Publish to PyPI
name: CI
on:
release:
types: [published]
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
publish:
description: 'Publish to PyPI'
required: true
default: 'false'
type: choice
options:
- 'false'
- 'test'
- 'prod'
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
# Build pure Python package: leann-core
build-core:
name: Build leann-core
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install build dependencies
run: |
uv pip install --system build twine
- name: Build package
run: |
cd packages/leann-core
uv build
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: leann-core-dist
path: packages/leann-core/dist/
# Build binary package: leann-backend-hnsw (default backend)
build-hnsw:
name: Build leann-backend-hnsw
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ['3.9', '3.10', '3.11', '3.12']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev libzmq3-dev \
pkg-config libopenblas-dev patchelf
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install libomp boost zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig
uv pip install --system auditwheel delocate
- name: Build wheel
run: |
cd packages/leann-backend-hnsw
uv build --wheel --python python
- name: Repair wheel (Linux)
if: runner.os == 'Linux'
run: |
cd packages/leann-backend-hnsw
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
- name: Repair wheel (macOS)
if: runner.os == 'macOS'
run: |
cd packages/leann-backend-hnsw
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: hnsw-${{ matrix.os }}-py${{ matrix.python-version }}
path: packages/leann-backend-hnsw/dist/
# Build binary package: leann-backend-diskann (multi-platform)
build-diskann:
name: Build leann-backend-diskann
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ['3.9', '3.10', '3.11', '3.12']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev libaio-dev libzmq3-dev \
protobuf-compiler libprotobuf-dev libabsl-dev patchelf
# Install Intel MKL using Intel's installer
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install libomp boost zeromq protobuf
# MKL is not available on Homebrew, but DiskANN can work without it
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Build wheel
run: |
cd packages/leann-backend-diskann
uv build --wheel --python python
- name: Repair wheel (Linux)
if: runner.os == 'Linux'
run: |
cd packages/leann-backend-diskann
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
- name: Repair wheel (macOS)
if: runner.os == 'macOS'
run: |
cd packages/leann-backend-diskann
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: diskann-${{ matrix.os }}-py${{ matrix.python-version }}
path: packages/leann-backend-diskann/dist/
# Build meta-package: leann (build last)
build-meta:
name: Build leann meta-package
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install build dependencies
run: |
uv pip install --system build
- name: Build package
run: |
cd packages/leann
uv build
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: leann-meta-dist
path: packages/leann/dist/
# Publish to PyPI
publish:
name: Publish to PyPI
needs: [build-core, build-hnsw, build-diskann, build-meta]
runs-on: ubuntu-latest
if: github.event_name == 'release' || github.event.inputs.publish != 'false'
steps:
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: dist
- name: Flatten directory structure
run: |
mkdir -p all_wheels
find dist -name "*.whl" -exec cp {} all_wheels/ \;
find dist -name "*.tar.gz" -exec cp {} all_wheels/ \;
- name: Publish to Test PyPI
if: github.event.inputs.publish == 'test' || github.event_name == 'workflow_dispatch'
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository-url: https://test.pypi.org/legacy/
packages-dir: all_wheels/
- name: Publish to PyPI
if: github.event_name == 'release' || github.event.inputs.publish == 'prod'
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
packages-dir: all_wheels/
build:
uses: ./.github/workflows/build-reusable.yml

167
.github/workflows/build-reusable.yml vendored Normal file
View File

@@ -0,0 +1,167 @@
name: Reusable Build
on:
workflow_call:
inputs:
ref:
description: 'Git ref to build'
required: false
type: string
default: ''
jobs:
build:
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
strategy:
matrix:
include:
- os: ubuntu-22.04
python: '3.9'
- os: ubuntu-22.04
python: '3.10'
- os: ubuntu-22.04
python: '3.11'
- os: ubuntu-22.04
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
- os: macos-latest
python: '3.9'
- os: macos-latest
python: '3.10'
- os: macos-latest
python: '3.11'
- os: macos-latest
python: '3.12'
- os: macos-latest
python: '3.13'
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
ref: ${{ inputs.ref }}
submodules: recursive
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install uv
uses: astral-sh/setup-uv@v4
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install llvm libomp boost protobuf zeromq
- name: Install build dependencies
run: |
uv pip install --system scikit-build-core numpy swig Cython pybind11
if [[ "$RUNNER_OS" == "Linux" ]]; then
uv pip install --system auditwheel
else
uv pip install --system delocate
fi
- name: Build packages
run: |
# Build core (platform independent)
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann-core
uv build
cd ../..
fi
# Build HNSW backend
cd packages/leann-backend-hnsw
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build DiskANN backend
cd packages/leann-backend-diskann
if [ "${{ matrix.os }}" == "macos-latest" ]; then
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
else
uv build --wheel --python python
fi
cd ../..
# Build meta package (platform independent)
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
cd packages/leann
uv build
cd ../..
fi
- name: Repair wheels (Linux)
if: runner.os == 'Linux'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
auditwheel repair dist/*.whl -w dist_repaired
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: Repair wheels (macOS)
if: runner.os == 'macOS'
run: |
# Repair HNSW wheel
cd packages/leann-backend-hnsw
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
# Repair DiskANN wheel
cd packages/leann-backend-diskann
if [ -d dist ]; then
delocate-wheel -w dist_repaired -v dist/*.whl
rm -rf dist
mv dist_repaired dist
fi
cd ../..
- name: List built packages
run: |
echo "📦 Built packages:"
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: packages-${{ matrix.os }}-py${{ matrix.python }}
path: packages/*/dist/

View File

@@ -1,110 +0,0 @@
name: CI - Build and Test
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build-test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Install system dependencies (Ubuntu)
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libomp-dev libboost-all-dev libzmq3-dev \
pkg-config libopenblas-dev patchelf \
libaio-dev protobuf-compiler libprotobuf-dev libabsl-dev
# Install Intel MKL for DiskANN
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'
run: |
brew install libomp boost zeromq protobuf
- name: Build all packages
run: |
echo "🔨 Building on ${{ matrix.os }} with Python ${{ matrix.python-version }}..."
export UV_SYSTEM_PYTHON=1
# Verify Python version
python --version
which python
# Build each package
for pkg in leann-core leann-backend-hnsw leann-backend-diskann leann; do
echo "Building $pkg..."
cd packages/$pkg
rm -rf dist/ build/ _skbuild/
# Use explicit python interpreter
uv build --wheel --python python
if [ ! -f dist/*.whl ]; then
echo "❌ Failed to build $pkg!"
exit 1
fi
echo "✅ $pkg built successfully"
cd ../..
done
- name: Install and test packages
run: |
# Create clean test environment
python -m venv test_env
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" ]]; then
source test_env/Scripts/activate
else
source test_env/bin/activate
fi
# Install built packages
pip install packages/*/dist/*.whl
# Basic import test
python -c "import leann; print('✅ LEANN imported successfully')"
python -c "import leann_backend_hnsw; print('✅ HNSW backend imported')"
python -c "import leann_backend_diskann; print('✅ DiskANN backend imported')"
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.os }}-py${{ matrix.python-version }}
path: packages/*/dist/*.whl
retention-days: 7
# Summary job to ensure all builds pass
ci-success:
needs: build-test
runs-on: ubuntu-latest
steps:
- name: CI Success
run: |
echo "✅ All CI builds passed!"
echo "Ready for manual release when needed."

View File

@@ -1,194 +1,126 @@
name: Manual Release
name: Release
on:
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g., 0.1.1)'
description: 'Version to release (e.g., 0.1.2)'
required: true
type: string
test_pypi:
description: 'Test on TestPyPI first'
required: false
type: boolean
default: true
jobs:
validate-and-release:
update-version:
name: Update Version
runs-on: ubuntu-latest
permissions:
contents: write
actions: read
outputs:
commit-sha: ${{ steps.push.outputs.commit-sha }}
steps:
- uses: actions/checkout@v4
- name: Validate version
run: |
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format"
exit 1
fi
echo "✅ Version format valid"
- name: Update versions and push
id: push
run: |
# Check current version
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
echo "Current version: $CURRENT_VERSION"
echo "Target version: ${{ inputs.version }}"
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
COMMIT_SHA=$(git rev-parse HEAD)
else
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
git push origin main
COMMIT_SHA=$(git rev-parse HEAD)
echo "✅ Pushed version update: $COMMIT_SHA"
fi
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
build-packages:
name: Build packages
needs: update-version
uses: ./.github/workflows/build-reusable.yml
with:
ref: 'main'
publish:
name: Publish and Release
needs: [update-version, build-packages]
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- uses: actions/checkout@v4
with:
token: ${{ secrets.GITHUB_TOKEN }}
ref: 'main'
- name: Check CI status
run: |
echo " This workflow will download build artifacts from the latest CI run."
echo " CI must have completed successfully on the current commit."
echo ""
- name: Validate version format
run: |
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid version format. Use semantic versioning (e.g., 0.1.1)"
exit 1
fi
echo "✅ Version format valid: ${{ inputs.version }}"
- name: Check if version already exists
run: |
if git tag | grep -q "^v${{ inputs.version }}$"; then
echo "❌ Version v${{ inputs.version }} already exists!"
exit 1
fi
echo "✅ Version is new"
- name: Set up Python
uses: actions/setup-python@v5
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
python-version: '3.13'
path: dist-artifacts
- name: Install uv
- name: Collect packages
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
mkdir -p dist
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
echo "📦 Packages to publish:"
ls -la dist/
- name: Update versions
run: |
./scripts/bump_version.sh ${{ inputs.version }}
git config user.name "GitHub Actions"
git config user.email "actions@github.com"
git add packages/*/pyproject.toml
git commit -m "chore: release v${{ inputs.version }}"
- name: Get CI run ID
id: get-ci-run
run: |
# Get the latest successful CI run on the previous commit (before version bump)
COMMIT_SHA=$(git rev-parse HEAD~1)
RUN_ID=$(gh run list \
--workflow="CI - Build and Test" \
--status=success \
--commit=$COMMIT_SHA \
--json databaseId \
--jq '.[0].databaseId')
if [ -z "$RUN_ID" ]; then
echo "❌ No successful CI run found for commit $COMMIT_SHA"
echo ""
echo "This usually means:"
echo "1. CI hasn't run on the latest commit yet"
echo "2. CI failed on the latest commit"
echo ""
echo "Please ensure CI passes on main branch before releasing."
exit 1
fi
echo "✅ Found CI run: $RUN_ID"
echo "run-id=$RUN_ID" >> $GITHUB_OUTPUT
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Download artifacts from CI
run: |
echo "📦 Downloading artifacts from CI run ${{ steps.get-ci-run.outputs.run-id }}..."
# Download all wheel artifacts
gh run download ${{ steps.get-ci-run.outputs.run-id }} \
--pattern "wheels-*" \
--dir ./dist-downloads
# Consolidate all wheels into packages/*/dist/
mkdir -p packages/leann-core/dist
mkdir -p packages/leann-backend-hnsw/dist
mkdir -p packages/leann-backend-diskann/dist
mkdir -p packages/leann/dist
find ./dist-downloads -name "*.whl" -exec cp {} ./packages/ \;
# Move wheels to correct package directories
for wheel in packages/*.whl; do
if [[ $wheel == *"leann_core"* ]]; then
mv "$wheel" packages/leann-core/dist/
elif [[ $wheel == *"leann_backend_hnsw"* ]]; then
mv "$wheel" packages/leann-backend-hnsw/dist/
elif [[ $wheel == *"leann_backend_diskann"* ]]; then
mv "$wheel" packages/leann-backend-diskann/dist/
elif [[ $wheel == *"leann-"* ]] && [[ $wheel != *"backend"* ]] && [[ $wheel != *"core"* ]]; then
mv "$wheel" packages/leann/dist/
fi
done
# List downloaded wheels
echo "✅ Downloaded wheels:"
find packages/*/dist -name "*.whl" -type f | sort
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Test on TestPyPI (optional)
if: inputs.test_pypi
continue-on-error: true
- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
if [ -z "$TWINE_PASSWORD" ]; then
echo "⚠️ TEST_PYPI_API_TOKEN not configured, skipping TestPyPI upload"
echo " To enable TestPyPI testing, add TEST_PYPI_API_TOKEN to repository secrets"
exit 0
echo "PYPI_API_TOKEN not configured!"
exit 1
fi
pip install twine
echo "📦 Uploading to TestPyPI..."
twine upload --repository testpypi packages/*/dist/* --verbose || {
echo "⚠️ TestPyPI upload failed, but continuing with release"
echo " This is optional and won't block the release"
exit 0
}
echo "✅ Test upload successful!"
echo "📋 Check packages at: https://test.pypi.org/user/your-username/"
echo ""
echo "To test installation:"
echo "pip install -i https://test.pypi.org/simple/ leann"
twine upload dist/* --skip-existing --verbose
echo "✅ Published to PyPI!"
- name: Create and push tag
- name: Create release
run: |
git tag "v${{ inputs.version }}"
git push origin main
git push origin "v${{ inputs.version }}"
echo "✅ Tag v${{ inputs.version }} created and pushed"
- name: Create GitHub Release
uses: softprops/action-gh-release@v1
with:
tag_name: v${{ inputs.version }}
name: Release v${{ inputs.version }}
body: |
## 🚀 Release v${{ inputs.version }}
### What's Changed
See the [full changelog](https://github.com/${{ github.repository }}/compare/...v${{ inputs.version }})
### Installation
```bash
pip install leann==${{ inputs.version }}
```
### Test Installation (if using TestPyPI)
```bash
pip install -i https://test.pypi.org/simple/ leann==${{ inputs.version }}
```
draft: false
prerelease: false
# Check if tag already exists
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
else
git tag "v${{ inputs.version }}"
git push origin "v${{ inputs.version }}"
echo "✅ Created and pushed tag v${{ inputs.version }}"
fi
# Check if release already exists
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
else
gh release create "v${{ inputs.version }}" \
--title "Release v${{ inputs.version }}" \
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
--latest
echo "✅ Created GitHub release v${{ inputs.version }}"
fi
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Trigger PyPI publish
run: |
echo "🚀 Triggering PyPI publish workflow..."
# The existing build-and-publish.yml will be triggered by the tag push
echo "✅ Release process completed! The publish workflow will run automatically."
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

164
README.md
View File

@@ -26,7 +26,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
<img src="assets/effects.png" alt="LEANN vs Traditional Vector DB Storage Comparison" width="70%">
</p>
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-usage-comparison)
> **The numbers speak for themselves:** Index 60 million Wikipedia chunks in just 6GB instead of 201GB. From emails to browser history, everything fits on your laptop. [See detailed benchmarks for different applications below ↓](#storage-comparison)
🔒 **Privacy:** Your data never leaves your laptop. No OpenAI, no cloud, no "terms of service".
@@ -38,7 +38,7 @@ LEANN achieves this through *graph-based selective recomputation* with *high-deg
**No Accuracy Loss:** Maintain the same search quality as heavyweight solutions while using 97% less storage.
## Installation
> `pip leann` coming soon!
```bash
git clone git@github.com:yichuan-w/LEANN.git leann
cd leann
@@ -94,7 +94,7 @@ ollama pull llama3.2:1b
## Quick Start in 30s
Our declarative API makes RAG as easy as writing a config file.
[Try in this ipynb file →](demo.ipynb)
[Try in this ipynb file →](demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/yichuan-w/LEANN/blob/main/demo.ipynb)
```python
from leann.api import LeannBuilder, LeannSearcher, LeannChat
@@ -133,6 +133,10 @@ LEANN supports RAG on various data sources including documents (.pdf, .txt, .md)
Ask questions directly about your personal PDFs, documents, and any directory containing your files!
<p align="center">
<img src="videos/paper_clear.gif" alt="LEANN Document Search Demo" width="600">
</p>
The example below asks a question about summarizing two papers (uses default data in `examples/data`):
```bash
@@ -150,6 +154,10 @@ python ./examples/main_cli_example.py
### 📧 Your Personal Email Secretary: RAG on Apple Mail!
<p align="center">
<img src="videos/mail_clear.gif" alt="LEANN Email Search Demo" width="600">
</p>
**Note:** You need to grant full disk access to your terminal/VS Code in System Preferences → Privacy & Security → Full Disk Access.
```bash
python examples/mail_reader_leann.py --query "What's the food I ordered by doordash or Uber eat mostly?"
@@ -187,7 +195,12 @@ Once the index is built, you can ask questions like:
- "Show me emails about travel expenses"
</details>
### 🔍 Time Machine for the Web: RAG Your Entire Google Browser History!
### 🔍 Time Machine for the Web: RAG Your Entire Chrome Browser History!
<p align="center">
<img src="videos/google_clear.gif" alt="LEANN Browser History Search Demo" width="600">
</p>
```bash
python examples/google_history_reader_leann.py --query "Tell me my browser history about machine learning?"
```
@@ -242,6 +255,10 @@ Once the index is built, you can ask questions like:
### 💬 WeChat Detective: Unlock Your Golden Memories!
<p align="center">
<img src="videos/wechat_clear.gif" alt="LEANN WeChat Search Demo" width="600">
</p>
```bash
python examples/wechat_history_reader_leann.py --query "Show me all group chats about weekend plans"
```
@@ -383,46 +400,18 @@ Options:
## Benchmarks
Run the comparison yourself:
```bash
python examples/compare_faiss_vs_leann.py
```
| System | Storage |
|--------|---------|
| FAISS HNSW | 5.5 MB |
| LEANN | 0.5 MB |
| **Savings** | **91%** |
📊 **[Simple Example: Compare LEANN vs FAISS →](examples/compare_faiss_vs_leann.py)**
### Storage Comparison
Same dataset, same hardware, same embedding model. LEANN just works better.
| System | DPR (2.1M) | Wiki (60M) | Chat (400K) | Email (780K) | Browser (38K) |
|--------|-------------|------------|-------------|--------------|---------------|
| Traditional vector database (e.g., FAISS) | 3.8 GB | 201 GB | 1.8 GB | 2.4 GB | 130 MB |
| LEANN | 324 MB | 6 GB | 64 MB | 79 MB | 6.4 MB |
| Savings| 91% | 97% | 97% | 97% | 95% |
### Storage Usage Comparison
| System | DPR (2.1M chunks) | RPJ-wiki (60M chunks) | Chat history (400K messages) | Apple emails (780K messages chunks) |Google Search History (38K entries)
|-----------------------|------------------|------------------------|-----------------------------|------------------------------|------------------------------|
| Traditional Vector DB(FAISS) | 3.8 GB | 201 GB | 1.8G | 2.4G |130.4 MB |
| **LEANN** | **324 MB** | **6 GB** | **64 MB** | **79 MB** |**6.4MB** |
| **Reduction** | **91% smaller** | **97% smaller** | **97% smaller** | **97% smaller** |**95% smaller** |
<!-- ### Memory Usage Comparison
| System j | DPR(2M docs) | RPJ-wiki(60M docs) | Chat history() |
| --------------------- | ---------------- | ---------------- | ---------------- |
| Traditional Vector DB(LLamaindex faiss) | x GB | x GB | x GB |
| **Leann** | **xx MB** | **x GB** | **x GB** |
| **Reduction** | **x%** | **x%** | **x%** |
### Query Performance of LEANN
| Backend | Index Size | Query Time | Recall@3 |
| ------------------- | ---------- | ---------- | --------- |
| DiskANN | 1M docs | xms | 0.95 |
| HNSW | 1M docs | xms | 0.95 | -->
*Benchmarks run on Apple M3 Pro 36 GB*
## Reproduce Our Results
```bash
@@ -450,98 +439,15 @@ If you find Leann useful, please cite:
}
```
## ✨ Features
## ✨ [Detailed Features →](docs/features.md)
### 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
### 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
### 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment
## 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
### Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results
## 🤝 [Contributing →](docs/contributing.md)
<!-- ## FAQ
### Common Issues
#### NCCL Topology Error
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
```
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
```
**Solution**: Set these environment variables before running your script:
```bash
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
export NCCL_IB_DISABLE=1
export NCCL_NET_PLUGIN=none
export NCCL_SOCKET_IFNAME=ens5
``` -->
## FAQ
### 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
## [FAQ →](docs/faq.md)
## 📈 Roadmap
### 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
### 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
### 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion
## 📈 [Roadmap →](docs/roadmap.md)
## 📄 License
@@ -549,11 +455,7 @@ MIT License - see [LICENSE](LICENSE) for details.
## 🙏 Acknowledgments
- **Microsoft Research** for the DiskANN algorithm
- **Meta AI** for FAISS and optimization insights
- **HuggingFace** for the transformer ecosystem
- **Our amazing contributors** who make this possible
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/)
---
<p align="center">

View File

@@ -13,8 +13,12 @@
"metadata": {},
"outputs": [],
"source": [
"# install this if you areusing colab\n",
"! pip install leann"
"# install this if you are using colab\n",
"! pip install leann\n",
"\n",
"# For Colab environment, we need to set some environment variables\n",
"import os\n",
"os.environ['LEANN_LOG_LEVEL'] = 'INFO' # Enable more detailed logging"
]
},
{
@@ -26,81 +30,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"M: 64 for level: 0\n",
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
"[0.00s] Reading Index HNSW header...\n",
"[0.00s] Header read: d=768, ntotal=5\n",
"[0.00s] Reading HNSW struct vectors...\n",
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
"[0.00s] Read assign_probas (6)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
"[0.11s] Read cum_nneighbor_per_level (7)\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
"[0.21s] Read levels (5)\n",
"[0.30s] Probing for compact storage flag...\n",
"[0.30s] Found compact flag: False\n",
"[0.30s] Compact flag is False, reading original format...\n",
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
"[0.30s] Read offsets (6)\n",
"[0.40s] Attempting to read neighbors vector...\n",
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
"[0.40s] Read neighbors (320)\n",
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
"[0.50s] Checking for storage data...\n",
"[0.50s] Found storage fourcc: 49467849.\n",
"[0.50s] Converting to CSR format...\n",
"[0.50s] Conversion loop finished. \n",
"[0.50s] Running validation checks...\n",
" Checking total valid neighbor count...\n",
" OK: Total valid neighbors = 20\n",
" Checking final pointer indices...\n",
" OK: Final pointers match data size.\n",
"[0.50s] Deleting original neighbors and offsets arrays...\n",
" CSR Stats: |data|=20, |level_ptr|=10\n",
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
" Pruning embeddings: Writing NULL storage marker.\n",
"[0.69s] Conversion complete.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
]
}
],
"outputs": [],
"source": [
"from leann.api import LeannBuilder\n",
"\n",
@@ -122,93 +54,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'programming languages'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n",
"INFO: Registering backend 'hnsw'\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n"
]
},
{
"data": {
"text/plain": [
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from leann.api import LeannSearcher\n",
"\n",
@@ -228,79 +76,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
"[read_HNSW NL v4] Read levels vector, size: 5\n",
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
"INFO: Skipping external storage loading, since is_recompute is true.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
"INFO:leann.api: Top_k: 2\n",
"INFO:leann.api: Additional kwargs: {}\n",
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
"INFO:leann.api: Backend returned: labels=2 results\n",
"INFO:leann.api: Processing 2 passage IDs:\n",
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
"INFO:leann.api: Final enriched results: 2 passages\n",
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ZmqDistanceComputer initialized: d=768, metric=0\n"
]
},
{
"data": {
"text/plain": [
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from leann.api import LeannChat\n",
"\n",

View File

@@ -1,93 +1,22 @@
# Release Guide
## 📋 Prerequisites
## Setup (One-time)
Before releasing, ensure:
1. ✅ All code changes are committed and pushed
2. ✅ CI has passed on the latest commit (check [Actions](https://github.com/yichuan-w/LEANN/actions/workflows/ci.yml))
3. ✅ You have determined the new version number
Add `PYPI_API_TOKEN` to GitHub Secrets:
1. Get token: https://pypi.org/manage/account/token/
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
### Optional: TestPyPI Configuration
## Release (One-click)
To enable TestPyPI testing (recommended but not required):
1. Get a TestPyPI API token from https://test.pypi.org/manage/account/token/
2. Add it to repository secrets: Settings → Secrets → Actions → New repository secret
- Name: `TEST_PYPI_API_TOKEN`
- Value: Your TestPyPI token (starts with `pypi-`)
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
2. Click "Run workflow"
3. Enter version: `0.1.2`
4. Click green "Run workflow" button
**Note**: TestPyPI testing is optional. If not configured, the release will skip TestPyPI and proceed.
That's it! The workflow will automatically:
- ✅ Update version in all packages
- ✅ Build all packages
- ✅ Publish to PyPI
- ✅ Create GitHub tag and release
## 🚀 Recommended: Manual Release Workflow
### Via GitHub UI (Most Reliable)
1. **Verify CI Status**: Check that the latest commit has a green checkmark ✅
2. Go to [Actions → Manual Release](https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml)
3. Click "Run workflow"
4. Enter version (e.g., `0.1.1`)
5. Toggle "Test on TestPyPI first" if desired
6. Click "Run workflow"
**What happens:**
- ✅ Validates version format
- ✅ Downloads pre-built packages from CI (no rebuild needed!)
- ✅ Updates all package versions
- ✅ Optionally tests on TestPyPI
- ✅ Creates tag and GitHub release
- ✅ Automatically triggers PyPI publish
### Via Command Line
```bash
gh workflow run release-manual.yml -f version=0.1.1 -f test_pypi=true
```
## ⚡ Quick Release (One-Line)
For experienced users who want the fastest path:
```bash
./scripts/release.sh 0.1.1
```
This script will:
1. Update all package versions
2. Commit and push changes
3. Create GitHub release
4. CI automatically builds and publishes to PyPI
⚠️ **Note**: If CI fails, you'll need to manually fix and re-tag
## Manual Testing Before Release
For testing specific packages locally (especially DiskANN on macOS):
```bash
# Build specific package locally
./scripts/build_and_test.sh diskann # or hnsw, core, meta, all
# Test installation in a clean environment
python -m venv test_env
source test_env/bin/activate
pip install packages/*/dist/*.whl
# Upload to Test PyPI (optional)
./scripts/upload_to_pypi.sh test
# Upload to Production PyPI (use with caution)
./scripts/upload_to_pypi.sh prod
```
## First-time setup
1. Install GitHub CLI:
```bash
brew install gh
gh auth login
```
2. Set PyPI token in GitHub:
```bash
gh secret set PYPI_API_TOKEN
# Paste your PyPI token when prompted
```
Check progress: https://github.com/yichuan-w/LEANN/actions

11
docs/contributing.md Normal file
View File

@@ -0,0 +1,11 @@
# 🤝 Contributing
We welcome contributions! Leann is built by the community, for the community.
## Ways to Contribute
- 🐛 **Bug Reports**: Found an issue? Let us know!
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
- 🔧 **Code Contributions**: PRs welcome for all skill levels
- 📖 **Documentation**: Help make Leann more accessible
- 🧪 **Benchmarks**: Share your performance results

10
docs/faq.md Normal file
View File

@@ -0,0 +1,10 @@
# FAQ
## 1. My building time seems long
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
```bash
--embedding-model sentence-transformers/all-MiniLM-L6-v2
```
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)

22
docs/features.md Normal file
View File

@@ -0,0 +1,22 @@
# ✨ Detailed Features
## 🔥 Core Features
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
- **🏗️ Pluggable Backends** - DiskANN, HNSW/FAISS with unified API
## 🛠️ Technical Highlights
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](test/build_mlx_index.py))
## 🎨 Developer Experience
- **Simple Python API** - Get started in minutes
- **Extensible backend system** - Easy to add new algorithms
- **Comprehensive examples** - From basic usage to production deployment

21
docs/roadmap.md Normal file
View File

@@ -0,0 +1,21 @@
# 📈 Roadmap
## 🎯 Q2 2025
- [X] DiskANN backend with MIPS/L2/Cosine support
- [X] HNSW backend integration
- [X] Real-time embedding pipeline
- [X] Memory-efficient graph pruning
## 🚀 Q3 2025
- [ ] Advanced caching strategies
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
- [ ] Add OpenAI recompute API
## 🌟 Q4 2025
- [ ] Integration with LangChain/LlamaIndex
- [ ] Visual similarity search
- [ ] Query rewrtiting, rerank and expansion

View File

@@ -135,6 +135,7 @@ def test_leann_hnsw():
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
print(f"Total number of chunks: {len(all_texts)}")
tracker.checkpoint("After text chunking")

View File

@@ -37,7 +37,7 @@ def main():
import faiss
except ImportError:
print("Faiss is not installed.")
print("Please install it with `uv pip install faiss-cpu`")
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
sys.exit(1)
from llama_index.core import (

View File

@@ -97,11 +97,13 @@ def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], i
# Use HNSW backend for better macOS compatibility
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="facebook/contriever",
embedding_model="text-embedding-3-small",
embedding_mode="openai",
graph_degree=32,
complexity=64,
is_compact=True,
is_recompute=True,
is_compact=False,
is_recompute=False,
num_threads=1 # Force single-threaded mode
)
@@ -222,14 +224,15 @@ async def query_leann_index(index_path: str, query: str):
"max_tokens": 1000
}
)
print(f"Leann: {chat_response}")
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
parser.add_argument('--index-dir', type=str, default="./all_google_new",
parser.add_argument('--index-dir', type=str, default="./google_history_index",
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
parser.add_argument('--max-entries', type=int, default=1000,
help='Maximum number of history entries to process (default: 1000)')

View File

@@ -224,15 +224,16 @@ async def query_leann_index(index_path: str, query: str):
beam_width=1,
)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
print(f"Leann: {chat_response}")
# print(f"Time taken: {end_time - start_time} seconds")
# highlight the answer
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
# Remove --mail-path argument and auto-detect all Messages directories
# Remove DEFAULT_MAIL_PATH
parser.add_argument('--index-dir', type=str, default="./mail_index_index_file",
parser.add_argument('--index-dir', type=str, default="./mail_index",
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
parser.add_argument('--max-emails', type=int, default=1000,
help='Maximum number of emails to process (-1 means all)')

View File

@@ -63,16 +63,14 @@ async def main(args):
llm_config = {"type": "openai", "model": "gpt-4o"}
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
# query = (
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
# )
query = args.query
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
print(f"Leann: {chat_response}")
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
if __name__ == "__main__":
@@ -110,6 +108,12 @@ if __name__ == "__main__":
default="examples/data",
help="Directory containing documents to index (PDF, TXT, MD files).",
)
parser.add_argument(
"--query",
type=str,
default="Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?",
help="The query to ask the Leann chat system.",
)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -234,7 +234,7 @@ async def query_leann_index(index_path: str, query: str):
},
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
)
print(f"Leann: {chat_response}")
print(f"Leann chat response: \033[36m{chat_response}\033[0m")
async def main():

View File

@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.1.1"
dependencies = ["leann-core==0.1.0", "numpy"]
version = "0.1.12"
dependencies = ["leann-core==0.1.12", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -48,6 +48,10 @@ class HNSWBuilder(LeannBackendBuilderInterface):
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute:
if self.is_compact:
# TODO: support this case @andy
raise ValueError("is_recompute is False, but is_compact is True. This is not compatible now. change is compact to False and you can use the original HNSW index.")
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
from . import faiss # type: ignore
@@ -92,8 +96,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if success:
logger.info("✅ CSR conversion successful.")
index_file_old = index_file.with_suffix(".old")
shutil.move(str(index_file), str(index_file_old))
# index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
logger.info(
f"INFO: Replaced original index with {mode_str} version at '{index_file}'"

View File

@@ -81,7 +81,21 @@ def create_hnsw_embedding_server(
with open(passages_file, "r") as f:
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
# Convert relative paths to absolute paths based on metadata file location
metadata_dir = Path(
passages_file
).parent.parent # Go up one level from the metadata file
passage_sources = []
for source in meta["passage_sources"]:
source_copy = source.copy()
# Convert relative paths to absolute paths
if not Path(source_copy["path"]).is_absolute():
source_copy["path"] = str(metadata_dir / source_copy["path"])
if not Path(source_copy["index_path"]).is_absolute():
source_copy["index_path"] = str(metadata_dir / source_copy["index_path"])
passage_sources.append(source_copy)
passages = PassageManager(passage_sources)
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
@@ -270,15 +284,15 @@ def create_hnsw_embedding_server(
if __name__ == "__main__":
import signal
import sys
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down gracefully...")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument(

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.1.1"
version = "0.1.12"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.1.0",
"leann-core==0.1.12",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.1.1"
version = "0.1.12"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"
@@ -21,6 +21,16 @@ dependencies = [
"sentence-transformers>=2.2.0",
"llama-index-core>=0.12.0",
"python-dotenv>=1.0.0",
"openai>=1.0.0",
"huggingface-hub>=0.20.0",
"transformers>=4.30.0",
"requests>=2.25.0",
"accelerate>=0.20.0",
"PyPDF2>=3.0.0",
"pymupdf>=1.23.0",
"pdfplumber>=0.10.0",
"mlx>=0.26.3; sys_platform == 'darwin'",
"mlx-lm>=0.26.0; sys_platform == 'darwin'",
]
[project.scripts]

View File

@@ -441,9 +441,9 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
logger.info(f" Generated embedding shape: {query_embedding.shape}")
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
logger.info(f" Embedding time: {embedding_time} seconds")
# logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
results = self.backend_impl.search(
@@ -458,7 +458,7 @@ class LeannSearcher:
**kwargs,
)
search_time = time.time() - start_time
logger.info(f" Search time: {search_time} seconds")
# logger.info(f" Search time: {search_time} seconds")
logger.info(
f" Backend returned: labels={len(results.get('labels', [[]])[0])} results"
)
@@ -479,15 +479,25 @@ class LeannSearcher:
metadata=passage_data.get("metadata", {}),
)
)
# Color codes for better logging
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
RESET = "\033[0m"
# Truncate text for display (first 100 chars)
display_text = passage_data['text']
logger.info(
f" {i + 1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text']}..."
f" {GREEN}{RESET} {BLUE}[{i + 1:2d}]{RESET} {YELLOW}ID:{RESET} '{string_id}' {YELLOW}Score:{RESET} {dist:.4f} {YELLOW}Text:{RESET} {display_text}"
)
except KeyError:
RED = "\033[91m"
logger.error(
f" {i + 1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!"
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
logger.info(f" Final enriched results: {len(enriched_results)} passages")
logger.info(f" {GREEN} Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results
@@ -517,7 +527,7 @@ class LeannChat:
):
if llm_kwargs is None:
llm_kwargs = {}
search_time = time.time()
results = self.searcher.search(
question,
top_k=top_k,
@@ -529,6 +539,8 @@ class LeannChat:
expected_zmq_port=expected_zmq_port,
**search_kwargs,
)
search_time = time.time() - search_time
# logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"

View File

@@ -7,6 +7,33 @@ from llama_index.core.node_parser import SentenceSplitter
from .api import LeannBuilder, LeannSearcher, LeannChat
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
"""Extract text from PDF using PyMuPDF for better quality."""
try:
import fitz # PyMuPDF
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
doc.close()
return text
except ImportError:
# Fallback to default reader
return None
def extract_pdf_text_with_pdfplumber(file_path: str) -> str:
"""Extract text from PDF using pdfplumber for better quality."""
try:
import pdfplumber
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
except ImportError:
# Fallback to default reader
return None
class LeannCLI:
def __init__(self):
@@ -145,12 +172,42 @@ Examples:
def load_documents(self, docs_dir: str):
print(f"Loading documents from {docs_dir}...")
documents = SimpleDirectoryReader(
# Try to use better PDF parsers first
documents = []
docs_path = Path(docs_dir)
for file_path in docs_path.rglob("*.pdf"):
print(f"Processing PDF: {file_path}")
# Try PyMuPDF first (best quality)
text = extract_pdf_text_with_pymupdf(str(file_path))
if text is None:
# Try pdfplumber
text = extract_pdf_text_with_pdfplumber(str(file_path))
if text:
# Create a simple document structure
from llama_index.core import Document
doc = Document(text=text, metadata={"source": str(file_path)})
documents.append(doc)
else:
# Fallback to default reader
print(f"Using default reader for {file_path}")
default_docs = SimpleDirectoryReader(
str(file_path.parent),
filename_as_id=True,
required_exts=[file_path.suffix],
).load_data()
documents.extend(default_docs)
# Load other file types with default reader
other_docs = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md", ".docx"],
required_exts=[".txt", ".md", ".docx"],
).load_data(show_progress=True)
documents.extend(other_docs)
all_texts = []
for doc in documents:

View File

@@ -264,9 +264,10 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
logger.info(
f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'"
)
print(f"len of texts: {len(texts)}")
# OpenAI has limits on batch size and input length
max_batch_size = 100 # Conservative batch size
max_batch_size = 1000 # Conservative batch size
all_embeddings = []
try:
@@ -296,6 +297,7 @@ def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray:
logger.info(
f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}"
)
print(f"len of embeddings: {len(embeddings)}")
return embeddings

View File

@@ -18,6 +18,24 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
def _is_colab_environment() -> bool:
"""Check if we're running in Google Colab environment."""
return "COLAB_GPU" in os.environ or "COLAB_TPU" in os.environ
def _get_available_port(start_port: int = 5557) -> int:
"""Get an available port starting from start_port."""
port = start_port
while port < start_port + 100: # Try up to 100 ports
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available ports found in range {start_port}-{start_port+100}")
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -175,48 +193,59 @@ class EmbeddingServerManager:
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""
Starts the embedding server process.
Args:
port (int): The preferred ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server.
Returns:
tuple[bool, int]: (success, actual_port_used)
"""
"""Start the embedding server."""
passages_file = kwargs.get("passages_file")
assert isinstance(passages_file, str), "passages_file must be a string"
# Check if we have a compatible running server
# Check if we have a compatible server already running
if self._has_compatible_running_server(model_name, passages_file):
assert self.server_port is not None, (
"a compatible running server should set server_port"
)
return True, self.server_port
logger.info("Found compatible running server!")
return True, port
# Find available port (compatible or free)
try:
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
except RuntimeError as e:
logger.error(str(e))
return False, port
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
# Find a compatible port or next available
actual_port, is_compatible = _find_compatible_port_or_next_available(
port, model_name, passages_file
)
if is_compatible:
logger.info(f"Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
logger.info(f"Found compatible server on port {actual_port}")
return True, actual_port
if actual_port != port:
logger.info(f"Using port {actual_port} instead of {port}")
# Start new server
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
# Try to find an available port
try:
actual_port = _get_available_port(port)
except RuntimeError:
logger.error("No available ports found")
return False, port
logger.info(f"Starting server on port {actual_port} for Colab environment")
# Use a simpler startup strategy for Colab
command = self._build_server_command(actual_port, model_name, embedding_mode, **kwargs)
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
def _has_compatible_running_server(
self, model_name: str, passages_file: str
) -> bool:
@@ -269,7 +298,9 @@ class EmbeddingServerManager:
]
if kwargs.get("passages_file"):
command.extend(["--passages-file", str(kwargs["passages_file"])])
# Convert to absolute path to ensure subprocess can find the file
passages_file = Path(kwargs["passages_file"]).resolve()
command.extend(["--passages-file", str(passages_file)])
if embedding_mode != "sentence-transformers":
command.extend(["--embedding-mode", embedding_mode])
@@ -346,3 +377,45 @@ class EmbeddingServerManager:
pass
self.server_process = None
def _launch_server_process_colab(self, command: list, port: int) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
# Register atexit callback
if not self._atexit_registered:
atexit.register(lambda: self.stop_server() if self.server_process else None)
self._atexit_registered = True
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout."""
max_wait, wait_interval = 30, 0.5 # Shorter timeout for Colab
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
logger.info("Colab embedding server is ready!")
return True, port
if self.server_process and self.server_process.poll() is not None:
# Check for error output
stdout, stderr = self.server_process.communicate()
logger.error(f"Colab server terminated during startup.")
logger.error(f"stdout: {stdout}")
logger.error(f"stderr: {stderr}")
return False, port
time.sleep(wait_interval)
logger.error(f"Colab server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port

View File

@@ -112,8 +112,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
)
# Convert to absolute path to ensure server can find it
zmq_port = self._ensure_server_running(
str(passages_source_file), zmq_port
str(passages_source_file.resolve()), zmq_port
)
return self._compute_embedding_via_server([query], zmq_port)[

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.1.1"
version = "0.1.12"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -1,12 +0,0 @@
import faiss
hnsw_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/hnsw_IP_M30_efC128.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(hnsw_index.ntotal)
# print stats of the graph
print(hnsw_index.hnsw.print_neighbor_stats(0))
# save_degree_distribution
hnsw_index.hnsw.save_degree_distribution(0, "degree_distribution_HNSW_M30.txt")

View File

@@ -1,11 +0,0 @@
import faiss
nsg_index = faiss.read_index("/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/nsg_R16.index", faiss.IO_FLAG_ONDISK_SAME_DIR)
# print total number of nodes
print(nsg_index.ntotal)
# print stats of the graph
print(nsg_index.nsg.print_neighbor_stats(0))
# save degree distribution
nsg_index.nsg.save_degree_distribution("degree_distribution_NSG_R60.txt")

View File

@@ -1,63 +0,0 @@
import torch
import torch.nn as nn
import time
# import bitsandbytes as bnb
from bitsandbytes.nn import Linear8bitLt
# set default to half
import torch
torch.set_default_dtype(torch.float16)
M = 2048
N = 2048
bsz = 2048
import torch_int
from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU
fp16_model = nn.Sequential(
nn.Linear(M, N),
# nn.Linear(2048, 2048)
)
int8_model = nn.Sequential(
Linear8bitLt(M, N, has_fp16_weights=False),
# Linear8bitLt(2048, 2048, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
fp16_model = fp16_model.to(0) # Move fp16 model to GPU as well
# Create random input tensor
input_tensor = torch.randn(bsz, M, device=0) # Batch of 1000 vectors
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
fp16_time = speed_test(fp16_model, input_tensor, "FP16")
int8_time = speed_test(int8_model, input_tensor, "INT8")
# Calculate speedup
speedup = fp16_time / int8_time
print(f"INT8 is {speedup:.2f}x faster than FP16")

View File

@@ -1,89 +0,0 @@
n,d,seqlen,bs,latency,h,flop,io,intensity,throughput,series
3,256,256,2048,0.009623501679245285,768,618475290624,167.48502132816208,3692720015.912285,64267177503366.266,dense
3,256,256,1024,0.004853848615384615,768,309237645312,166.15392854317415,1861151572.059558,63709783682138.234,dense
3,256,256,512,0.0024687246971962615,768,154618822656,163.57953256539062,945221081.3366361,62631051097597.516,dense
3,256,256,256,0.0012845360838052097,768,77309411328,157.64931990085577,490388486.1451936,60184694149645.54,dense
3,256,256,128,0.0006901147179878049,768,38654705664,147.57393422494675,261934506.70684624,56012000116019.945,dense
3,256,256,64,0.0003363830693015702,768,19327352832,153.1328437752606,126212981.84970059,57456378146882.51,dense
3,256,256,32,0.00018671159748991485,768,9663676416,141.10249365427362,68486928.65540518,51757237075334.75,dense
3,256,256,16,0.00012353640857142858,768,4831838208,111.40488993609125,43371868.24359184,39112665358133.98,dense
3,256,256,8,9.774760007849294e-05,768,2415919104,76.43260800265766,31608487.09906635,24715891766754.14,dense
3,256,256,4,6.672271167474822e-05,768,1207959552,64.82614227498455,18633833.660438772,18104173551704.773,dense
3,256,256,2,4.9758770289855074e-05,768,603979776,55.317122669351576,10918495.880745342,12138157202874.861,dense
3,256,1,2048,9.785507940251571e-05,768,2415919104,76.34865809334705,31643242.518371396,24688745017132.86,dense
3,256,1,1024,6.692813470149253e-05,768,1207959552,64.62717090938949,18691202.70936228,18048606275785.867,dense
3,256,1,512,4.9680950036205655e-05,768,603979776,55.40377142534654,10901419.893658841,12157170415618.898,dense
3,256,1,256,4.2781118741058655e-05,768,301989888,45.95672244805227,6571179.83862661,7058952568020.829,dense
3,256,1,128,5.0662328255350016e-05,768,150994944,31.046026784880404,4863583.512513602,2980418571348.519,dense
3,256,1,64,4.475009253945481e-05,768,75497472,30.75426042497223,2454862.219307235,1687090857598.4766,dense
3,256,1,32,4.51682671454219e-05,768,37748736,28.29313765537115,1334201.1218340008,835735758435.5786,dense
3,256,1,16,5.03585186661834e-05,768,18874368,24.401035466223117,773506.846712577,374799904761.1871,dense
3,256,1,8,5.023459565217391e-05,768,9437184,23.972005435021096,393675.19858030166,187862246674.45105,dense
3,256,1,4,5.053219391083726e-05,768,4718592,23.58765586356967,200044.97383259286,93377936614.54384,dense
3,256,1,2,4.4607398995335484e-05,768,2359296,26.58285456464288,88752.54515134107,52890239133.797226,dense
12,256,256,2048,0.14480779847058822,3072,9895604649984,44.620009282941716,221775046868.20184,68336130750540.26,dense
12,256,256,1024,0.07254347629166667,3072,4947802324992,44.664248332585096,110777691547.58836,68204648824643.82,dense
12,256,256,512,0.036310761444444443,3072,2473901162496,44.876147984203506,55127306456.13385,68131349056975.164,dense
12,256,256,256,0.01821551906896552,3072,1236950581248,45.24607467289738,27338295977.947884,67906414116709.98,dense
12,256,256,128,0.009229417903030302,3072,618475290624,45.67217092440895,13541622351.335684,67011299859001.46,dense
12,256,256,64,0.004754550595394737,3072,309237645312,46.31372736116993,6677019167.566916,65040352207320.695,dense
12,256,256,32,0.002405752659340659,3072,154618822656,49.68826015254682,3111777755.5766335,64270456921525.82,dense
12,256,256,16,0.0012287219045005488,3072,77309411328,56.323579604557374,1372594069.3184311,62918558743709.18,dense
12,256,256,8,0.0006206816149425287,3072,38654705664,70.95456179103653,544781120.315271,62277832520589.78,dense
12,256,256,4,0.0003875502697142857,3072,19327352832,81.16954743236613,238110885.71245712,49870569942445.75,dense
12,256,256,2,0.00027502018627941914,3072,9663676416,91.50537035282076,105607751.53129694,35138062215483.168,dense
12,256,1,2048,0.0006202853873290136,3072,38654705664,70.99988634205897,544433345.6784943,62317614526515.766,dense
12,256,1,1024,0.00038721467732724153,3072,19327352832,81.2398957010995,237904697.74985722,49913791918755.53,dense
12,256,1,512,0.000274364799,3072,9663676416,91.72395326121995,105356082.81599998,35221998052308.45,dense
12,256,1,256,0.00012488918589482266,3072,4831838208,176.31707535146046,27404255.647778228,38689003962834.75,dense
12,256,1,128,8.976711102514506e-05,3072,2415919104,227.78088507574267,10606329.425740216,26913187652026.21,dense
12,256,1,64,8.715176287471176e-05,3072,1207959552,225.59268282689945,5354604.31102229,13860414432884.701,dense
12,256,1,32,8.523013435114503e-05,3072,603979776,226.06539514085782,2671703.8033338524,7086458100741.991,dense
12,256,1,16,7.901561645904116e-05,3072,301989888,241.35704882952732,1251216.3595988373,3821901309300.556,dense
12,256,1,8,7.827949114210329e-05,3072,150994944,242.37091635608994,622991.1833900034,1928920867994.581,dense
12,256,1,4,7.779445951035782e-05,3072,75497472,243.25022783249054,310369.58391664835,970473636235.5986,dense
12,256,1,2,7.758845406626506e-05,3072,37748736,243.57933441822672,154975.11761480253,486525172518.07056,dense
3,256,256,2048,0.00507974918466899,768,206158430208,475.59810852303485,433471930.42508715,40584371927298.98,qk_init
3,256,256,1024,0.0025616677649325623,768,103079215104,471.5519977009198,218595649.27424532,40239103803811.82,qk_init
3,256,256,512,0.0013029336670480549,768,51539607552,463.55374128015677,111183672.92143403,39556585922573.38,qk_init
3,256,256,256,0.0006738189029345373,768,25769803776,448.1766342333362,57499213.050413854,38244406121244.69,qk_init
3,256,256,128,0.000358254672959467,768,12884901888,421.47375986100144,30571065.425874516,35965760841472.125,qk_init
3,256,256,64,0.0002007051105022831,768,6442450944,376.1611839930762,17126836.096194826,32099087700742.5,qk_init
3,256,256,32,0.00012189697230142565,768,3221225472,309.6773881032524,10401874.969721656,26425803784810.87,qk_init
3,256,256,16,8.453561698040722e-05,768,1610612736,223.2711923587723,7213705.982328083,19052475081281.902,qk_init
3,256,256,8,6.407660705009276e-05,768,805306368,147.2797083750448,5467870.468274581,12567868448003.822,qk_init
3,256,256,4,5.036328747284576e-05,768,402653184,93.69110391262903,4297667.197682838,7994974200544.344,qk_init
3,256,256,2,4.5488761135057476e-05,768,201326592,51.865470527877875,3881707.616858238,4425853485045.578,qk_init
12,256,256,2048,0.020202365999999996,3072,824633720832,478.3437947812648,1723935231.9999998,40818670488001.266,qk_init
12,256,256,1024,0.010124155888157895,3072,412316860416,477.2583770318811,863927969.1228071,40726048173387.19,qk_init
12,256,256,512,0.005085633937062937,3072,206158430208,475.04777848703077,433974095.9627039,40537410430893.29,qk_init
12,256,256,256,0.0025654916853281853,3072,103079215104,470.84913933193053,218921957.14800516,40179126556324.74,qk_init
12,256,256,128,0.0013045765704467354,3072,51539607552,462.9699702434292,111323867.34478809,39506770794105.96,qk_init
12,256,256,64,0.0006742801519939804,3072,25769803776,447.87005387442576,57538572.970153,38218244597284.33,qk_init
12,256,256,32,0.00035831976790671853,3072,12884901888,421.3971919051604,30576620.194706645,35959227042573.69,qk_init
12,256,256,16,0.0002005369068918302,3072,6442450944,376.4766953382971,17112482.721436176,32126011335534.68,qk_init
12,256,256,8,0.00012179187250509165,3072,3221225472,309.94462293386505,10392906.453767821,26448607823689.82,qk_init
12,256,256,4,8.452507263643351e-05,3072,1610612736,223.2990450204527,7212806.198308992,19054851841745.297,qk_init
12,256,256,2,6.412381767545489e-05,3072,805306368,147.17127491946468,5471899.108305484,12558615459794.32,qk_init
3,256,256,2048,0.0016183739398395718,768,805306368,811597824.0,0.9922480620155039,1265467.7325087283,qk_ar
3,256,256,1024,0.0008322699728813558,768,402653184,405798912.0,0.9922480620155039,1230369.9921491416,qk_ar
3,256,256,512,0.00043886859397590365,768,201326592,202899456.0,0.9922480620155039,1166636.2255762408,qk_ar
3,256,256,256,0.00024185948322147648,768,100663296,101449728.0,0.9922480620155039,1058465.8355760013,qk_ar
3,256,256,128,0.00014308985100166944,768,50331648,50724864.0,0.9922480620155039,894542.82818777,qk_ar
3,256,256,64,9.382939365815932e-05,768,25165824,25362432.0,0.9922480620155039,682089.028872613,qk_ar
3,256,256,32,6.856070612244899e-05,768,12582912,12681216.0,0.9922480620155039,466739.6503012703,qk_ar
3,256,256,16,5.452260553129549e-05,768,6291456,6340608.0,0.9922480620155039,293456.26174846216,qk_ar
3,256,256,8,4.608557533261417e-05,768,3145728,3170304.0,0.9922480620155039,173590.1080166944,qk_ar
3,256,256,4,4.386146957766642e-05,768,1572864,1585152.0,0.9922480620155039,91196.21477609445,qk_ar
3,256,256,2,4.330941094420601e-05,768,786432,792576.0,0.9922480620155039,46179.33969539622,qk_ar
12,256,256,2048,0.006347041645299144,3072,3221225472,3246391296.0,0.9922480620155039,322670.011392918,qk_ar
12,256,256,1024,0.0031943104467592586,3072,1610612736,1623195648.0,0.9922480620155039,320569.96872013,qk_ar
12,256,256,512,0.0016183416350267381,3072,805306368,811597824.0,0.9922480620155039,316373.2483416833,qk_ar
12,256,256,256,0.0008325934893977947,3072,402653184,405798912.0,0.9922480620155039,307472.9784221131,qk_ar
12,256,256,128,0.0004389725746987952,3072,201326592,202899456.0,0.9922480620155039,291589.9702568624,qk_ar
12,256,256,64,0.00024191767449664432,3072,100663296,101449728.0,0.9922480620155039,264552.8076159138,qk_ar
12,256,256,32,0.0001431546143572621,3072,50331648,50724864.0,0.9922480620155039,223534.53392804778,qk_ar
12,256,256,16,9.404283597678917e-05,3072,25165824,25362432.0,0.9922480620155039,170135.23501087292,qk_ar
12,256,256,8,6.855550037091989e-05,3072,12582912,12681216.0,0.9922480620155039,116693.773026467,qk_ar
12,256,256,4,5.4802094978165945e-05,3072,6291456,6340608.0,0.9922480620155039,72989.91036006316,qk_ar
12,256,256,2,4.608510707869206e-05,3072,3145728,3170304.0,0.9922480620155039,43397.96795057727,qk_ar
1 n d seqlen bs latency h flop io intensity throughput series
2 3 256 256 2048 0.009623501679245285 768 618475290624 167.48502132816208 3692720015.912285 64267177503366.266 dense
3 3 256 256 1024 0.004853848615384615 768 309237645312 166.15392854317415 1861151572.059558 63709783682138.234 dense
4 3 256 256 512 0.0024687246971962615 768 154618822656 163.57953256539062 945221081.3366361 62631051097597.516 dense
5 3 256 256 256 0.0012845360838052097 768 77309411328 157.64931990085577 490388486.1451936 60184694149645.54 dense
6 3 256 256 128 0.0006901147179878049 768 38654705664 147.57393422494675 261934506.70684624 56012000116019.945 dense
7 3 256 256 64 0.0003363830693015702 768 19327352832 153.1328437752606 126212981.84970059 57456378146882.51 dense
8 3 256 256 32 0.00018671159748991485 768 9663676416 141.10249365427362 68486928.65540518 51757237075334.75 dense
9 3 256 256 16 0.00012353640857142858 768 4831838208 111.40488993609125 43371868.24359184 39112665358133.98 dense
10 3 256 256 8 9.774760007849294e-05 768 2415919104 76.43260800265766 31608487.09906635 24715891766754.14 dense
11 3 256 256 4 6.672271167474822e-05 768 1207959552 64.82614227498455 18633833.660438772 18104173551704.773 dense
12 3 256 256 2 4.9758770289855074e-05 768 603979776 55.317122669351576 10918495.880745342 12138157202874.861 dense
13 3 256 1 2048 9.785507940251571e-05 768 2415919104 76.34865809334705 31643242.518371396 24688745017132.86 dense
14 3 256 1 1024 6.692813470149253e-05 768 1207959552 64.62717090938949 18691202.70936228 18048606275785.867 dense
15 3 256 1 512 4.9680950036205655e-05 768 603979776 55.40377142534654 10901419.893658841 12157170415618.898 dense
16 3 256 1 256 4.2781118741058655e-05 768 301989888 45.95672244805227 6571179.83862661 7058952568020.829 dense
17 3 256 1 128 5.0662328255350016e-05 768 150994944 31.046026784880404 4863583.512513602 2980418571348.519 dense
18 3 256 1 64 4.475009253945481e-05 768 75497472 30.75426042497223 2454862.219307235 1687090857598.4766 dense
19 3 256 1 32 4.51682671454219e-05 768 37748736 28.29313765537115 1334201.1218340008 835735758435.5786 dense
20 3 256 1 16 5.03585186661834e-05 768 18874368 24.401035466223117 773506.846712577 374799904761.1871 dense
21 3 256 1 8 5.023459565217391e-05 768 9437184 23.972005435021096 393675.19858030166 187862246674.45105 dense
22 3 256 1 4 5.053219391083726e-05 768 4718592 23.58765586356967 200044.97383259286 93377936614.54384 dense
23 3 256 1 2 4.4607398995335484e-05 768 2359296 26.58285456464288 88752.54515134107 52890239133.797226 dense
24 12 256 256 2048 0.14480779847058822 3072 9895604649984 44.620009282941716 221775046868.20184 68336130750540.26 dense
25 12 256 256 1024 0.07254347629166667 3072 4947802324992 44.664248332585096 110777691547.58836 68204648824643.82 dense
26 12 256 256 512 0.036310761444444443 3072 2473901162496 44.876147984203506 55127306456.13385 68131349056975.164 dense
27 12 256 256 256 0.01821551906896552 3072 1236950581248 45.24607467289738 27338295977.947884 67906414116709.98 dense
28 12 256 256 128 0.009229417903030302 3072 618475290624 45.67217092440895 13541622351.335684 67011299859001.46 dense
29 12 256 256 64 0.004754550595394737 3072 309237645312 46.31372736116993 6677019167.566916 65040352207320.695 dense
30 12 256 256 32 0.002405752659340659 3072 154618822656 49.68826015254682 3111777755.5766335 64270456921525.82 dense
31 12 256 256 16 0.0012287219045005488 3072 77309411328 56.323579604557374 1372594069.3184311 62918558743709.18 dense
32 12 256 256 8 0.0006206816149425287 3072 38654705664 70.95456179103653 544781120.315271 62277832520589.78 dense
33 12 256 256 4 0.0003875502697142857 3072 19327352832 81.16954743236613 238110885.71245712 49870569942445.75 dense
34 12 256 256 2 0.00027502018627941914 3072 9663676416 91.50537035282076 105607751.53129694 35138062215483.168 dense
35 12 256 1 2048 0.0006202853873290136 3072 38654705664 70.99988634205897 544433345.6784943 62317614526515.766 dense
36 12 256 1 1024 0.00038721467732724153 3072 19327352832 81.2398957010995 237904697.74985722 49913791918755.53 dense
37 12 256 1 512 0.000274364799 3072 9663676416 91.72395326121995 105356082.81599998 35221998052308.45 dense
38 12 256 1 256 0.00012488918589482266 3072 4831838208 176.31707535146046 27404255.647778228 38689003962834.75 dense
39 12 256 1 128 8.976711102514506e-05 3072 2415919104 227.78088507574267 10606329.425740216 26913187652026.21 dense
40 12 256 1 64 8.715176287471176e-05 3072 1207959552 225.59268282689945 5354604.31102229 13860414432884.701 dense
41 12 256 1 32 8.523013435114503e-05 3072 603979776 226.06539514085782 2671703.8033338524 7086458100741.991 dense
42 12 256 1 16 7.901561645904116e-05 3072 301989888 241.35704882952732 1251216.3595988373 3821901309300.556 dense
43 12 256 1 8 7.827949114210329e-05 3072 150994944 242.37091635608994 622991.1833900034 1928920867994.581 dense
44 12 256 1 4 7.779445951035782e-05 3072 75497472 243.25022783249054 310369.58391664835 970473636235.5986 dense
45 12 256 1 2 7.758845406626506e-05 3072 37748736 243.57933441822672 154975.11761480253 486525172518.07056 dense
46 3 256 256 2048 0.00507974918466899 768 206158430208 475.59810852303485 433471930.42508715 40584371927298.98 qk_init
47 3 256 256 1024 0.0025616677649325623 768 103079215104 471.5519977009198 218595649.27424532 40239103803811.82 qk_init
48 3 256 256 512 0.0013029336670480549 768 51539607552 463.55374128015677 111183672.92143403 39556585922573.38 qk_init
49 3 256 256 256 0.0006738189029345373 768 25769803776 448.1766342333362 57499213.050413854 38244406121244.69 qk_init
50 3 256 256 128 0.000358254672959467 768 12884901888 421.47375986100144 30571065.425874516 35965760841472.125 qk_init
51 3 256 256 64 0.0002007051105022831 768 6442450944 376.1611839930762 17126836.096194826 32099087700742.5 qk_init
52 3 256 256 32 0.00012189697230142565 768 3221225472 309.6773881032524 10401874.969721656 26425803784810.87 qk_init
53 3 256 256 16 8.453561698040722e-05 768 1610612736 223.2711923587723 7213705.982328083 19052475081281.902 qk_init
54 3 256 256 8 6.407660705009276e-05 768 805306368 147.2797083750448 5467870.468274581 12567868448003.822 qk_init
55 3 256 256 4 5.036328747284576e-05 768 402653184 93.69110391262903 4297667.197682838 7994974200544.344 qk_init
56 3 256 256 2 4.5488761135057476e-05 768 201326592 51.865470527877875 3881707.616858238 4425853485045.578 qk_init
57 12 256 256 2048 0.020202365999999996 3072 824633720832 478.3437947812648 1723935231.9999998 40818670488001.266 qk_init
58 12 256 256 1024 0.010124155888157895 3072 412316860416 477.2583770318811 863927969.1228071 40726048173387.19 qk_init
59 12 256 256 512 0.005085633937062937 3072 206158430208 475.04777848703077 433974095.9627039 40537410430893.29 qk_init
60 12 256 256 256 0.0025654916853281853 3072 103079215104 470.84913933193053 218921957.14800516 40179126556324.74 qk_init
61 12 256 256 128 0.0013045765704467354 3072 51539607552 462.9699702434292 111323867.34478809 39506770794105.96 qk_init
62 12 256 256 64 0.0006742801519939804 3072 25769803776 447.87005387442576 57538572.970153 38218244597284.33 qk_init
63 12 256 256 32 0.00035831976790671853 3072 12884901888 421.3971919051604 30576620.194706645 35959227042573.69 qk_init
64 12 256 256 16 0.0002005369068918302 3072 6442450944 376.4766953382971 17112482.721436176 32126011335534.68 qk_init
65 12 256 256 8 0.00012179187250509165 3072 3221225472 309.94462293386505 10392906.453767821 26448607823689.82 qk_init
66 12 256 256 4 8.452507263643351e-05 3072 1610612736 223.2990450204527 7212806.198308992 19054851841745.297 qk_init
67 12 256 256 2 6.412381767545489e-05 3072 805306368 147.17127491946468 5471899.108305484 12558615459794.32 qk_init
68 3 256 256 2048 0.0016183739398395718 768 805306368 811597824.0 0.9922480620155039 1265467.7325087283 qk_ar
69 3 256 256 1024 0.0008322699728813558 768 402653184 405798912.0 0.9922480620155039 1230369.9921491416 qk_ar
70 3 256 256 512 0.00043886859397590365 768 201326592 202899456.0 0.9922480620155039 1166636.2255762408 qk_ar
71 3 256 256 256 0.00024185948322147648 768 100663296 101449728.0 0.9922480620155039 1058465.8355760013 qk_ar
72 3 256 256 128 0.00014308985100166944 768 50331648 50724864.0 0.9922480620155039 894542.82818777 qk_ar
73 3 256 256 64 9.382939365815932e-05 768 25165824 25362432.0 0.9922480620155039 682089.028872613 qk_ar
74 3 256 256 32 6.856070612244899e-05 768 12582912 12681216.0 0.9922480620155039 466739.6503012703 qk_ar
75 3 256 256 16 5.452260553129549e-05 768 6291456 6340608.0 0.9922480620155039 293456.26174846216 qk_ar
76 3 256 256 8 4.608557533261417e-05 768 3145728 3170304.0 0.9922480620155039 173590.1080166944 qk_ar
77 3 256 256 4 4.386146957766642e-05 768 1572864 1585152.0 0.9922480620155039 91196.21477609445 qk_ar
78 3 256 256 2 4.330941094420601e-05 768 786432 792576.0 0.9922480620155039 46179.33969539622 qk_ar
79 12 256 256 2048 0.006347041645299144 3072 3221225472 3246391296.0 0.9922480620155039 322670.011392918 qk_ar
80 12 256 256 1024 0.0031943104467592586 3072 1610612736 1623195648.0 0.9922480620155039 320569.96872013 qk_ar
81 12 256 256 512 0.0016183416350267381 3072 805306368 811597824.0 0.9922480620155039 316373.2483416833 qk_ar
82 12 256 256 256 0.0008325934893977947 3072 402653184 405798912.0 0.9922480620155039 307472.9784221131 qk_ar
83 12 256 256 128 0.0004389725746987952 3072 201326592 202899456.0 0.9922480620155039 291589.9702568624 qk_ar
84 12 256 256 64 0.00024191767449664432 3072 100663296 101449728.0 0.9922480620155039 264552.8076159138 qk_ar
85 12 256 256 32 0.0001431546143572621 3072 50331648 50724864.0 0.9922480620155039 223534.53392804778 qk_ar
86 12 256 256 16 9.404283597678917e-05 3072 25165824 25362432.0 0.9922480620155039 170135.23501087292 qk_ar
87 12 256 256 8 6.855550037091989e-05 3072 12582912 12681216.0 0.9922480620155039 116693.773026467 qk_ar
88 12 256 256 4 5.4802094978165945e-05 3072 6291456 6340608.0 0.9922480620155039 72989.91036006316 qk_ar
89 12 256 256 2 4.608510707869206e-05 3072 3145728 3170304.0 0.9922480620155039 43397.96795057727 qk_ar

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

View File

@@ -1,594 +0,0 @@
# python embedd_micro.py --use_int8 Fastest
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from torchao import quantize_
from transformers import AutoModel, BitsAndBytesConfig
from tqdm import tqdm
from contextlib import contextmanager
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_int4: bool = False
use_int8: bool = False # Add this parameter
use_cuda_graphs: bool = False
use_flash_attention: bool = False
use_linear8bitlt: bool = False
class CUDAGraphContainer:
"""Container for managing CUDA graphs for different batch sizes."""
def __init__(self, model: nn.Module, seq_length: int):
self.model = model
self.seq_length = seq_length
self.graphs: Dict[int, CUDAGraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
if batch_size not in self.graphs:
self.graphs[batch_size] = CUDAGraphWrapper(
self.model, batch_size, self.seq_length
)
return self.graphs[batch_size]
class CUDAGraphWrapper:
"""Wrapper for CUDA graph capture and replay."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device="cuda",
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
if model is None:
raise ValueError("Cannot optimize None model")
# Move to GPU
model = model.cuda()
print("- Model moved to GPU")
# FP16
if config.use_fp16 and not config.use_int4:
model = model.half()
# use torch compile
model = torch.compile(model)
print("- Using FP16 precision")
# Check if using SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Flash Attention
if config.use_flash_attention:
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Memory efficient attention
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using CUDA events."""
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@contextmanager
def timing(self):
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
def elapsed_time(self) -> float:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
try:
self.model = self._load_model()
if self.model is None:
raise ValueError("Model initialization failed - model is None")
self.cuda_graphs = (
CUDAGraphContainer(self.model, config.seq_length)
if config.use_cuda_graphs
else None
)
self.timer = Timer()
except Exception as e:
print(f"ERROR in benchmark initialization: {str(e)}")
raise
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
try:
# Int4 quantization using HuggingFace integration
if self.config.use_int4:
import bitsandbytes as bnb
print(f"- bitsandbytes version: {bnb.__version__}")
# 检查是否使用自定义的8bit量化
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Using custom Linear8bitLt replacement for all linear layers")
# 加载原始模型(不使用量化配置)
import bitsandbytes as bnb
import torch
# set default to half
torch.set_default_dtype(torch.float16)
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
model = AutoModel.from_pretrained(
self.config.model_path,
torch_dtype=compute_dtype,
)
# 定义替换函数
def replace_linear_with_linear8bitlt(model):
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
for name, module in list(model.named_children()):
if isinstance(module, nn.Linear):
# 获取原始线性层的参数
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
# 创建8bit线性层
# print size
print(f"in_features: {in_features}, out_features: {out_features}")
new_module = bnb.nn.Linear8bitLt(
in_features,
out_features,
bias=bias,
has_fp16_weights=False
)
# 复制权重和偏置
new_module.weight.data = module.weight.data
if bias:
new_module.bias.data = module.bias.data
# 替换模块
setattr(model, name, new_module)
else:
# 递归处理子模块
replace_linear_with_linear8bitlt(module)
return model
# 替换所有线性层
model = replace_linear_with_linear8bitlt(model)
# add torch compile
model = torch.compile(model)
# 将模型移到GPU量化发生在这里
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print("- All linear layers replaced with Linear8bitLt")
else:
# 使用原来的Int4量化方法
print("- Using bitsandbytes for Int4 quantization")
# Create quantization config
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
print("- Quantization config:", quantization_config)
# Load model directly with quantization config
model = AutoModel.from_pretrained(
self.config.model_path,
quantization_config=quantization_config,
torch_dtype=compute_dtype,
device_map="auto" # Let HF decide on device mapping
)
# Check if model loaded successfully
if model is None:
raise ValueError("Model loading returned None")
print(f"- Model type: {type(model)}")
# Apply optimizations directly here
print("\nApplying model optimizations:")
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
print("- Model moved to GPU with Linear8bitLt quantization")
else:
# Skip moving to GPU since device_map="auto" already did that
print("- Model already on GPU due to device_map='auto'")
# Skip FP16 conversion since we specified compute_dtype
print(f"- Using {compute_dtype} for compute dtype")
# Check CUDA and SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Try xformers if available
try:
from xformers.ops import memory_efficient_attention
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
else:
print("- Model doesn't support xformers")
except (ImportError, AttributeError):
print("- Xformers not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# Int8 quantization using HuggingFace integration
# Int8 quantization using TorchAO
elif self.config.use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
# Import the quantize_ function and the quantization config
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
print("- Successfully imported TorchAO")
# Load model normally first
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = AutoModel.from_pretrained(
self.config.model_path,
device_map="auto"
)
print("- Model loaded in full precision")
print(f"- Model type: {type(model)}")
# Apply quantization - call the function to get the config, then apply it
# quantize_(model, int8_dynamic_activation_int8_weight())
# from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig,int8_dynamic_activation_int8_semi_sparse_weight,int4_weight_only,Int8DynActInt4WeightGPTQQuantizer,int8_dynamic_activation_int4_weight,Int8DynamicActivationInt4WeightConfig,Int4DynamicActivationInt4WeightConfig
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())
print("- Model successfully quantized with int8 weights and int8 activations")
# add torch compile
model = torch.compile(model)
# For older PyTorch versions that have issues with tensor subclasses
from torchao.utils import unwrap_tensor_subclass
import torch
if hasattr(torch, '_version') and not torch.version >= "2.5.0":
print("- Unwrapping tensor subclasses for compatibility with older PyTorch")
unwrap_tensor_subclass(model)
# Apply optimizations
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
else:
print("- PyTorch SDPA not available")
# Set to eval mode
model.eval()
print("- Model set to eval mode")
# For better performance with int8 dynamic quantization
torch._inductor.config.force_fuse_int_mm_with_mul = True
print("- Enabled fusion of int matmul with mul operations")
else:
# Standard loading for FP16/FP32
model = AutoModel.from_pretrained(self.config.model_path)
print("- Model loaded in standard precision")
print(f"- Model type: {type(model)}")
# Apply standard optimizations
# set default to half
import torch
torch.set_default_dtype(torch.bfloat16)
model = ModelOptimizer.optimize(model, self.config)
model = model.half()
# add torch compile
model = torch.compile(model)
# Final check to ensure model is not None
if model is None:
raise ValueError("Model is None after optimization")
print(f"- Final model type: {type(model)}")
return model
except Exception as e:
print(f"ERROR loading model: {str(e)}")
import traceback
traceback.print_exc()
raise
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device="cuda",
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
# Reset peak memory stats
torch.cuda.reset_peak_memory_stats()
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create CUDA graph for this batch size
cuda_graph_wrapper = (
self.cuda_graphs.get_or_create(batch_size)
if self.cuda_graphs is not None
else None
)
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
print(f"Input shape: {input_ids.shape}")
# Run benchmark
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
try:
elapsed_time, output = self._run_inference(input_ids, cuda_graph_wrapper)
if i == 0: # Only print on first run
print(f"Output shape: {output.last_hidden_state.shape}")
times.append(elapsed_time)
except Exception as e:
print(f"Error during inference: {e}")
break
if not times:
print(f"No successful runs for batch size {batch_size}, skipping")
continue
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
# Log memory usage
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
# Add memory info to results
for batch_size in results:
results[batch_size]["peak_memory_gb"] = peak_memory_gb
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,10,16,20,32,40,64,128,256,512,1024,2048,4096,8192",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--use_fp16",
action="store_true",
help="Enable FP16 inference",
)
parser.add_argument(
"--use_int4",
action="store_true",
help="Enable INT4 quantization using bitsandbytes",
)
parser.add_argument(
"--use_int8",
action="store_true",
help="Enable INT8 quantization for both activations and weights using bitsandbytes",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available",
)
parser.add_argument(
"--use_linear8bitlt",
action="store_true",
help="Enable Linear8bitLt quantization for all linear layers",
)
args = parser.parse_args()
# Print arguments for debugging
print("\nCommand line arguments:")
for arg, value in vars(args).items():
print(f"- {arg}: {value}")
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=args.use_fp16,
use_int4=args.use_int4,
use_int8=args.use_int8, # Add this line
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
use_linear8bitlt=args.use_linear8bitlt,
)
# Print configuration for debugging
print("\nBenchmark configuration:")
for field, value in vars(config).items():
print(f"- {field}: {value}")
try:
benchmark = Benchmark(config)
results = benchmark.run()
# Save results to file
import json
import os
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Generate filename based on configuration
precision_type = "int4" if config.use_int4 else "fp16" if config.use_fp16 else "fp32"
model_name = os.path.basename(config.model_path)
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
# Save results
with open(output_file, "w") as f:
json.dump(
{
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
"results": {str(k): v for k, v in results.items()}
},
f,
indent=2
)
print(f"Results saved to {output_file}")
except Exception as e:
print(f"Benchmark failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,376 +0,0 @@
import argparse
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import AutoModel
from tqdm import tqdm
from contextlib import contextmanager
import math
@dataclass
class BenchmarkConfig:
model_path: str
batch_sizes: List[int]
seq_length: int
num_runs: int
use_fp16: bool = True
use_cuda_graphs: bool = False
use_flash_attention: bool = False
max_batch_size: int = 256 # Maximum batch size before splitting
class CUDAGraphContainer:
"""Container for managing CUDA graphs for different batch sizes."""
def __init__(self, model: nn.Module, seq_length: int, max_batch_size: int):
self.model = model
self.seq_length = seq_length
self.max_batch_size = max_batch_size
self.graphs: Dict[int, CUDAGraphWrapper] = {}
def get_or_create(self, batch_size: int) -> 'CUDAGraphWrapper':
# For CUDA graphs, we always use the actual batch size or max_batch_size
effective_batch_size = min(batch_size, self.max_batch_size)
if effective_batch_size not in self.graphs:
self.graphs[effective_batch_size] = CUDAGraphWrapper(
self.model, effective_batch_size, self.seq_length
)
return self.graphs[effective_batch_size]
class CUDAGraphWrapper:
"""Wrapper for CUDA graph capture and replay."""
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
self.model = model
self.static_input = self._create_random_batch(batch_size, seq_length)
self.static_attention_mask = torch.ones_like(self.static_input)
# Warm up
self._warmup()
# Capture graph
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self.static_output = self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
return torch.randint(
0, 1000, (batch_size, seq_length),
device="cuda",
dtype=torch.long
)
def _warmup(self, num_warmup: int = 3):
with torch.no_grad():
for _ in range(num_warmup):
self.model(
input_ids=self.static_input,
attention_mask=self.static_attention_mask
)
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
self.static_input.copy_(input_ids)
self.static_attention_mask.copy_(attention_mask)
self.graph.replay()
return self.static_output
class ModelOptimizer:
"""Applies various optimizations to the model."""
@staticmethod
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
print("\nApplying model optimizations:")
# Move to GPU
model = model.cuda()
print("- Model moved to GPU")
# FP16
if config.use_fp16:
model = model.half()
print("- Using FP16 precision")
# Check if using SDPA
if torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
# No need to do anything as it's automatically enabled
else:
print("- PyTorch SDPA not available")
# Flash Attention
if config.use_flash_attention:
try:
from flash_attn.flash_attention import FlashAttention
print("- Flash Attention 2 available")
if hasattr(model.config, "attention_mode"):
model.config.attention_mode = "flash_attention_2"
print(" - Enabled Flash Attention 2 mode")
except ImportError:
print("- Flash Attention not available")
# Optimize LayerNorm
try:
num_layernorms = 0
for module in model.modules():
if isinstance(module, torch.nn.LayerNorm):
module.forward = torch.jit.script(module.forward)
num_layernorms += 1
if num_layernorms > 0:
print(f"- Optimized {num_layernorms} LayerNorm modules with TorchScript")
except Exception as e:
print(f"- LayerNorm optimization failed: {e}")
# Memory efficient attention
try:
from xformers.ops import memory_efficient_attention
model.enable_xformers_memory_efficient_attention()
print("- Enabled xformers memory efficient attention")
except (ImportError, AttributeError):
print("- Xformers not available")
model.eval()
print("- Model set to eval mode")
return model
class Timer:
"""Handles accurate GPU timing using CUDA events."""
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@contextmanager
def timing(self):
self.start_event.record()
yield
self.end_event.record()
self.end_event.synchronize()
def elapsed_time(self) -> float:
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
class Benchmark:
"""Main benchmark runner."""
def __init__(self, config: BenchmarkConfig):
self.config = config
self.model = self._load_model()
self.cuda_graphs = (
CUDAGraphContainer(self.model, config.seq_length, config.max_batch_size)
if config.use_cuda_graphs
else None
)
self.timer = Timer()
def _load_model(self) -> nn.Module:
print(f"Loading model from {self.config.model_path}...")
model = AutoModel.from_pretrained(self.config.model_path)
return ModelOptimizer.optimize(model, self.config)
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
return torch.randint(
0, 1000,
(batch_size, self.config.seq_length),
device="cuda",
dtype=torch.long
)
def _run_inference(
self,
input_ids: torch.Tensor,
cuda_graph_wrapper: Optional[CUDAGraphWrapper] = None
) -> Tuple[float, torch.Tensor]:
attention_mask = torch.ones_like(input_ids)
original_batch_size = input_ids.shape[0]
print(f"Original input_ids shape: {input_ids.shape}")
# Split large batches to avoid OOM
max_batch_size = self.config.max_batch_size
if original_batch_size > max_batch_size:
print(f"Splitting batch of size {original_batch_size} into chunks of {max_batch_size}")
total_time = 0
outputs = []
with torch.no_grad():
for i in range(0, original_batch_size, max_batch_size):
end_idx = min(i + max_batch_size, original_batch_size)
batch_slice = input_ids[i:end_idx]
mask_slice = attention_mask[i:end_idx]
print(f"Processing chunk {i//max_batch_size + 1}: shape {batch_slice.shape}")
# Use CUDA graph if available (with the smaller batch size)
chunk_cuda_graph = None
if cuda_graph_wrapper is not None:
chunk_cuda_graph = self.cuda_graphs.get_or_create(batch_slice.shape[0])
with self.timer.timing():
if chunk_cuda_graph is not None:
chunk_output = chunk_cuda_graph(batch_slice, mask_slice)
else:
chunk_output = self.model(input_ids=batch_slice, attention_mask=mask_slice)
total_time += self.timer.elapsed_time()
outputs.append(chunk_output.last_hidden_state)
# Combine outputs
combined_output = torch.cat(outputs, dim=0)
print(f"Combined output shape: {combined_output.shape}")
# Create a wrapper object similar to model output to maintain consistency
class DummyOutput:
def __init__(self, hidden_states):
self.last_hidden_state = hidden_states
output = DummyOutput(combined_output)
return total_time, output
else:
# Process normally for small batches
with torch.no_grad(), self.timer.timing():
if cuda_graph_wrapper is not None:
output = cuda_graph_wrapper(input_ids, attention_mask)
else:
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
print(f"Output shape: {output.last_hidden_state.shape}")
return self.timer.elapsed_time(), output
def run(self) -> Dict[int, Dict[str, float]]:
results = {}
for batch_size in self.config.batch_sizes:
print(f"\nTesting batch size: {batch_size}")
times = []
# Get or create CUDA graph for this batch size
cuda_graph_wrapper = None
if self.cuda_graphs is not None:
if batch_size <= self.config.max_batch_size:
cuda_graph_wrapper = self.cuda_graphs.get_or_create(batch_size)
else:
# For large batches, we'll use the max_batch_size graph in chunks
cuda_graph_wrapper = True # Just a flag to indicate we want to use CUDA graphs
# Pre-allocate input tensor
input_ids = self._create_random_batch(batch_size)
# Run benchmark
for run_idx in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
elapsed_time, _ = self._run_inference(input_ids, cuda_graph_wrapper)
times.append(elapsed_time)
print(f"Run {run_idx+1}: {elapsed_time:.4f}s")
# Calculate statistics
avg_time = np.mean(times)
std_time = np.std(times)
throughput = batch_size / avg_time
results[batch_size] = {
"avg_time": avg_time,
"std_time": std_time,
"throughput": throughput,
}
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
print(f"Throughput: {throughput:.2f} sequences/second")
return results
def main():
parser = argparse.ArgumentParser(description="Model Inference Benchmark")
parser.add_argument(
"--model_path",
type=str,
default="facebook/contriever",
help="Path to the model",
)
parser.add_argument(
"--batch_sizes",
type=str,
default="1,2,4,8,16,32,64,128,256,512,1024,2048,4096",
help="Comma-separated list of batch sizes",
)
parser.add_argument(
"--seq_length",
type=int,
default=256,
help="Sequence length for input",
)
parser.add_argument(
"--num_runs",
type=int,
default=5,
help="Number of runs for each batch size",
)
parser.add_argument(
"--no_fp16",
action="store_true",
help="Disable FP16 inference",
)
parser.add_argument(
"--use_cuda_graphs",
action="store_true",
help="Enable CUDA Graphs optimization",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Enable Flash Attention 2 if available",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=256,
help="Maximum batch size before splitting to prevent OOM",
)
args = parser.parse_args()
config = BenchmarkConfig(
model_path=args.model_path,
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
seq_length=args.seq_length,
num_runs=args.num_runs,
use_fp16=not args.no_fp16,
use_cuda_graphs=args.use_cuda_graphs,
use_flash_attention=args.use_flash_attention,
max_batch_size=args.max_batch_size,
)
benchmark = Benchmark(config)
results = benchmark.run()
# Print overall summary
print("\n===== BENCHMARK SUMMARY =====")
print(f"Model: {config.model_path}")
print(f"Sequence Length: {config.seq_length}")
print(f"FP16: {config.use_fp16}")
print(f"CUDA Graphs: {config.use_cuda_graphs}")
print(f"Flash Attention: {config.use_flash_attention}")
print(f"Max Batch Size: {config.max_batch_size}")
print("\nResults:")
print("\nBatch Size | Avg Time (s) | Throughput (seq/s)")
print("-" * 50)
for bs in sorted(results.keys()):
r = results[bs]
print(f"{bs:^10} | {r['avg_time']:^12.4f} | {r['throughput']:^17.2f}")
if __name__ == "__main__":
main()

View File

@@ -1,218 +0,0 @@
import torch
import torch.nn as nn
import time
import torch.nn.functional as F
# Import necessary functions from the quantize.py file
def get_group_qparams(w, n_bit=4, groupsize=128):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
torch.bfloat16
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros):
assert scales.shape == zeros.shape
assert scales.dtype == torch.bfloat16
assert zeros.dtype == torch.bfloat16
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self, in_features: int, out_features: int,
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(torch.bfloat16)
return linear_forward_int4(
input,
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
# Define dimensions that satisfy the requirements for INT4 quantization
# in_features must be divisible by inner_k_tiles * 16
# out_features must be divisible by 8
in_features = 1024 # Must be divisible by inner_k_tiles * 16
out_features = 2048 # Must be divisible by 8
groupsize = 128
inner_k_tiles = 8
# Create models
fp16_model = nn.Sequential(
nn.Linear(in_features, out_features, bias=False)
)
# Create INT4 model
int4_model = nn.Sequential(
WeightOnlyInt4Linear(in_features, out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles)
)
# Quantize the weights and set up the INT4 model
with torch.no_grad():
# Convert FP16 weights to INT4
fp16_weight = fp16_model[0].weight.data.to(torch.bfloat16)
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
fp16_weight, groupsize, inner_k_tiles
)
# Set the quantized weights in the INT4 model
int4_model[0].weight.copy_(weight_int4pack)
int4_model[0].scales_and_zeros.copy_(scales_and_zeros)
# Move models to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fp16_model = fp16_model.to(device)
int4_model = int4_model.to(device)
# Create random input tensor
batch_size = 1024
input_tensor = torch.randn(batch_size, in_features, device=device)
input_tensor_bf16 = input_tensor.to(torch.bfloat16)
# Speed test function
def speed_test(model, input_tensor, name, num_iterations=100):
# Warmup
for _ in range(10):
_ = model(input_tensor)
# Actual timing
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iterations):
_ = model(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"{name} model: {avg_time:.6f} seconds per iteration")
return avg_time
# Run speed tests
with torch.no_grad(): # Disable gradient calculation for inference
print(f"Running benchmark with batch_size={batch_size}, in_features={in_features}, out_features={out_features}")
print(f"INT4 parameters: groupsize={groupsize}, inner_k_tiles={inner_k_tiles}")
fp16_time = speed_test(fp16_model, input_tensor_bf16, "FP16")
int4_time = speed_test(int4_model, input_tensor, "INT4")
# Calculate speedup
speedup = fp16_time / int4_time
print(f"INT4 is {speedup:.2f}x faster than FP16")
# Calculate memory savings
fp16_memory = fp16_model[0].weight.nelement() * fp16_model[0].weight.element_size()
int4_memory = (int4_model[0].weight.nelement() * int4_model[0].weight.element_size() +
int4_model[0].scales_and_zeros.nelement() * int4_model[0].scales_and_zeros.element_size())
memory_reduction = fp16_memory / int4_memory
print(f"Memory reduction: {memory_reduction:.2f}x ({fp16_memory/1024/1024:.2f} MB vs {int4_memory/1024/1024:.2f} MB)")
# Check accuracy
with torch.no_grad():
fp16_output = fp16_model(input_tensor_bf16)
int4_output = int4_model(input_tensor)
# Calculate error metrics
abs_error = torch.abs(fp16_output - int4_output)
rel_error = abs_error / (torch.abs(fp16_output) + 1e-7)
print(f"Mean absolute error: {abs_error.mean().item():.6f}")
print(f"Max absolute error: {abs_error.max().item():.6f}")
print(f"Mean relative error: {rel_error.mean().item():.6f}")

View File

@@ -1,83 +0,0 @@
import torch
import nvmath.bindings.cublas
import ctypes
# 创建 CUBLAS 句柄
handle = nvmath.bindings.cublas.create()
# 准备数据 - 使用 uint8 类型,并确保内存连续
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 255).to(torch.uint8).contiguous()
b = (torch.rand(k, n, device="cuda") * 255).to(torch.uint8).contiguous()
c = torch.zeros(m, n, device="cuda", dtype=torch.uint8).contiguous()
# 确保张量在 CUDA 上
assert a.is_cuda and b.is_cuda and c.is_cuda
# 确保张量是连续的
assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous()
# 获取指针
a_ptr = a.data_ptr()
b_ptr = b.data_ptr()
c_ptr = c.data_ptr()
# 设置参数
transa = 0 # CUBLAS_OP_N (不转置)
transb = 0 # CUBLAS_OP_N (不转置)
transc = 0 # CUBLAS_OP_N (不转置)
# 设置偏置值
a_bias = 0
b_bias = 0
c_bias = 0
# 设置正确的 leading dimensions
lda = k # A 的 leading dimension
ldb = n # B 的 leading dimension
ldc = n # C 的 leading dimension
c_mult = 1
c_shift = 0
# 打印调试信息
print(f"a shape: {a.shape}, a_ptr: {a_ptr}")
print(f"b shape: {b.shape}, b_ptr: {b_ptr}")
print(f"c shape: {c.shape}, c_ptr: {c_ptr}")
try:
# 调用 uint8gemm_bias
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr, a_bias, lda,
b_ptr, b_bias, ldb,
c_ptr, c_bias, ldc,
c_mult, c_shift
)
except Exception as e:
print(f"Error: {e}")
# 尝试使用 ctypes 转换指针
a_ptr_c = ctypes.c_void_p(a_ptr).value
b_ptr_c = ctypes.c_void_p(b_ptr).value
c_ptr_c = ctypes.c_void_p(c_ptr).value
print(f"Using ctypes: a_ptr: {a_ptr_c}, b_ptr: {b_ptr_c}, c_ptr: {c_ptr_c}")
# 再次尝试调用
nvmath.bindings.cublas.uint8gemm_bias(
handle,
transa, transb, transc,
m, n, k,
a_ptr_c, a_bias, lda,
b_ptr_c, b_bias, ldb,
c_ptr_c, c_bias, ldc,
c_mult, c_shift
)
# 销毁 CUBLAS 句柄
nvmath.bindings.cublas.destroy(handle)
# 打印结果
print("Result:")
print(c)

View File

@@ -1,23 +0,0 @@
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor import oneshot
# Select quantization algorithm. In this case, we:
# * apply SmoothQuant to make the activations easier to quantize
# * quantize the weights to int8 with GPTQ (static per channel)
# * quantize the activations to int8 (dynamic per token)
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(scheme="W8A8", targets="Linear", ignore=["lm_head"]),
]
# Apply quantization using the built in open_platypus dataset.
# * See examples for demos showing how to pass a custom calibration set
oneshot(
model="facebook/contriever",
dataset="open_platypus",
recipe=recipe,
output_dir="contriever-INT4",
max_seq_length=2048,
num_calibration_samples=512,
)

View File

@@ -1,41 +0,0 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0
"""
This example demonstrates basic matrix multiplication of FP8 tensors.
In narrow-precision operations, quantization scales must be provided for each tensor. These
scales are used to dequantize input operands and quantize the result. Without proper
scaling, the results of FP8 operations will likely exceed the type's range.
FP8 is only supported with cuBLAS 12.8 or newer and on devices with compute
capability 8.9 or higher.
"""
import torch
import nvmath
# Prepare sample input data. Note that N, M and K must be divisible by 16 for FP8.
# cuBLAS requires B to be column-major, so we first create a row-major tensor and then
# transpose it.
m, n, k = 64, 32, 48
a = (torch.rand(m, k, device="cuda") * 10).type(torch.float8_e4m3fn)
b = (torch.rand(n, k, device="cuda") * 10).type(torch.float8_e4m3fn).T
# Prepare quantization scales. The scales must allow the result to fit within the dynamic
# range of the data type used. Scales can be provided either as a dictionary or as a
# MatmulQuantizationScales object. Note that scales are only allowed for FP8 operands.
scales = {"a": 1, "b": 1, "d": 0.1}
# Perform the multiplication. The result of the multiplication will be:
# (scales.a * A) @ (scales.b * B) * scales.d
result = nvmath.linalg.advanced.matmul(a, b, quantization_scales=scales)
# Check how scaling helped to fit into the dynamic range of float8_e4m3fn type.
result_without_scaling = nvmath.linalg.advanced.matmul(a, b, quantization_scales={"a": 1, "b": 1, "d": 1})
print("Without scaling, most of the elements were clamped to the maximum value of float8_e4m3fn type (448):")
print(result_without_scaling)
print(f"\nWith D scale set to {scales['d']}, they were scaled down to fit into the dynamic range of float8_e4m3fn:")
print(result)

View File

View File

@@ -1,58 +0,0 @@
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
def save_model_in_pth_format(model_name, output_dir):
"""
Download a model from Hugging Face and save it in PTH format
for use with quantization benchmarks.
Args:
model_name: Name of the model on Hugging Face
output_dir: Directory to save the model
"""
print(f"Loading model {model_name}...")
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True
)
# Save tokenizer
tokenizer.save_pretrained(output_dir)
# Extract and save the model weights in PTH format
model_state_dict = model.state_dict()
# Save the model weights
model_path = Path(output_dir) / "model.pth"
torch.save(model_state_dict, model_path)
print(f"Model saved to {model_path}")
# Print model size information
param_count = sum(p.numel() for p in model.parameters())
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
print(f"Model parameters: {param_count:,}")
print(f"Model size: {model_size_mb:.2f} MB")
return model_path
if __name__ == "__main__":
# Use a small model for testing
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
output_dir = "./tinyllama-1.1b-chat"
model_path = save_model_in_pth_format(model_name, output_dir)
print("\nYou can now use this model with the INT4 benchmark script.")
print("Example command:")
print(f"python int4benchmark.py --model_path {model_path}")

View File

@@ -1,677 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cab91cfc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import copy\n",
"import dataclasses\n",
"import os\n",
"import time\n",
"import pathlib\n",
"import itertools\n",
"import multiprocessing\n",
"import scipy\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pickle\n",
"import gzip\n",
"import threading\n",
"import queue\n",
"import pytz\n",
"import traceback\n",
"from datetime import datetime\n",
"from tqdm.auto import tqdm, trange\n",
"from typing import Any\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as mtick\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format='retina'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d24fbd7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sat Apr 12 00:10:05 2025 \n",
"+-----------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n",
"|-----------------------------------------+------------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+========================+======================|\n",
"| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n",
"| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+------------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=========================================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "538b2c11",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n",
" latency = []\n",
" \n",
" # First run, ignore min_secs\n",
" if f_setup is not None:\n",
" f_setup()\n",
" st = time.perf_counter_ns()\n",
" f()\n",
" ed = time.perf_counter_ns()\n",
" latency.append((ed-st)/1e9)\n",
" \n",
" # Subsequent runs, until reaching both min_repeat and min_secs\n",
" min_nanos = int(min_secs * 1e9)\n",
" start_nanos = time.perf_counter_ns()\n",
" while True:\n",
" now_nanos = time.perf_counter_ns()\n",
" if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n",
" break\n",
" if f_setup is not None:\n",
" f_setup()\n",
" st = time.perf_counter_ns()\n",
" f()\n",
" ed = time.perf_counter_ns()\n",
" latency.append((ed-st)/1e9)\n",
" return np.array(latency)\n",
"\n",
"def tail_mean(xs, skip=0.2):\n",
" return xs[int(len(xs) * skip):].mean()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02c9c9b1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x7c5afc12b850>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"torch.set_grad_enabled(False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3405fdc7",
"metadata": {},
"outputs": [],
"source": [
"nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n",
"seqlen_list = [256]\n",
"bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "10dc981a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(12, 256), (3, 256)]\n",
"[256]\n",
"[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n"
]
}
],
"source": [
"print(nd_list)\n",
"print(seqlen_list)\n",
"print(bs_list)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7e0ee385",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n",
" seqlen_list = [1] + seqlen_list\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" maxbs = max(bs_list)\n",
" print(maxbs, n, d, seqlen)\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" def run():\n",
" torch.matmul(X[:bs], W)\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, X, W\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c206a502",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" try:\n",
" maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n",
" except ValueError:\n",
" pbar.update(len(bs_list))\n",
" continue\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" if bs > maxbs:\n",
" pbar.update()\n",
" continue\n",
" Q = Qmax[:bs]\n",
" K = Kmax[:bs]\n",
" def run():\n",
" torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, Q, K, Qmax, Kmax\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a3a2103c",
"metadata": {},
"outputs": [],
"source": [
"def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n",
" total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n",
" pbar = tqdm(total=total)\n",
" for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n",
" h = n * d\n",
" try:\n",
" maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n",
" except ValueError:\n",
" pbar.update(len(bs_list))\n",
" continue\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n",
" Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n",
" torch.cuda.synchronize()\n",
" for bs in reversed(bs_list):\n",
" pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n",
" if bs > maxbs:\n",
" pbar.update()\n",
" continue\n",
" Q = Qmax[:bs]\n",
" K = Kmax[:bs]\n",
" def run():\n",
" torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n",
" torch.cuda.synchronize()\n",
" def clear_cache():\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n",
" l = tail_mean(latency)\n",
" out.append({\n",
" \"n\": n,\n",
" \"d\": d,\n",
" \"seqlen\": seqlen,\n",
" \"bs\": bs,\n",
" \"latency\": l\n",
" })\n",
" pbar.update()\n",
" del cache, Q, K, Qmax, Kmax\n",
" torch.cuda.empty_cache()\n",
" pbar.close()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3aaad98a",
"metadata": {},
"outputs": [],
"source": [
"data = {}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "18137de3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/22 [00:00<?, ?it/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 22/22 [00:44<00:00, 2.04s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_qk_init(db, nd_list, seqlen_list, bs_list)\n",
"data[\"qk_init\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "26c76e15",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 22/22 [00:44<00:00, 2.01s/it, bs=2, d=256, h=3072, n=12, seqlen=256] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_qk_ar(db, nd_list, seqlen_list, bs_list)\n",
"data[\"qk_ar\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "313e36eb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/44 [00:00<?, ?it/s, bs=2048, d=256, h=768, n=3, seqlen=256]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 3 256 256\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 25%|██▌ | 11/44 [00:22<01:06, 2.00s/it, bs=2048, d=256, h=768, n=3, seqlen=1] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 3 256 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|█████ | 22/44 [00:44<00:44, 2.00s/it, bs=2048, d=256, h=3072, n=12, seqlen=256]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 12 256 256\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 75%|███████▌ | 33/44 [01:07<00:22, 2.02s/it, bs=2048, d=256, h=3072, n=12, seqlen=1] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2048 12 256 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 44/44 [01:29<00:00, 2.03s/it, bs=2, d=256, h=3072, n=12, seqlen=1] \n"
]
}
],
"source": [
"db = []\n",
"benchmark_dense(db, nd_list, seqlen_list, bs_list)\n",
"data[\"dense\"] = db"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "50c37959",
"metadata": {},
"outputs": [],
"source": [
"with gzip.open(\"data/20230516-transformer-batching1.pkl.gz\", \"wb\") as f:\n",
" pickle.dump(data, f)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "828ddb54",
"metadata": {},
"outputs": [],
"source": [
"df_dense = (\n",
" pd.DataFrame.from_dict(data[\"dense\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"] * x[\"seqlen\"] * x[\"h\"]**2) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"seqlen\"]*x[\"h\"]*2 + x[\"h\"]**2) * 2/x['latency']/1e9)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
" .assign(series=\"dense\")\n",
")\n",
"df_qk_init = (\n",
" pd.DataFrame.from_dict(data[\"qk_init\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]**2) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"seqlen\"]*x[\"d\"]*2 + x[\"seqlen\"]**2)) * 2/x['latency']/1e9)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"flop\"] / x[\"latency\"])\n",
" .assign(series=\"qk_init\")\n",
")\n",
"df_qk_ar = (\n",
" pd.DataFrame.from_dict(data[\"qk_ar\"])\n",
" .assign(h=lambda x: x[\"n\"] * x[\"d\"])\n",
" .assign(flop=lambda x: (x[\"bs\"]*x[\"n\"]*x[\"d\"]*x[\"seqlen\"]) * 2)\n",
" .assign(io=lambda x: (x[\"bs\"]*x[\"n\"]*(x[\"d\"] + x[\"seqlen\"]*x[\"d\"] + x[\"seqlen\"])) * 2)\n",
" .assign(intensity=lambda x: x[\"flop\"] / x[\"io\"])\n",
" .assign(throughput=lambda x: x[\"bs\"] / x[\"latency\"])\n",
" .assign(series=\"qk_ar\")\n",
")\n",
"pd.concat([df_dense, df_qk_init, df_qk_ar]).to_csv(\"data/transformer-batching-microbenchmarks.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "c296a395",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<module 'pandas' from '/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/pandas/__init__.py'>"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a25cdd5a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "63b8a531",
"metadata": {},
"outputs": [],
"source": [
"import transformers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af90eff1",
"metadata": {},
"outputs": [],
"source": [
"def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n",
" return transformers.OPTConfig(\n",
" num_hidden_layers=n_layers,\n",
" hidden_size=d_model,\n",
" ffn_dim=d_model*4,\n",
" num_attention_heads=n_heads,\n",
" **kwargs\n",
" )\n",
"optcfg = {\n",
" # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n",
" \"125m\": _gen_opt_cfg(12, 768, 12),\n",
" \"350m\": _gen_opt_cfg(24, 1024, 16),\n",
" \"760m\": _gen_opt_cfg(24, 1536, 16),\n",
" \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n",
" \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n",
" \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n",
" \"13b\": _gen_opt_cfg(40, 5120, 40),\n",
" \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n",
" \"30b\": _gen_opt_cfg(48, 7168, 56),\n",
" \"66b\": _gen_opt_cfg(64, 9216, 72),\n",
" \"175b\": _gen_opt_cfg(96, 12288, 96),\n",
" \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b9ebbec",
"metadata": {},
"outputs": [],
"source": [
"def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n",
" bs, tgt_len = input_ids.shape\n",
" if past_key_values is not None:\n",
" _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n",
" assert bs == _bs\n",
" else:\n",
" src_len = 0\n",
" if attention_mask is None:\n",
" attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n",
" ret = model(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" past_key_values=past_key_values,\n",
" use_cache=True, output_hidden_states=False, return_dict=True,\n",
" )\n",
" return ret\n",
"\n",
"def time_greedy_generate(model, input_ids, new_tokens):\n",
" ts = []\n",
" output = input_ids\n",
" past_key_values = None\n",
" cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n",
" attention_mask = torch.ones(input_ids.shape, device=model.device) \n",
" for _ in range(new_tokens):\n",
" cache.zero_()\n",
" torch.cuda.synchronize()\n",
" st = time.perf_counter_ns()\n",
" \n",
" ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n",
" input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n",
" output = torch.cat([output, input_ids], axis=1)\n",
" past_key_values = ret.past_key_values\n",
" attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n",
" \n",
" torch.cuda.synchronize()\n",
" ed = time.perf_counter_ns()\n",
" ts.append((ed-st)/1e9)\n",
" return np.array(ts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc92f940",
"metadata": {},
"outputs": [],
"source": [
"opt_config = optcfg[\"6.7b\"]\n",
"\n",
"torch.set_default_dtype(torch.bfloat16)\n",
"with transformers.modeling_utils.no_init_weights():\n",
" model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n",
"torch.set_default_dtype(torch.float32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c19fa396",
"metadata": {},
"outputs": [],
"source": [
"db = {}\n",
"input_tokens = 200\n",
"new_tokens = 500\n",
"for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n",
" x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n",
" stack = []\n",
" for _ in range(10):\n",
" l = time_greedy_generate(model, x, new_tokens=new_tokens)\n",
" stack.append(l)\n",
" db[bs] = np.median(np.stack(stack), axis=0)\n",
" del x\n",
" torch.cuda.empty_cache()\n",
"del model\n",
"torch.cuda.empty_cache()\n",
"\n",
"with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n",
" pickle.dump(db, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -1,165 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# Set plot parameters
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# Path settings
FIGURE_PATH = "./paper_plot/figures"
# Load accuracy data
acc_data = pd.read_csv("./paper_plot/data/acc.csv")
# Create figure with 4 subplots (one for each dataset)
fig, axs = plt.subplots(1, 4)
fig.set_size_inches(9, 2.5)
# Reduce the spacing between subplots
# plt.subplots_adjust(wspace=0.2) # Reduced from 0.3 to 0.1
# Define datasets and their columns
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
metrics = ["Exact Match", "F1"]
# Define bar settings - make bars thicker
# total_width, n = 0.9, 3 # increased total width and n for three models
# width = total_width / n
# The 'width' variable below now defines the distance between the centers of adjacent bars within a group.
# It's also used as the base for calculating the actual plotted bar width.
# Original 2 bars had centers 1.0 apart. For 3 bars, we need a smaller distance.
# A value of 0.64 for distance between centers, with a scaling factor of 0.8 for bar width,
# results in an actual bar width of ~0.51, and a group span of ~1.79, similar to original's ~1.76.
n = 3 # Number of models
width = 0.64 # Distance between centers of adjacent bars in a group
bar_width_plotting_factor = 0.8 # Bar takes 80% of the space defined by 'width'
# Colors and hatches
edgecolors = ["dimgrey", "#63B8B6", "tomato"] # Added color for PQ 5
hatches = ["/////", "xxxxx", "\\\\\\\\\\"] # Added hatch for PQ 5
labels = ["BM25", "PQ Compressed", "Ours"] # Added PQ 5
# Create plots for each dataset
for i, dataset in enumerate(datasets):
ax = axs[i]
# Get data for this dataset and convert to percentages
em_values = [
acc_data.loc[0, f"{dataset} Exact Match"] * 100,
acc_data.loc[1, f"{dataset} Exact Match"] * 100,
acc_data.loc[2, f"{dataset} Exact Match"] * 100 # Added PQ 5 EM data
]
f1_values = [
acc_data.loc[0, f"{dataset} F1"] * 100,
acc_data.loc[1, f"{dataset} F1"] * 100,
acc_data.loc[2, f"{dataset} F1"] * 100 # Added PQ 5 F1 data
]
# Define x positions for bars
# For EM: center - width, center, center + width
# For F1: center - width, center, center + width
group_centers = [1.0, 3.0] # Centers for EM and F1 groups
bar_offsets = [-width, 0, width]
# Plot all bars on the same axis
for metric_idx, metric_group_center in enumerate(group_centers):
values_to_plot = em_values if metric_idx == 0 else f1_values
for j, model_label in enumerate(labels):
x_pos = metric_group_center + bar_offsets[j]
bar_value = values_to_plot[j]
ax.bar(
x_pos,
bar_value,
width=width * bar_width_plotting_factor, # Use the new factor for bar width
color="white",
edgecolor=edgecolors[j],
hatch=hatches[j],
linewidth=1.5,
label=model_label if i == 0 and metric_idx == 0 else None # Label only once
)
# Add value on top of bar
ax.text(x_pos, bar_value + (0.1 if dataset == "GPQA" else 0.1),
f"{bar_value:.1f}", ha='center', va='bottom',
fontsize=9, fontweight='bold') # Reduced fontsize for text on bars
# Set x-ticks and labels
ax.set_xticks(group_centers) # Position ticks at the center of each group
xticklabels = ax.set_xticklabels(metrics, fontsize=12)
# Now, shift these labels slightly to the right
# Adjust this value to control the amount of shift (in data coordinates)
# Given your group_centers are 1.0 and 3.0, a small value like 0.05 to 0.15 might be appropriate.
# horizontal_shift = 0.7 # Try adjusting this value
# for label in xticklabels:
# # Get the current x position (which is the tick location)
# current_x_pos = label.get_position()[0]
# # Set the new x position by adding the shift
# label.set_position((current_x_pos + horizontal_shift, label.get_position()[1]))
# # Ensure the label remains horizontally centered on this new x position
# # (set_xticklabels defaults to 'center', so this re-affirms it if needed)
# label.set_horizontalalignment('center')
# Set title
ax.set_title(dataset, fontsize=14)
# Set y-label for all subplots
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
else:
# Hide y-tick labels for non-first subplots to save space
ax.tick_params(axis='y', labelsize=10)
# Set y-limits based on data range
all_values = em_values + f1_values
max_val = max(all_values)
min_val = min(all_values)
# Special handling for GPQA which has very low values
if dataset == "GPQA":
ax.set_ylim(0, 10.0) # Set a fixed range for GPQA
else:
# Reduce the extra space above the bars
ax.set_ylim(min_val * 0.9, max_val * 1.1) # Adjusted upper limit for text
# Format y-ticks as percentages
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
# Set x-limits to properly space the bars with less blank space
# ax.set_xlim(group_centers[0] - total_width, group_centers[1] + total_width)
# Set xlim to be similar to original (0,4) for group_centers (1,3) => margin of 1.0
ax.set_xlim(group_centers[0] - 1.0, group_centers[1] + 1.0)
# Add a box around the subplot
# for spine in ax.spines.values():
# spine.set_visible(True)
# spine.set_linewidth(1.0)
# Add legend to first subplot
if i == 0:
ax.legend(
bbox_to_anchor=(2.21, 1.35), # Adjusted anchor if needed
ncol=3, # Changed to 3 columns for three labels
loc="upper center",
labelspacing=0.1,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=1.0,
handletextpad=0.6,
columnspacing=0.8,
prop={"weight": "bold", "size": 12},
)
# Save figure with tight layout but no additional padding
plt.savefig(FIGURE_PATH + "/accuracy_comparison.pdf", bbox_inches='tight', pad_inches=0.05)
plt.show()

View File

@@ -1,309 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /hnsw_degree_visit_plot_binned_academic.py
# \brief: Generates a binned bar plot of HNSW node average per-query visit probability
# per degree bin, styled for academic publications, with caching.
# Author: raphael hao (Original script by user, styling and caching adapted by Gemini)
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
from collections import Counter
import os # For robust filepath manipulation
import math # For calculating scaling factor
import pickle # For caching data
# %%
# --- Matplotlib parameters for academic paper style (from reference) ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Use LaTeX for text rendering (if available)
# --- Define styles from reference ---
edgecolors_ref = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
# %%
# --- File Paths ---
degree_file = '/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/degree_distribution.txt'
visit_log_file = './re.log'
output_image_file = './paper_plot/figures/hnsw_visit_count_per_degree_corrected.pdf'
# --- CACHE FILE PATH: Keep this consistent ---
CACHE_FILE_PATH = './binned_plot_data_cache.pkl'
# --- Configuration ---
# Set to True to bypass cache and force recomputation.
# Otherwise, delete CACHE_FILE_PATH manually to force recomputation.
FORCE_RECOMPUTE = False
NUMBER_OF_QUERIES = 1000.0 # Number of queries the visit_counts are based on
# Create directory for figures if it doesn't exist
output_dir = os.path.dirname(output_image_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Created directory: {output_dir}")
# %%
# --- Attempt to load data from cache or compute ---
df_plot_data = None
bin_size_for_plot = None # Will hold the bin_size associated with df_plot_data
if not FORCE_RECOMPUTE and os.path.exists(CACHE_FILE_PATH):
try:
with open(CACHE_FILE_PATH, 'rb') as f:
cache_content = pickle.load(f)
df_plot_data = cache_content['data']
bin_size_for_plot = cache_content['bin_size']
# Basic validation of cached data
# Expecting 'average_visit_count_per_node_in_bin' (raw average over NUMBER_OF_QUERIES)
if not isinstance(df_plot_data, pd.DataFrame) or \
'degree_bin_label' not in df_plot_data.columns or \
'average_visit_count_per_node_in_bin' not in df_plot_data.columns or \
not isinstance(bin_size_for_plot, int):
print("Cached data is not in the expected format or missing 'average_visit_count_per_node_in_bin'. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
else:
print(f"Successfully loaded binned data from cache: {CACHE_FILE_PATH}")
# --- Modify the label loaded from cache for display purpose ---
# This modification only happens when data is loaded from cache and meets specific conditions.
# Assumption: If the bin_size_for_plot in cache is 5,
# then the original label "0-4" actually represents nodes with degree 1-4 (because you guarantee no 0-degree nodes).
if df_plot_data is not None and 'degree_bin_label' in df_plot_data.columns and bin_size_for_plot == 5:
# Check if "0-4" label exists
if '0-4' in df_plot_data['degree_bin_label'].values:
# Use .loc to ensure the modification is on the original DataFrame
df_plot_data.loc[df_plot_data['degree_bin_label'] == '0-4', 'degree_bin_label'] = '1-4'
print("Modified degree_bin_label from '0-4' to '1-4' for display purpose.")
except Exception as e:
print(f"Error loading from cache: {e}. Recomputing.")
df_plot_data = None # Invalidate to trigger recomputation
if df_plot_data is None:
print("Cache not found, invalid, or recompute forced. Computing data from scratch...")
# --- 1. Read Degree Distribution File ---
degrees_data = []
try:
with open(degree_file, 'r') as f:
for i, line in enumerate(f):
line_stripped = line.strip()
if line_stripped:
degrees_data.append({'node_id': i, 'degree': int(line_stripped)})
except FileNotFoundError:
print(f"Error: Degree file '{degree_file}' not found. Using dummy data for degrees.")
degrees_data = [{'node_id': i, 'degree': (i % 20) + 1 } for i in range(200)]
degrees_data.extend([{'node_id': 200+i, 'degree': i} for i in range(58, 67)]) # For 60-64 bin
degrees_data.extend([{'node_id': 300+i, 'degree': (i % 5)+1} for i in range(10)]) # Low degrees
degrees_data.extend([{'node_id': 400+i, 'degree': 80 + (i%5)} for i in range(10)]) # High degrees
if not degrees_data:
print(f"Critical Error: No data loaded or generated for degrees. Exiting.")
exit()
df_degrees = pd.DataFrame(degrees_data)
print(f"Successfully loaded/generated {len(df_degrees)} degree entries.")
# --- 2. Read Visit Log File and Count Frequencies ---
visit_counts = Counter()
node_id_pattern = re.compile(r"Vis(i)?ted node: (\d+)")
try:
with open(visit_log_file, 'r') as f_log:
for line_num, line in enumerate(f_log, 1):
match = node_id_pattern.search(line)
if match:
try:
node_id = int(match.group(2))
visit_counts[node_id] += 1 # Increment visit count for the node
except ValueError:
print(f"Warning: Non-integer node_id in log '{visit_log_file}' line {line_num}: {line.strip()}")
except FileNotFoundError:
print(f"Warning: Visit log file '{visit_log_file}' not found. Using dummy visit counts.")
if not df_degrees.empty:
for node_id_val in df_degrees['node_id'].sample(frac=0.9, random_state=1234): # Seed for reproducibility
degree_val = df_degrees[df_degrees['node_id'] == node_id_val]['degree'].iloc[0]
# Generate visit counts to test different probability magnitudes
if node_id_val % 23 == 0: # Very low probability
lambda_val = 0.0005 * (100 / (max(1,degree_val) + 1)) # avg visits over 1k queries
elif node_id_val % 11 == 0: # Low probability
lambda_val = 0.05 * (100 / (max(1,degree_val) + 1))
elif node_id_val % 5 == 0: # Moderate probability
lambda_val = 2.5 * (100 / (max(1,degree_val) + 1))
else: # Higher probability (but still < 1000 visits for a single node usually)
lambda_val = 50 * (100 / (max(1,degree_val) + 1))
visit_counts[node_id_val] = np.random.poisson(lambda_val)
if visit_counts[node_id_val] < 0: visit_counts[node_id_val] = 0
if not visit_counts:
print(f"Warning: No visit data parsed/generated. Plot may show zero visits.")
df_visits = pd.DataFrame(columns=['node_id', 'visit_count'])
else:
df_visits_list = [{'node_id': nid, 'visit_count': count} for nid, count in visit_counts.items()]
df_visits = pd.DataFrame(df_visits_list)
print(f"Parsed/generated {len(df_visits)} unique visited nodes, totaling {sum(visit_counts.values())} visits (simulated over {NUMBER_OF_QUERIES} queries).")
# --- 3. Merge Degree Data with Visit Data ---
df_merged = pd.merge(df_degrees, df_visits, on='node_id', how='left')
df_merged['visit_count'] = df_merged['visit_count'].fillna(0).astype(float) # visit_count is total over NUMBER_OF_QUERIES
print(f"Merged data contains {len(df_merged)} entries.")
# --- 5. Binning Degrees and Calculating Average Visit Count per Node in Bin (over NUMBER_OF_QUERIES) ---
current_bin_size = 5
bin_size_for_plot = current_bin_size
if not df_degrees.empty:
print(f"\nBinning degrees into groups of {current_bin_size} for average visit count calculation...")
df_merged_with_bins = df_merged.copy()
df_merged_with_bins['degree_bin_start'] = (df_merged_with_bins['degree'] // current_bin_size) * current_bin_size
df_binned_analysis = df_merged_with_bins.groupby('degree_bin_start').agg(
total_visit_count_in_bin=('visit_count', 'sum'),
node_count_in_bin=('node_id', 'nunique')
).reset_index()
# This is the average number of times a node in this bin was visited over NUMBER_OF_QUERIES queries.
# This value is what gets cached.
df_binned_analysis['average_visit_count_per_node_in_bin'] = 0.0
df_binned_analysis.loc[df_binned_analysis['node_count_in_bin'] > 0, 'average_visit_count_per_node_in_bin'] = \
df_binned_analysis['total_visit_count_in_bin'] / df_binned_analysis['node_count_in_bin']
df_binned_analysis['degree_bin_label'] = df_binned_analysis['degree_bin_start'].astype(str) + '-' + \
(df_binned_analysis['degree_bin_start'] + current_bin_size - 1).astype(str)
bin_to_drop_label = '60-64'
original_length = len(df_binned_analysis)
df_plot_data_intermediate = df_binned_analysis[df_binned_analysis['degree_bin_label'] != bin_to_drop_label].copy()
if len(df_plot_data_intermediate) < original_length:
print(f"\nManually dropped the bin: '{bin_to_drop_label}'")
else:
print(f"\nNote: Bin '{bin_to_drop_label}' not found for dropping or already removed.")
df_plot_data = df_plot_data_intermediate
print(f"\nBinned data (average visit count per node in bin over {NUMBER_OF_QUERIES} queries) for plotting prepared:")
print(df_plot_data[['degree_bin_label', 'average_visit_count_per_node_in_bin']].head())
if df_plot_data is not None and not df_plot_data.empty:
try:
with open(CACHE_FILE_PATH, 'wb') as f:
pickle.dump({'data': df_plot_data, 'bin_size': bin_size_for_plot}, f)
print(f"Saved computed binned data to cache: {CACHE_FILE_PATH}")
except Exception as e:
print(f"Error saving data to cache: {e}")
elif df_plot_data is None or df_plot_data.empty:
print("Computed data for binned plot is empty, not saving to cache.")
else:
print("Degree data (df_degrees) is empty. Cannot perform binning.")
df_plot_data = pd.DataFrame()
bin_size_for_plot = current_bin_size
# %%
# --- 6. Plotting (Binned Bar Chart - Academic Style) ---
if df_plot_data is not None and not df_plot_data.empty and 'average_visit_count_per_node_in_bin' in df_plot_data.columns:
base_name, ext = os.path.splitext(output_image_file)
# --- OUTPUT PDF FILE NAME: Keep this consistent ---
binned_output_image_file = base_name + ext
fig, ax = plt.subplots(figsize=(6, 2.5)) # Adjusted figure size
df_plot_data_plotting = df_plot_data.copy()
# Calculate per-query probability: (avg visits over N queries) / N
df_plot_data_plotting['per_query_visit_probability'] = \
df_plot_data_plotting['average_visit_count_per_node_in_bin'] / NUMBER_OF_QUERIES
max_probability = df_plot_data_plotting['per_query_visit_probability'].max()
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability']
y_axis_label = r"Per-Query Node Visit Probability in Bin" # Base label
apply_scaling_to_label_and_values = False # Initialize flag
exponent_for_label_display = 0 # Initialize exponent
if pd.notna(max_probability) and max_probability > 0:
potential_exponent = math.floor(math.log10(max_probability))
if potential_exponent <= -4 or potential_exponent >= 0:
apply_scaling_to_label_and_values = True
exponent_for_label_display = potential_exponent
# No specific adjustment for potential_exponent >=0 here, it's handled by the general logic.
if apply_scaling_to_label_and_values:
y_axis_label = rf"Visit Probability ($\times 10^{{{exponent_for_label_display}}}$)"
y_axis_values_to_plot = df_plot_data_plotting['per_query_visit_probability'] / (10**exponent_for_label_display)
print(f"Plotting with Max per-query probability: {max_probability:.2e}, Exponent for label: {exponent_for_label_display}. Y-axis values scaled for plot.")
else:
print(f"Plotting with Max per-query probability: {max_probability:.2e}. Plotting direct probabilities without label scaling (exponent {potential_exponent} is within no-scale range [-3, -1]).")
elif pd.notna(max_probability) and max_probability == 0:
print("Max per-query probability is 0. Plotting direct probabilities (all zeros).")
else:
print(f"Max per-query probability is NaN or invalid ({max_probability}). Plotting direct probabilities without scaling if possible.")
ax.bar(
df_plot_data_plotting['degree_bin_label'],
y_axis_values_to_plot,
color='white',
edgecolor=edgecolors_ref[0],
linewidth=1.5,
width=0.8
)
ax.set_xlabel('Node Degree', fontsize=10.5, labelpad=6)
# MODIFIED LINE: Added labelpad to move the y-axis label to the left
ax.set_ylabel(y_axis_label, fontsize=10.5, labelpad=10)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, pos: f"{x:.0f}%"))
num_bins = len(df_plot_data_plotting)
if num_bins > 12:
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=9)
elif num_bins > 8:
ax.tick_params(axis='x', labelsize=9)
else:
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
padding_factor = 0.05
current_max_y_on_axis = y_axis_values_to_plot.max()
upper_y_limit = 0.1 # Default small upper limit
if pd.notna(current_max_y_on_axis):
if current_max_y_on_axis > 0:
# Adjust minimum visible range based on whether scaling was applied and the exponent
min_meaningful_limit = 0.01
if apply_scaling_to_label_and_values and exponent_for_label_display >= 0 : # Numbers on axis are smaller due to positive exponent scaling
min_meaningful_limit = 0.1 # If original numbers were e.g. 2500 (2.5 x 10^3), scaled axis is 2.5, 0.1 is fine
elif not apply_scaling_to_label_and_values and pd.notna(max_probability) and max_probability >=1: # Direct large probabilities
min_meaningful_limit = 1 # If max prob is 2.5 (250%), axis value 2.5, needs larger base limit
upper_y_limit = max(min_meaningful_limit, current_max_y_on_axis * (1 + padding_factor))
else: # current_max_y_on_axis is 0
upper_y_limit = 0.1
ax.set_ylim(0, upper_y_limit)
else:
ax.set_ylim(0, 1.0) # Default for empty or NaN data
plt.tight_layout()
plt.savefig(binned_output_image_file, bbox_inches="tight", dpi=300)
print(f"Binned bar chart saved to {binned_output_image_file}")
plt.show()
plt.close(fig)
else:
if df_plot_data is None:
print("Data for plotting (df_plot_data) is None. Skipping plot generation.")
elif df_plot_data.empty:
print("Data for plotting (df_plot_data) is empty. Skipping plot generation.")
elif 'average_visit_count_per_node_in_bin' not in df_plot_data.columns:
print("Essential column 'average_visit_count_per_node_in_bin' is missing in df_plot_data. Skipping plot generation.")
# %%
print("Script finished.")

View File

@@ -1,7 +0,0 @@
In this paper, we present LiteANN, a storage-efficient approximate nearest neighbor (ANN) search index optimized for resource-constrained personal devices. LiteANN combines a compact graph-based structure with an efficient on-the-fly recomputation strategy to enable fast and accurate retrieval wih minimal storage overhead. Our evaluation shows that LiteANN reduces index size to under 5% of the original raw data up to 50× smaller than standard indexes while achieving 90% top-3 recall in under 2 seconds on real-world question-answering benchmarks.

View File

@@ -1,81 +0,0 @@
import numpy as np
import os
# --- Configuration for Data Paths and Labels (Mirrors plotting script for consistency) ---
BIG_GRAPH_PATHS = [
"/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/hnsw/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/99_4_degree_based_hnsw_IP_M32_efC256/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/d9_hnsw_IP_M8_efC128/",
"/opt/dlami/nvme/scaling_out/embeddings/facebook/contriever-msmarco/rpj_wiki/1-shards/indices/half_edges_IP_M32_efC128/"
]
STATS_FILE_NAME = "degree_distribution.txt"
BIG_GRAPH_LABELS = [ # These will be used as keys in the cached file
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
# Average degrees are static and can be directly used in the plotting script or also cached.
# For simplicity here, we'll focus on caching the dynamic degree arrays.
# BIG_GRAPH_AVG_DEG = [18, 9, 9, 9]
# --- Cache File Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz" # Using .npz for multiple arrays
def create_degree_data_cache():
"""
Reads degree distribution data from specified text files and saves it
into a compressed NumPy (.npz) cache file.
"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
cached_data = {}
print(f"Starting data caching process for {len(BIG_GRAPH_PATHS)} graph types...")
for i, base_path in enumerate(BIG_GRAPH_PATHS):
method_label = BIG_GRAPH_LABELS[i]
degree_file_path = os.path.join(base_path, STATS_FILE_NAME)
print(f"Processing: {method_label} from {degree_file_path}")
try:
# Load degrees as integers
degrees = np.loadtxt(degree_file_path, dtype=int)
if degrees.size == 0:
print(f" [WARN] Degree file is empty: {degree_file_path}. Storing as empty array for {method_label}.")
# Store an empty array or handle as needed. For npz, an empty array is fine.
cached_data[method_label] = np.array([], dtype=int)
else:
# Store the loaded degrees array with the method label as the key
cached_data[method_label] = degrees
print(f" [INFO] Loaded {len(degrees)} degrees for {method_label}. Max degree: {np.max(degrees) if degrees.size > 0 else 'N/A'}")
except FileNotFoundError:
print(f" [ERROR] Degree file not found: {degree_file_path}. Skipping {method_label}.")
# Optionally store a placeholder or skip. For robustness, store None or an empty array.
# Storing None might require special handling when loading. Empty array is safer for np.load.
cached_data[method_label] = np.array([], dtype=int) # Store empty array if file not found
except Exception as e:
print(f" [ERROR] An error occurred loading {degree_file_path} for {method_label}: {e}")
cached_data[method_label] = np.array([], dtype=int) # Store empty array on other errors
if not cached_data:
print("[ERROR] No data was successfully processed or loaded. Cache file will not be created.")
return
try:
# Save all collected degree arrays into a single .npz file.
# Using savez_compressed for potentially smaller file size.
np.savez_compressed(cache_file_path, **cached_data)
print(f"\n[SUCCESS] Degree distribution data successfully cached to: {os.path.abspath(cache_file_path)}")
print("Cached arrays (keys):", list(cached_data.keys()))
except Exception as e:
print(f"\n[ERROR] Failed to save data to cache file {cache_file_path}: {e}")
if __name__ == "__main__":
print("--- Degree Distribution Data Caching Script ---")
create_degree_data_cache()
print("--- Caching script finished. ---")

View File

@@ -1,4 +0,0 @@
Model,NQ Exact Match,NQ F1,TriviaQA Exact Match,TriviaQA F1,GPQA Exact Match,GPQA F1,HotpotQA Exact Match,HotpotQA F1
BM25,0.192,0.277,0.406,0.474,0.020089,0.04524,0.162,0.239
PQ 5,0.2075,0.291,0.422,0.495,0.0201,0.0445,0.148,0.219
Ours,0.265,0.361,0.533,0.604,0.02008,0.0452,0.182,0.2729
1 Model NQ Exact Match NQ F1 TriviaQA Exact Match TriviaQA F1 GPQA Exact Match GPQA F1 HotpotQA Exact Match HotpotQA F1
2 BM25 0.192 0.277 0.406 0.474 0.020089 0.04524 0.162 0.239
3 PQ 5 0.2075 0.291 0.422 0.495 0.0201 0.0445 0.148 0.219
4 Ours 0.265 0.361 0.533 0.604 0.02008 0.0452 0.182 0.2729

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1296720e79196bbdf38f051043c1b054667803726a24036c0b6a87cedb204ea5
size 227482438

View File

@@ -1,21 +0,0 @@
2,1,512,1024,0.541,0.326,1.659509202
2,2,512,1024,0.979,0.621,1.576489533
2,4,512,1024,1.846,0.977,1.889457523
2,8,512,1024,3.575,1.943,1.83993824
2,16,512,1024,7.035,3.733,1.884543263
2,32,512,1024,15.655,8.517,1.838088529
2,64,512,1024,32.772,17.43,1.88020654
4,1,512,1024,2.675,1.38,1.938405797
4,2,512,1024,5.397,2.339,2.307396323
4,4,512,1024,10.672,4.944,2.158576052
4,8,512,1024,21.061,9.266,2.272933305
4,16,512,1024,46.332,18.334,2.527108105
4,32,512,1024,99.607,36.156,2.754923111
4,64,512,1024,186.348,72.356,2.575432583
8,1,512,1024,7.325,4.087,1.792268167
8,2,512,1024,14.109,7.491,1.883460152
8,4,512,1024,28.499,14.013,2.033754371
8,8,512,1024,65.222,27.453,2.375769497
8,16,512,1024,146.294,52.55,2.783901047
8,32,512,1024,277.099,103.61,2.674442621
8,64,512,1024,512.979,208.36,2.461984066
1 2 1 512 1024 0.541 0.326 1.659509202
2 2 2 512 1024 0.979 0.621 1.576489533
3 2 4 512 1024 1.846 0.977 1.889457523
4 2 8 512 1024 3.575 1.943 1.83993824
5 2 16 512 1024 7.035 3.733 1.884543263
6 2 32 512 1024 15.655 8.517 1.838088529
7 2 64 512 1024 32.772 17.43 1.88020654
8 4 1 512 1024 2.675 1.38 1.938405797
9 4 2 512 1024 5.397 2.339 2.307396323
10 4 4 512 1024 10.672 4.944 2.158576052
11 4 8 512 1024 21.061 9.266 2.272933305
12 4 16 512 1024 46.332 18.334 2.527108105
13 4 32 512 1024 99.607 36.156 2.754923111
14 4 64 512 1024 186.348 72.356 2.575432583
15 8 1 512 1024 7.325 4.087 1.792268167
16 8 2 512 1024 14.109 7.491 1.883460152
17 8 4 512 1024 28.499 14.013 2.033754371
18 8 8 512 1024 65.222 27.453 2.375769497
19 8 16 512 1024 146.294 52.55 2.783901047
20 8 32 512 1024 277.099 103.61 2.674442621
21 8 64 512 1024 512.979 208.36 2.461984066

View File

@@ -1,9 +0,0 @@
Dataset,Metric,Original,original + batch,original + two_level,original + two_level + batch
NQ,Latency,6.9,5.8,4.2,3.7
NQ,SpeedUp,1,1.18965517,1.64285714,1.86486486
TriviaQA,Latency,17.054,14.542,12.046,10.83
TriviaQA,SpeedUp,1,1.17274103,1.41573967,1.57469990
GPQA,Latency,9.164,7.639,6.798,5.77
GPQA,SpeedUp,1,1.19963346,1.34804354,1.58821490
HotpotQA,Latency,60.279,39.827,50.664,29.868
HotpotQA,SpeedUp,1,1.51352098,1.18977972,2.01817999
1 Dataset Metric Original original + batch original + two_level original + two_level + batch
2 NQ Latency 6.9 5.8 4.2 3.7
3 NQ SpeedUp 1 1.18965517 1.64285714 1.86486486
4 TriviaQA Latency 17.054 14.542 12.046 10.83
5 TriviaQA SpeedUp 1 1.17274103 1.41573967 1.57469990
6 GPQA Latency 9.164 7.639 6.798 5.77
7 GPQA SpeedUp 1 1.19963346 1.34804354 1.58821490
8 HotpotQA Latency 60.279 39.827 50.664 29.868
9 HotpotQA SpeedUp 1 1.51352098 1.18977972 2.01817999

View File

@@ -1,25 +0,0 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25,LLM_Gen_Time_1B,LLM_Gen_Time_3B,LLM_Gen_Time_7B
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,3.323,0.021,0.085,0.217,0.472
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,4.616,0,0.085,0.217,0.472
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,19.494,0,0.085,0.217,0.472
NQ,MAC,85%,0,0,0.152,2.199,1535.10,7.971,0.033,0.316,0.717,1.468
NQ,MAC,90%,0,0,0.37,2.936,2446.60,13.843,0,0.316,0.717,1.468
NQ,MAC,95%,0,0,1.207,4.191,4569.29,44.363,0,0.316,0.717,1.468
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,3.752,0.033,0.139,0.156,0.315
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,5.777,0,0.139,0.156,0.315
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,20.944,0,0.139,0.156,0.315
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,8.889,0.036,0.325,0.692,1.415
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,17.145,0,0.325,0.692,1.415
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,47.909,0,0.325,0.692,1.415
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,1.897,0.137,0.443,0.396,0.651
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,1.733,0,0.443,0.396,0.651
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,4.033,0,0.443,0.396,0.651
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,4.762,0.100,0.37,0.813,1.676
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,5.223,0,0.37,0.813,1.676
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,9.715,0,0.37,0.813,1.676
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,10.358,0.70,0.144,0.196,0.420
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,15.515,0,0.144,0.196,0.420
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,61.757,0,0.144,0.196,0.420
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,23.636,0.052,0.144,0.196,0.420
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,44.803,0,0.144,0.196,0.420
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,140.62,0,0.144,0.196,0.420
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25 LLM_Gen_Time_1B LLM_Gen_Time_3B LLM_Gen_Time_7B
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 3.323 0.021 0.085 0.217 0.472
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 4.616 0 0.085 0.217 0.472
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 19.494 0 0.085 0.217 0.472
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 7.971 0.033 0.316 0.717 1.468
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 13.843 0 0.316 0.717 1.468
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 44.363 0 0.316 0.717 1.468
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 3.752 0.033 0.139 0.156 0.315
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 5.777 0 0.139 0.156 0.315
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 20.944 0 0.139 0.156 0.315
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 8.889 0.036 0.325 0.692 1.415
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 17.145 0 0.325 0.692 1.415
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 47.909 0 0.325 0.692 1.415
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 1.897 0.137 0.443 0.396 0.651
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 1.733 0 0.443 0.396 0.651
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 4.033 0 0.443 0.396 0.651
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 4.762 0.100 0.37 0.813 1.676
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 5.223 0 0.37 0.813 1.676
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 9.715 0 0.37 0.813 1.676
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 10.358 0.70 0.144 0.196 0.420
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 15.515 0 0.144 0.196 0.420
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 61.757 0 0.144 0.196 0.420
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 23.636 0.052 0.144 0.196 0.420
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 44.803 0 0.144 0.196 0.420
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 140.62 0 0.144 0.196 0.420

View File

@@ -1,25 +0,0 @@
Dataset,Hardware,Recall_target,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,
NQ,A10,85%,0.046,1.656,0.017,2.996,482.53,4.243,
NQ,A10,90%,0.051,2.552,0.028,3.437,769.04,8.136,
NQ,A10,95%,0.055,5.163,0.070,5.602,1436.26,27.275,
NQ,MAC,85%,0,0,0.152,2.199,1535.10,10.672,
NQ,MAC,90%,0,0,0.37,2.936,2446.60,19.941,
NQ,MAC,95%,0,0,1.207,4.191,4569.29,61.383,
TriviaQA,A10,85%,0.042,1.772,0.032,2.464,560.5,5.612,
TriviaQA,A10,90%,0.043,3.541,0.057,3.651,997.81,10.737,
TriviaQA,A10,95%,0.053,7.168,0.090,5.458,2005.33,36.387,
TriviaQA,MAC,85%,0,0,0.481,1.875,1783.14787,12.825,
TriviaQA,MAC,90%,0,0,0.984,2.639,3174.410301,24.977,
TriviaQA,MAC,95%,0,0,1.578,3.884,6379.712245,85.734,
GPQA,A10,85%,0.041,0.134,0.024,0.048,40.16,2.269,
GPQA,A10,90%,0.042,0.174,0.034,0.06,54.71,3.200,
GPQA,A10,95%,0.045,0.292,0.051,0.11,97.67,7.445,
GPQA,MAC,85%,0,0,0.144,0.087,127.7707505,6.123,
GPQA,MAC,90%,0,0,0.288,0.108,174.0647409,8.507,
GPQA,MAC,95%,0,0,0.497,0.132,310.7380142,19.577,
HotpotQA,A10,85%,0.044,2.519,0.054,4.048,724.26,14.713,
HotpotQA,A10,90%,0.049,3.867,0.109,5.045,1173.67,33.561,
HotpotQA,A10,95%,0.07,10.928,0.412,8.659,3079.57,68.626,
HotpotQA,MAC,85%,0,0,0.974,2.844,2304.125187,34.783,
HotpotQA,MAC,90%,0,0,1.913,3.542,3415.736201,53.004,
HotpotQA,MAC,95%,0,0,5.783,6.764,9797.244043,95.413,
1 Dataset Hardware Recall_target HNSW IVF DiskANN IVF-Disk IVF-Recompute Our
2 NQ A10 85% 0.046 1.656 0.017 2.996 482.53 4.243
3 NQ A10 90% 0.051 2.552 0.028 3.437 769.04 8.136
4 NQ A10 95% 0.055 5.163 0.070 5.602 1436.26 27.275
5 NQ MAC 85% 0 0 0.152 2.199 1535.10 10.672
6 NQ MAC 90% 0 0 0.37 2.936 2446.60 19.941
7 NQ MAC 95% 0 0 1.207 4.191 4569.29 61.383
8 TriviaQA A10 85% 0.042 1.772 0.032 2.464 560.5 5.612
9 TriviaQA A10 90% 0.043 3.541 0.057 3.651 997.81 10.737
10 TriviaQA A10 95% 0.053 7.168 0.090 5.458 2005.33 36.387
11 TriviaQA MAC 85% 0 0 0.481 1.875 1783.14787 12.825
12 TriviaQA MAC 90% 0 0 0.984 2.639 3174.410301 24.977
13 TriviaQA MAC 95% 0 0 1.578 3.884 6379.712245 85.734
14 GPQA A10 85% 0.041 0.134 0.024 0.048 40.16 2.269
15 GPQA A10 90% 0.042 0.174 0.034 0.06 54.71 3.200
16 GPQA A10 95% 0.045 0.292 0.051 0.11 97.67 7.445
17 GPQA MAC 85% 0 0 0.144 0.087 127.7707505 6.123
18 GPQA MAC 90% 0 0 0.288 0.108 174.0647409 8.507
19 GPQA MAC 95% 0 0 0.497 0.132 310.7380142 19.577
20 HotpotQA A10 85% 0.044 2.519 0.054 4.048 724.26 14.713
21 HotpotQA A10 90% 0.049 3.867 0.109 5.045 1173.67 33.561
22 HotpotQA A10 95% 0.07 10.928 0.412 8.659 3079.57 68.626
23 HotpotQA MAC 85% 0 0 0.974 2.844 2304.125187 34.783
24 HotpotQA MAC 90% 0 0 1.913 3.542 3415.736201 53.004
25 HotpotQA MAC 95% 0 0 5.783 6.764 9797.244043 95.413

View File

@@ -1,3 +0,0 @@
Hardware,HNSW,IVF,DiskANN,IVF-Disk,IVF-Recompute,Our,BM25
RAM,190,171,10,0,0,0,0
Storage,185.4,171,240,171,0.5,5,59
1 Hardware HNSW IVF DiskANN IVF-Disk IVF-Recompute Our BM25
2 RAM 190 171 10 0 0 0 0
3 Storage 185.4 171 240 171 0.5 5 59

View File

@@ -1,12 +0,0 @@
Torch,8,55.592
Torch,16,75.439
Torch,32,110.025
Torch,64,186.496
Tutel,8,56.718
Tutel,16,82.121
Tutel,32,125.070
Tutel,64,216.191
BRT,8,56.725
BRT,16,79.291
BRT,32,93.180
BRT,64,118.923
1 Torch 8 55.592
2 Torch 16 75.439
3 Torch 32 110.025
4 Torch 64 186.496
5 Tutel 8 56.718
6 Tutel 16 82.121
7 Tutel 32 125.070
8 Tutel 64 216.191
9 BRT 8 56.725
10 BRT 16 79.291
11 BRT 32 93.180
12 BRT 64 118.923

View File

@@ -1,6 +0,0 @@
Disk cache size,0,2.5%(180G*2.5%),5%,8%,10%
Latency,,,,,
NQ,4.616,4.133,3.826,3.511,3.323
TriviaQA,5.777,4.979,4.553,4.141,3.916
GPQA,1.733,1.593,1.468,1.336,1.259
Hotpot,15.515,13.479,12.383,11.216,10.606
1 Disk cache size 0 2.5%(180G*2.5%) 5% 8% 10%
2 Latency
3 NQ 4.616 4.133 3.826 3.511 3.323
4 TriviaQA 5.777 4.979 4.553 4.141 3.916
5 GPQA 1.733 1.593 1.468 1.336 1.259
6 Hotpot 15.515 13.479 12.383 11.216 10.606

View File

@@ -1,151 +0,0 @@
import matplotlib
from matplotlib.axes import Axes
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
# plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "sans-serif" # Use generic sans-serif family
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % Use Helvetica font for text
\usepackage{sfmath} % Use sans-serif font for math
\renewcommand{\familydefault}{\sfdefault} % Set sans-serif as default text font
\usepackage[T1]{fontenc} % Recommended for font encoding
"""
# plt.rcParams['mathtext.fontset'] = 'dejavusans'
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# New data in dictionary format
datasets = ["NQ", "TriviaQA", "GPQA", "Hotpot"]
cache_ratios = ["4.2G\n (0\%)", "8.7G\n (2.5\%)", "13.2G\n (5\%)", "18.6G\n (8\%)", "22.2G\n (10\%)"]
latency_data = {
"NQ": [4.616, 4.133, 3.826, 3.511, 3.323],
"TriviaQA": [5.777, 4.979, 4.553, 4.141, 3.916],
"GPQA": [1.733, 1.593, 1.468, 1.336, 1.259],
"Hotpot": [15.515, 13.479, 12.383, 11.216, 10.606],
}
cache_hit_counts = {
"NQ": [0, 14.81, 23.36, 31.99, 36.73],
"TriviaQA": [0, 18.55, 27.99, 37.06, 41.86],
"GPQA": [0, 10.99, 20.31, 29.71, 35.01],
"Hotpot": [0, 17.47, 26.91, 36.2, 41.06]
}
# Create the figure with 4 subplots in a 2x2 grid
fig, axes_grid = plt.subplots(2, 2, figsize=(7,6))
axes = axes_grid.flatten() # Flatten the 2x2 grid to a 1D array
# Bar style settings
width = 0.7
x = np.arange(len(cache_ratios))
# Define hatch patterns for different cache ratios
hatch_patterns = ['//', '//', '//', '//', '//']
# Find max cache hit value across all datasets for unified y-axis
all_hit_counts = []
for dataset in datasets:
all_hit_counts.extend(cache_hit_counts[dataset])
max_unified_hit = max(all_hit_counts) * 1.13
for i, dataset in enumerate(datasets):
latencies = latency_data[dataset]
hit_counts = cache_hit_counts[dataset]
for j, val in enumerate(latencies):
container = axes[i].bar(
x[j],
val,
width=width,
color="white",
edgecolor="black",
linewidth=1.0,
zorder=10,
)
axes[i].bar_label(
container,
[f"{val:.2f}"],
fontsize=10,
zorder=200,
fontweight="bold",
)
axes[i].set_title(dataset, fontsize=font_size)
axes[i].set_xticks(x)
axes[i].set_xticklabels(cache_ratios, fontsize=12, rotation=0, ha='center', fontweight="bold")
max_val_ratios = [1.35, 1.65, 1.45, 1.75]
max_val = max(latencies) * max_val_ratios[i]
axes[i].set_ylim(0, max_val)
axes[i].tick_params(axis='y', labelsize=12)
if i % 2 == 0:
axes[i].set_ylabel("Latency (s)", fontsize=font_size)
axes[i].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
ax2: Axes = axes[i].twinx()
ax2.plot(x, hit_counts,
linestyle='--',
marker='o',
markersize=6,
linewidth=1.5,
color='k',
markerfacecolor='none',
zorder=20)
ax2.set_ylim(0, max_unified_hit)
ax2.tick_params(axis='y', labelsize=12)
if i % 2 == 1:
ax2.set_ylabel(r"Cache Hit (\%)", fontsize=font_size)
for j, val in enumerate(hit_counts):
if val > 0:
ax2.annotate(f"{val:.1f}%",
(x[j], val),
textcoords="offset points",
xytext=(0, 5),
ha='center',
va='bottom',
fontsize=10,
fontweight='bold')
# Create legend for both plots
bar_patch = mpatches.Patch(facecolor='white', edgecolor='black', label='Latency')
line_patch = Line2D([0], [0], color='black', linestyle='--', label='Cache Hit Rate')
# --- MODIFICATION FOR LEGEND AT THE TOP ---
fig.legend(handles=[bar_patch, line_patch],
loc='upper center', # Position the legend at the upper center
bbox_to_anchor=(0.5, 0.995), # Anchor point (0.5 means horizontal center of figure,
# 0.97 means 97% from the bottom, so near the top)
ncol=3,
fontsize=font_size-2)
# --- END OF MODIFICATION ---
# Set common x-axis label - you might want to add this back if needed
# fig.text(0.5, 0.02, "Disk Cache Size", ha='center', fontsize=font_size, fontweight='bold') # Adjusted y for potential bottom label
# --- MODIFICATION FOR TIGHT LAYOUT ---
# Adjust rect to make space for the legend at the top.
# (left, bottom, right, top_for_subplots)
# We want subplots to occupy space from y=0 up to y=0.93 (or similar)
# leaving the top portion (0.93 to 1.0) for the legend.
plt.tight_layout(rect=(0, 0, 1, 0.93)) # Ensure subplots are below the legend
# --- END OF MODIFICATION ---
# Create directory if it doesn't exist (optional, good practice)
import os
if not os.path.exists(SAVE_PTH):
os.makedirs(SAVE_PTH)
plt.savefig(f"{SAVE_PTH}/disk_cache_latency.pdf", dpi=300) # Changed filename slightly for testing
print(f"Save to {SAVE_PTH}/disk_cache_latency.pdf")
# plt.show() # Optional: to display the plot

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

View File

Binary file not shown.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 100 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 41 KiB

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

Binary file not shown.

View File

@@ -1,107 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /gpu_utilization_plot.py
# \brief: Plots GPU throughput vs. batch size to show utilization with equally spaced x-axis.
# Author: AI Assistant
import numpy as np
import pandas as pd # Using pandas for data structuring, similar to example
from matplotlib import pyplot as plt
# Apply styling similar to the example script
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"
# plt.rcParams["hatch.linewidth"] = 1.5 # Not used for line plots
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Enables LaTeX for text rendering
# New Benchmark data (4th set)
data = {
'batch_size': [1, 4, 8, 10, 16, 20, 32, 40, 64, 128, 256,],
'avg_time_s': [
0.0031, 0.0057, 0.0100, 0.0114, 0.0186, 0.0234,
0.0359, 0.0422, 0.0626, 0.1259, 0.2454,
],
'throughput_seq_s': [
318.10, 696.77, 798.95, 874.70, 859.58, 855.19,
890.80, 946.93, 1022.75, 1017.03, 1043.17,
]
}
benchmark_df = pd.DataFrame(data)
# Create the plot
# Increased width slightly for more x-axis labels
fig, ax = plt.subplots()
fig.set_size_inches(8, 5)
# Generate equally spaced x-coordinates (indices)
x_indices = np.arange(len(benchmark_df))
# Plotting throughput vs. batch size (using indices for x-axis)
ax.plot(
x_indices, # Use equally spaced indices for plotting
benchmark_df['throughput_seq_s'],
marker='o', # Add markers to data points
linestyle='-',
color="#63B8B6", # A color inspired by the example's 'edgecolors'
linewidth=2,
markersize=6,
# label="Model Throughput" # Label for legend if needed, but not showing legend by default
)
# Setting labels for axes
ax.set_xlabel("Batch Size", fontsize=14)
ax.set_ylabel("Throughput (sequences/second)", fontsize=14)
# Customizing Y-axis for the new data range:
# Start Y from 0 to include the anomalous low point and show full scale.
y_min_val = 200
# Round up y_max_val to the nearest 100, as max throughput > 1000
y_max_val = np.ceil(benchmark_df['throughput_seq_s'].max() / 100) * 100
ax.set_ylim((y_min_val, y_max_val))
# Set y-ticks every 100 units, ensuring the top tick is included.
ax.set_yticks(np.arange(y_min_val, y_max_val + 1, 100))
# Customizing X-axis for equally spaced ticks:
# Set tick positions to the indices
ax.set_xticks(x_indices)
# Set tick labels to the actual batch_size values
ax.set_xticklabels(benchmark_df['batch_size'])
ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate X-axis labels, fontsize 10
ax.tick_params(axis='y', labelsize=12)
# Add a light grid for better readability, common in academic plots
ax.grid(True, linestyle=':', linewidth=0.5, color='grey', alpha=0.7, zorder=0)
# Remove title (as requested)
# ax.set_title("GPU Throughput vs. Batch Size", fontsize=16) # Title would go here
# Optional: Add a legend if you have multiple lines or want to label the single line
# ax.legend(
# loc="center right", # Location might need adjustment due to data shape
# edgecolor="black",
# facecolor="white",
# framealpha=1.0,
# shadow=False,
# fancybox=False,
# prop={"weight": "bold", "size": 10}
# ).set_zorder(100)
# Adjust layout to prevent labels from being cut off
plt.tight_layout()
# Save the figure
output_filename = "./paper_plot/figures/gpu_throughput_vs_batch_size_equispaced.pdf"
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
print(f"Plot saved to {output_filename}")
# Display the plot (optional, depending on environment)
plt.show()
# %%
# This is just to mimic the '%%' cell structure from the example.
# No actual code needed here for this script.

View File

@@ -1,245 +0,0 @@
import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib.ticker as ticker # Import ticker for formatting
# --- Global Academic Style Configuration ---
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["ytick.direction"] = "out"
plt.rcParams["xtick.direction"] = "out"
plt.rcParams["axes.grid"] = False # Grid lines are off
plt.rcParams["text.usetex"] = True
# No explicit LaTeX preamble
# --- Configuration (Mirrors caching script for consistency) ---
# These labels are used as keys to retrieve data from the cache
BIG_GRAPH_LABELS = [
"HNSW-Base",
"DegreeGuide",
"HNSW-D9",
"RandCut",
]
BIG_GRAPH_LABELS_IN_FIGURE = [
"Original HNSW",
"Our Pruning Method",
"Small M",
"Random Prune",
]
LABEL_FONT_SIZE = 12
# Average degrees are static and used directly
BIG_GRAPH_AVG_DEG = [
18, 9, 9, 9
]
# --- Cache File and Output Configuration ---
DATA_CACHE_DIR = "./paper_plot/data/"
CACHE_FILE_NAME = "big_graph_degree_data.npz"
OUTPUT_DIR = "./paper_plot/figures/"
os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure output directory for figures exists
OUTPUT_FILE_BIG_GRAPH = os.path.join(OUTPUT_DIR, "degree_distribution.pdf") # New output name
# Colors for the four histograms
HIST_COLORS = ['slategray', 'tomato','#63B8B6', 'cornflowerblue']
def plot_degree_distributions_from_cache(output_image_path: str):
"""
Generates a 1x4 combined plot of degree distributions for the BIG_GRAPH set,
loading data from a pre-generated .npz cache file.
"""
cache_file_path = os.path.join(DATA_CACHE_DIR, CACHE_FILE_NAME)
if not os.path.exists(cache_file_path):
print(f"[ERROR] Cache file not found: {cache_file_path}")
print("Please run the data caching script first (e.g., cache_degree_data.py).")
return
try:
# Load the cached data
with np.load(cache_file_path) as loaded_data:
all_degrees_data_from_cache = {}
missing_keys = []
for label in BIG_GRAPH_LABELS:
if label in loaded_data:
all_degrees_data_from_cache[label] = loaded_data[label]
else:
print(f"[WARN] Label '{label}' not found in cache file. Plotting may be incomplete.")
all_degrees_data_from_cache[label] = np.array([], dtype=int) # Use empty array for missing data
missing_keys.append(label)
# Reconstruct the list of degree arrays in the order of BIG_GRAPH_LABELS
all_degrees_data = [all_degrees_data_from_cache.get(label, np.array([], dtype=int)) for label in BIG_GRAPH_LABELS]
print(f"[INFO] Successfully loaded data from cache: {cache_file_path}")
except Exception as e:
print(f"[ERROR] Failed to load or process data from cache file {cache_file_path}: {e}")
return
try:
fig, axes = plt.subplots(2, 2, figsize=(7, 4), sharex=True, sharey=True)
axes = axes.flatten() # Flatten the 2x2 axes array for easy iteration
active_degrees_data = all_degrees_data
for i, method in enumerate(BIG_GRAPH_LABELS):
if method == "DegreeGuide":
# Random span these 60 datas to 64
arr = active_degrees_data[i]
print(arr[:10])
# arr[arr > 54] -= 4
print(type(arr))
print(np.max(arr))
arr2 = arr * 60 / 64
# print(np.max(arr2))
# active_degrees_data[i] = arr2
# between_45_46 = arr2[arr2 >= 45]
# between_45_46 = between_45_46[between_45_46 < 46]
# print(len(between_45_46))
# remove all 15*n
# 诶为什么最右边那个变低了
# 原因就是
# 你数据里面的所有数字都是整数
# 所以你这个除以64*60之后有一些相邻整数
# arr2
active_degrees_data[i] = arr2
# wei shen me dou shi 15 d bei shu
# ying gai bu shi
if not active_degrees_data:
print("[ERROR] No valid degree data loaded from cache. Cannot generate plot.")
if 'fig' in locals() and plt.fignum_exists(fig.number):
plt.close(fig)
return
overall_min_deg = min(np.min(d) for d in active_degrees_data)
overall_max_deg = max(np.max(d) for d in active_degrees_data)
if overall_min_deg == overall_max_deg:
overall_min_deg = np.floor(overall_min_deg - 0.5)
overall_max_deg = np.ceil(overall_max_deg + 0.5)
else:
overall_min_deg = np.floor(overall_min_deg - 0.5)
overall_max_deg = np.ceil(overall_max_deg + 0.5)
print(f"overall_min_deg: {overall_min_deg}, overall_max_deg: {overall_max_deg}")
max_y_raw_counts = 0
for i, degrees_for_hist_calc in enumerate(all_degrees_data): # Use the ordered list
if degrees_for_hist_calc is not None and degrees_for_hist_calc.size > 0:
min_deg_local = np.min(degrees_for_hist_calc)
max_deg_local = np.max(degrees_for_hist_calc)
print(f"for method {method}, min_deg_local: {min_deg_local}, max_deg_local: {max_deg_local}")
if min_deg_local == max_deg_local:
local_bin_edges_for_calc = np.array([np.floor(min_deg_local - 0.5), np.ceil(max_deg_local + 0.5)])
else:
num_local_bins_for_calc = int(np.ceil(max_deg_local + 0.5) - np.floor(min_deg_local - 0.5))
local_bin_edges_for_calc = np.linspace(np.floor(min_deg_local - 0.5),
np.ceil(max_deg_local + 0.5),
num_local_bins_for_calc + 1)
if i == 1:
unique_data = np.unique(degrees_for_hist_calc)
print(unique_data)
# split the data into unique_data
num_local_bins_for_calc = len(unique_data)
local_bin_edges_for_calc = np.concatenate([unique_data-0.1, [np.inf]])
counts, _ = np.histogram(degrees_for_hist_calc, bins=local_bin_edges_for_calc)
if counts.size > 0:
max_y_raw_counts = max(max_y_raw_counts, np.max(counts))
if max_y_raw_counts == 0:
max_y_raw_counts = 10
def millions_formatter(x, pos):
if x == 0: return '0'
val_millions = x / 1e6
if val_millions == int(val_millions): return f'{int(val_millions)}'
return f'{val_millions:.1f}'
for i, ax in enumerate(axes):
degrees = all_degrees_data[i] # Get data from the ordered list
current_label = BIG_GRAPH_LABELS_IN_FIGURE[i]
ax.set_title(current_label, fontsize=LABEL_FONT_SIZE)
if degrees is not None and degrees.size > 0:
min_deg_local_plot = np.min(degrees)
max_deg_local_plot = np.max(degrees)
if min_deg_local_plot == max_deg_local_plot:
plot_bin_edges = np.array([np.floor(min_deg_local_plot - 0.5), np.ceil(max_deg_local_plot + 0.5)])
else:
num_plot_bins = int(np.ceil(max_deg_local_plot + 0.5) - np.floor(min_deg_local_plot - 0.5))
plot_bin_edges = np.linspace(np.floor(min_deg_local_plot - 0.5),
np.ceil(max_deg_local_plot + 0.5),
num_plot_bins + 1)
if i == 1:
unique_data = np.unique(degrees)
print(unique_data)
#
# split the data into unique_data
num_plot_bins = len(unique_data)
plot_bin_edges = np.concatenate([unique_data-0.1, [unique_data[-1] + 0.8375]])
ax.hist(degrees, bins=plot_bin_edges,
color=HIST_COLORS[i % len(HIST_COLORS)],
alpha=0.85)
avg_deg_val = BIG_GRAPH_AVG_DEG[i]
ax.text(0.95, 0.88, f"Avg Degree: {avg_deg_val}",
transform=ax.transAxes, fontsize=15,
verticalalignment='top', horizontalalignment='right',
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', pad=0.3))
else:
ax.text(0.5, 0.5, 'Data unavailable', horizontalalignment='center',
verticalalignment='center', transform=ax.transAxes, fontsize=9)
ax.set_xlim(0, overall_max_deg)
ax.set_ylim(0, max_y_raw_counts * 1.12)
ax.set_yscale('log')
for spine_pos in ['top', 'right', 'bottom', 'left']:
ax.spines[spine_pos].set_edgecolor('black')
ax.spines[spine_pos].set_linewidth(1.0)
# ax.spines['top'].set_visible(False)
# ax.spines['right'].set_visible(False)
ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, length=4, width=1, labelsize=12)
ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=(i%2==0), length=4, width=1, labelsize=12)
# ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: millions_formatter(x, pos)))
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.ticklabel_format(style='plain', axis='x', useOffset=False)
axes[0].set_ylabel(r"Number of Nodes", fontsize=12)
axes[2].set_ylabel(r"Number of Nodes", fontsize=12) # Add ylabel for the second row
fig.text(0.54, 0.02, "Node Degree", ha='center', va='bottom', fontsize=15)
plt.tight_layout(rect=(0.06, 0.05, 0.98, 0.88))
plt.savefig(output_image_path, dpi=300, bbox_inches='tight', pad_inches=0.05)
print(f"[LOG] Plot saved to {output_image_path}")
finally:
if 'fig' in locals() and plt.fignum_exists(fig.number):
plt.close(fig)
if __name__ == "__main__":
if plt.rcParams["text.usetex"]:
print("INFO: LaTeX rendering is enabled via rcParams.")
else:
print("INFO: LaTeX rendering is disabled (text.usetex=False).")
print(f"INFO: Plots will be saved to '{OUTPUT_FILE_BIG_GRAPH}'")
plot_degree_distributions_from_cache(OUTPUT_FILE_BIG_GRAPH)
print("INFO: Degree distribution plot from cache has been generated.")

View File

@@ -1,330 +0,0 @@
# python faiss/demo/plot_graph_struct.py faiss/demo/output.log
# python faiss/demo/plot_graph_struct.py large_graph_recompute.log
import argparse
import re
import matplotlib.pyplot as plt
import numpy as np
# Modified recall_levels and corresponding styles/widths from previous step
recall_levels = [0.90, 0.92, 0.94, 0.96]
line_styles = ['--', '-', '-', '-']
line_widths = [1, 1.5, 1.5, 1.5]
MAPPED_METHOD_NAMES = [
# 'HNSW-Base',
# 'DegreeGuide',
# 'HNSW-D9',
# 'RandCut',
"Original HNSW",
"Our Pruning Method",
"Small M",
"Random Prune",
]
PERFORMANCE_PLOT_PATH = './paper_plot/figures/H_hnsw_performance_comparison.pdf'
SAVED_PATH = './paper_plot/figures/H_hnsw_recall_comparison.pdf'
def extract_data_from_log(log_content):
"""Extract method names, recall lists, and recompute lists from the log file."""
method_pattern = r"Building HNSW index with ([^\.]+)\.\.\.|Building HNSW index with ([^\n]+)..."
recall_list_pattern = r"recall_list: (\[[\d\., ]+\])"
recompute_list_pattern = r"recompute_list: (\[[\d\., ]+\])"
avg_neighbors_pattern = r"neighbors per node: ([\d\.]+)"
method_matches = re.findall(method_pattern, log_content)
# Temporary list for raw method identifiers from regex
_methods_raw_identifiers_regex = []
for match in method_matches:
method_ident = match[0] if match[0] else match[1]
_methods_raw_identifiers_regex.append(method_ident.strip().rstrip('.'))
recall_lists_str = re.findall(recall_list_pattern, log_content)
recompute_lists_str = re.findall(recompute_list_pattern, log_content)
avg_neighbors_str_list = re.findall(avg_neighbors_pattern, log_content) # Keep as string list for now
# Determine if regex approach was sufficient, similar to original logic
# This check helps decide if we use regex-extracted names or fallback to split-parsing
_min_len_for_regex_path = min(
len(_methods_raw_identifiers_regex) if _methods_raw_identifiers_regex else 0,
len(recall_lists_str) if recall_lists_str else 0,
len(recompute_lists_str) if recompute_lists_str else 0,
len(avg_neighbors_str_list) if avg_neighbors_str_list else 0
)
methods = [] # This will hold the final display names
if _min_len_for_regex_path < 4 : # Fallback path if regex didn't get enough (e.g., for 4 methods)
# print("Regex approach failed or yielded insufficient data, trying direct extraction...")
sections = log_content.split("Building HNSW index with ")[1:]
methods_temp = []
for section in sections:
method_name_raw = section.split("\n")[0].strip().rstrip('.')
# Apply new short names in fallback
if method_name_raw == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
elif method_name_raw.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
elif method_name_raw.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
elif method_name_raw.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3]
else: mapped_name = method_name_raw # Fallback to raw if no rule
methods_temp.append(mapped_name)
methods = methods_temp
# If fallback provides fewer than 4 methods, reordering later might not apply or error
# print(f"Direct extraction found {len(methods)} methods: {methods}")
else: # Regex path considered sufficient
methods_temp = []
for raw_name in _methods_raw_identifiers_regex:
# Apply new short names for regex path too
if raw_name == 'hnsw_IP_M30_efC128': mapped_name = MAPPED_METHOD_NAMES[0]
elif raw_name.startswith('99_4_degree'): mapped_name = MAPPED_METHOD_NAMES[1]
elif raw_name.startswith('d9_hnsw'): mapped_name = MAPPED_METHOD_NAMES[2]
elif raw_name.startswith('half'): mapped_name = MAPPED_METHOD_NAMES[3] # Assumes 'half' is a good prefix
else: mapped_name = raw_name # Fallback to cleaned raw name
methods_temp.append(mapped_name)
methods = methods_temp
# print(f"Regex extraction found {len(methods)} methods: {methods}")
# Convert string lists of numbers to actual numbers
avg_neighbors = [float(avg) for avg in avg_neighbors_str_list]
# Reordering (This reordering is crucial for color consistency if colors are fixed by position)
# It assumes methods[0] is Base, methods[1] is Our, etc., *before* this reordering step
# if that was the natural order from logs. The reordering swaps 3rd and 4th items.
if len(methods) >= 4 and \
len(recall_lists_str) >= 4 and \
len(recompute_lists_str) >= 4 and \
len(avg_neighbors) >= 4:
# This reordering means:
# Original order assumed: HNSW-Base, DegreeGuide, HNSW-D9, RandCut
# After reorder: HNSW-Base, DegreeGuide, RandCut, HNSW-D9
methods = [methods[0], methods[1], methods[3], methods[2]]
recall_lists_str = [recall_lists_str[0], recall_lists_str[1], recall_lists_str[3], recall_lists_str[2]]
recompute_lists_str = [recompute_lists_str[0], recompute_lists_str[1], recompute_lists_str[3], recompute_lists_str[2]]
avg_neighbors = [avg_neighbors[0], avg_neighbors[1], avg_neighbors[3], avg_neighbors[2]]
# else:
# print("Warning: Not enough elements to perform standard reordering. Using data as found.")
if len(avg_neighbors) > 0 and avg_neighbors_str_list[0] == "17.35": # Note: avg_neighbors_str_list used for string comparison
target_avg_neighbors = [18, 9, 9, 9] # This seems to be a specific adjustment based on a known log state
current_len = len(avg_neighbors)
# Ensure this reordering matches the one applied to `methods` if avg_neighbors were reordered with them
# If avg_neighbors was reordered, this hardcoding might need adjustment or be applied pre-reorder.
# For now, assume it applies to the (potentially reordered) avg_neighbors list.
avg_neighbors = target_avg_neighbors[:current_len]
recall_lists = [eval(recall_list) for recall_list in recall_lists_str]
recompute_lists = [eval(recompute_list) for recompute_list in recompute_lists_str]
# Final truncation to ensure all lists have the same minimum length
min_length = min(len(methods), len(recall_lists), len(recompute_lists), len(avg_neighbors))
methods = methods[:min_length]
recall_lists = recall_lists[:min_length]
recompute_lists = recompute_lists[:min_length]
avg_neighbors = avg_neighbors[:min_length]
return methods, recall_lists, recompute_lists, avg_neighbors
def plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, current_recall_levels):
"""Create a line chart comparing computation costs at different recall levels, with academic style."""
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
# plt.rcParams["hatch.linewidth"] = 1.5 # From example, but not used in line plot
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True # Ensure LaTeX is available or set to False
computation_costs = []
for i, method_name in enumerate(methods): # methods now contains short names
method_costs = []
for level in current_recall_levels:
recall_idx = next((idx for idx, recall in enumerate(recall_lists[i]) if recall >= level), None)
if recall_idx is not None:
method_costs.append(recompute_lists[i][recall_idx])
else:
method_costs.append(None)
computation_costs.append(method_costs)
fig, ax = plt.subplots(figsize=(5,2.5))
# Modified academic_colors for consistency
# HNSW-Base (Grey), DegreeGuide (Red), RandCut (Cornflowerblue), HNSW-D9 (DarkBlue)
# academic_colors = ['dimgrey', 'tomato', 'cornflowerblue', '#003366', 'forestgreen', 'crimson']
academic_colors = [ 'slategray', 'tomato', 'cornflowerblue','#63B8B6',]
markers = ['o', '*', '^', 'D', 'v', 'P']
# Origin, Our, Random, SmallM
for i, method_name in enumerate(methods): # method_name is now short, e.g., 'HNSW-Base'
color_idx = i % len(academic_colors)
marker_idx = i % len(markers)
y_values_plot = [val if val is not None else np.nan for val in computation_costs[i]]
y_values_plot = [val / 10000 if val is not None else np.nan for val in computation_costs[i]]
if method_name == MAPPED_METHOD_NAMES[0]: # Original HNSW-Base
linestyle = '--'
else:
linestyle = '-'
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
marker_size = 12
elif method_name == MAPPED_METHOD_NAMES[2]: # Small M
marker_size = 7.5
else:
marker_size = 8
if method_name == MAPPED_METHOD_NAMES[1]: # Our Pruning Method
zorder = 10
else:
zorder = 1
# for random prune
if method_name == MAPPED_METHOD_NAMES[3]:
y_values_plot[0] += 0.12 # To prevent overlap with our method
elif method_name == MAPPED_METHOD_NAMES[1]:
y_values_plot[0] -= 0.06 # To prevent overlap with original hnsw
ax.plot(current_recall_levels, y_values_plot,
label=f"{method_name} (Avg Degree: {int(avg_neighbors[i])})", # Uses new short names
color=academic_colors[color_idx], marker=markers[marker_idx], markeredgecolor='#FFFFFF80', # zhege miaobian shibushi buhaokan()
markersize=marker_size, linewidth=2, linestyle=linestyle, zorder=zorder)
ax.set_xlabel('Recall Target', fontsize=9, fontweight="bold")
ax.set_ylabel('Nodes to Recompute', fontsize=9, fontweight="bold")
ax.set_xticks(current_recall_levels)
ax.set_xticklabels([f'{level*100:.0f}\%' for level in current_recall_levels], fontsize=10)
ax.tick_params(axis='y', labelsize=10)
ax.set_ylabel(r'Nodes to Recompute ($\mathbf{\times 10^4}$)', fontsize=9, fontweight="bold")
# Legend styling (already moved up from previous request)
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=2,
fontsize=6, edgecolor="black", facecolor="white", framealpha=1,
shadow=False, fancybox=False, prop={"weight": "normal", "size": 8})
# No grid lines: ax.grid(True, linestyle='--', alpha=0.7)
# Spines adjustment for academic look
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.0)
ax.spines['bottom'].set_linewidth(1.0)
annot_recall_level_92 = 0.92
if annot_recall_level_92 in current_recall_levels:
annot_recall_idx_92 = current_recall_levels.index(annot_recall_level_92)
method_base_name = "Our Pruning Method"
method_compare_92_name = "Small M"
if method_base_name in methods and method_compare_92_name in methods:
idx_base = methods.index(method_base_name)
idx_compare_92 = methods.index(method_compare_92_name)
cost_base_92 = computation_costs[idx_base][annot_recall_idx_92] / 10000
cost_compare_92 = computation_costs[idx_compare_92][annot_recall_idx_92] / 10000
if cost_base_92 is not None and cost_compare_92 is not None and cost_base_92 > 0:
ratio_92 = cost_compare_92 / cost_base_92
ax.annotate("", xy=(annot_recall_level_92, cost_compare_92),
xytext=(annot_recall_level_92, cost_base_92),
arrowprops=dict(arrowstyle="<->", color='#333333',
lw=1.5, mutation_scale=15,
shrinkA=3, shrinkB=3),
zorder=10) # Arrow drawn first
text_x_pos_92 = annot_recall_level_92 # Text x is on the arrow line
text_y_pos_92 = (cost_base_92 + cost_compare_92) / 2
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
if text_y_pos_92 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymin + (plot_ymax-plot_ymin)*0.05
if text_y_pos_92 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_92 = plot_ymax - (plot_ymax-plot_ymin)*0.05
ax.text(text_x_pos_92, text_y_pos_92, f"{ratio_92:.2f}x",
fontsize=9, color='black',
va='center', ha='center', # Centered horizontally and vertically
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
fc='white', # Face color matches plot background
ec='white', # Edge color matches plot background
alpha=1.0), # Fully opaque
zorder=11) # Text on top of arrow
# --- Annotation for performance gap at 96% recall (0.96) ---
annot_recall_level_96 = 0.96
if annot_recall_level_96 in current_recall_levels:
annot_recall_idx_96 = current_recall_levels.index(annot_recall_level_96)
method_base_name = "Our Pruning Method"
method_compare_96_name = "Random Prune"
if method_base_name in methods and method_compare_96_name in methods:
idx_base = methods.index(method_base_name)
idx_compare_96 = methods.index(method_compare_96_name)
cost_base_96 = computation_costs[idx_base][annot_recall_idx_96] / 10000
cost_compare_96 = computation_costs[idx_compare_96][annot_recall_idx_96] / 10000
if cost_base_96 is not None and cost_compare_96 is not None and cost_base_96 > 0:
ratio_96 = cost_compare_96 / cost_base_96
ax.annotate("", xy=(annot_recall_level_96, cost_compare_96),
xytext=(annot_recall_level_96, cost_base_96),
arrowprops=dict(arrowstyle="<->", color='#333333',
lw=1.5, mutation_scale=15,
shrinkA=3, shrinkB=3),
zorder=10) # Arrow drawn first
text_x_pos_96 = annot_recall_level_96 # Text x is on the arrow line
text_y_pos_96 = (cost_base_96 + cost_compare_96) / 2
plot_ymin, plot_ymax = ax.get_ylim() # Boundary checks
if text_y_pos_96 < plot_ymin + (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymin + (plot_ymax-plot_ymin)*0.05
if text_y_pos_96 > plot_ymax - (plot_ymax-plot_ymin)*0.05: text_y_pos_96 = plot_ymax - (plot_ymax-plot_ymin)*0.05
ax.text(text_x_pos_96, text_y_pos_96, f"{ratio_96:.2f}x",
fontsize=9, color='black',
va='center', ha='center', # Centered horizontally and vertically
bbox=dict(boxstyle='square,pad=0.25', # Creates space around text
fc='white', # Face color matches plot background
ec='white', # Edge color matches plot background
alpha=1.0), # Fully opaque
zorder=11) # Text on top of arrow
plt.tight_layout(pad=0.5)
plt.savefig(SAVED_PATH, bbox_inches="tight", dpi=300)
plt.show()
# --- Main script execution ---
parser = argparse.ArgumentParser()
parser.add_argument("log_file", type=str, default="./demo/output.log")
args = parser.parse_args()
try:
with open(args.log_file, 'r') as f:
log_content = f.read()
except FileNotFoundError:
print(f"Error: Log file '{args.log_file}' not found.")
exit()
methods, recall_lists, recompute_lists, avg_neighbors = extract_data_from_log(log_content)
if methods:
# plot_performance(methods, recall_lists, recompute_lists, avg_neighbors)
# print(f"Performance plot saved to {PERFORMANCE_PLOT_PATH}")
plot_recall_comparison(methods, recall_lists, recompute_lists, avg_neighbors, recall_levels)
print(f"Recall comparison plot saved to {SAVED_PATH}")
print("\nMethod Summary:")
for i, method in enumerate(methods):
print(f"{method}:")
if i < len(avg_neighbors): # Check index bounds
print(f" - Average neighbors per node: {avg_neighbors[i]:.2f}")
for level in recall_levels:
if i < len(recall_lists) and i < len(recompute_lists): # Check index bounds
recall_idx = next((idx for idx, recall_val in enumerate(recall_lists[i]) if recall_val >= level), None)
if recall_idx is not None:
print(f" - Computations needed for {level*100:.0f}% recall: {recompute_lists[i][recall_idx]:.0f}")
else:
print(f" - Does not reach {level*100:.0f}% recall in the test")
else:
print(f" - Data missing for recall/recompute lists for method {method}")
print()
else:
print("No data extracted from the log file. Cannot generate plots or summary.")

View File

@@ -1,441 +0,0 @@
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines
import pandas as pd
import numpy as np
from matplotlib.patches import FancyArrowPatch
sns.set_theme(style="ticks", font_scale=1.2)
plt.rcParams['axes.grid'] = True
plt.rcParams['axes.grid.which'] = 'major'
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['grid.color'] = 'gray'
plt.rcParams['grid.alpha'] = 0.3
plt.rcParams['xtick.minor.visible'] = False
plt.rcParams['ytick.minor.visible'] = False
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["text.usetex"] = True
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
# 0.085s 0.217s 0.472s
# llm_inference_time=[0.085, 0.217, 0.472, 0] # Will be replaced by CSV data
# llm_inference_time_for_mac = [0.316, 0.717, 1.468, 0] # Will be replaced by CSV data
def parse_latency_data(csv_path):
df = pd.read_csv(csv_path)
latency_data = {}
llm_gen_times = {} # To store LLM generation times: (dataset, hardware) -> time
for _, row in df.iterrows():
dataset = row['Dataset']
hardware = row['Hardware']
recall_target_str = row['Recall_target'].replace('%', '')
try:
recall_target = float(recall_target_str)
except ValueError:
print(f"Warning: Could not parse recall_target '{row['Recall_target']}'. Skipping row.")
continue
if (dataset, hardware) not in llm_gen_times: # Read once per (dataset, hardware)
llm_time_val = pd.to_numeric(row.get('LLM_Gen_Time_1B'), errors='coerce')
if not pd.isna(llm_time_val):
llm_gen_times[(dataset, hardware)] = llm_time_val
else:
llm_gen_times[(dataset, hardware)] = np.nan # Store NaN if unparsable/missing
cols_to_skip = ['Dataset', 'Hardware', 'Recall_target',
'LLM_Gen_Time_1B', 'LLM_Gen_Time_3B', 'LLM_Gen_Time_7B']
for col in df.columns:
if col not in cols_to_skip:
method_name = col
key = (dataset, hardware, method_name)
if key not in latency_data:
latency_data[key] = []
try:
latency_value = float(row[method_name])
latency_data[key].append((recall_target, latency_value))
except ValueError:
# Handle cases where latency might be non-numeric (e.g., 'N/A' or empty)
print(f"Warning: Could not parse latency for {method_name} at {dataset}/{hardware}/Recall {recall_target} ('{row[method_name]}'). Skipping this point.")
latency_data[key].append((recall_target, np.nan)) # Or skip appending
# Sort by recall for consistent plotting
for key in latency_data:
latency_data[key].sort(key=lambda x: x[0])
return latency_data, llm_gen_times
def parse_storage_data(csv_path):
df = pd.read_csv(csv_path)
storage_data = {}
# Assuming the first column is 'MetricType' (RAM/Storage) and subsequent columns are methods
# And the header row is like: MetricType, Method1, Method2, ...
# Transpose to make methods as rows for easier lookup might be an option,
# but let's try direct parsing.
# Find the row for RAM and Storage
ram_row = df[df.iloc[:, 0] == 'RAM'].iloc[0]
storage_row = df[df.iloc[:, 0] == 'Storage'].iloc[0]
methods = df.columns[1:] # First column is the metric type label
for method in methods:
storage_data[method] = {
'RAM': pd.to_numeric(ram_row[method], errors='coerce'),
'Storage': pd.to_numeric(storage_row[method], errors='coerce')
}
return storage_data
# Load data
latency_csv_path = 'paper_plot/data/main_latency.csv'
storage_csv_path = 'paper_plot/data/ram_storage.csv'
latency_data, llm_generation_times = parse_latency_data(latency_csv_path)
storage_info = parse_storage_data(storage_csv_path)
# --- Determine unique Datasets and Hardware combinations to plot for ---
unique_dataset_hardware_configs = sorted(list(set((d, h) for d, h, m in latency_data.keys())))
if not unique_dataset_hardware_configs:
print("Error: No (Dataset, Hardware) combinations found in latency data. Check CSV paths and content.")
exit()
# --- Define constants for plotting ---
all_method_names = sorted(list(set(m for d,h,m in latency_data.keys())))
if not all_method_names:
# Fallback if latency_data is empty but storage_info might have method names
all_method_names = sorted(list(storage_info.keys()))
if not all_method_names:
print("Error: No method names found in data. Cannot proceed with plotting.")
exit()
method_markers = {
'HNSW': 'o',
'IVF': 'X',
'DiskANN': 's',
'IVF-Disk': 'P',
'IVF-Recompute': '^',
'Our': '*',
'BM25': "v"
# Add more if necessary, or make it dynamic
}
method_display_names = {
'IVF-Recompute': 'IVF-Recompute (EdgeRAG)',
# 其他方法保持原名
}
# Ensure all methods have a marker
default_markers = ['^', 'v', '<', '>', 'H', 'h', '+', 'x', '|', '_']
next_default_marker = 0
for mn in all_method_names:
if mn not in method_markers:
print(f"mn: {mn}")
method_markers[mn] = default_markers[next_default_marker % len(default_markers)]
next_default_marker +=1
recall_levels_present = sorted(list(set(r for key in latency_data for r, l in latency_data[key])))
# Define colors for up to a few common recall levels, add more if needed
base_recall_colors = {
85.0: "#1f77b4", # Blue
90.0: "#ff7f0e", # Orange
95.0: "#2ca02c", # Green
# Add more if other recall % values exist
}
recall_colors = {}
color_palette = sns.color_palette("viridis", n_colors=len(recall_levels_present))
for idx, r_level in enumerate(recall_levels_present):
recall_colors[r_level] = base_recall_colors.get(r_level, color_palette[idx % len(color_palette)])
# --- Determine global x (latency) and y (storage) limits for consistent axes ---
all_latency_values = []
all_storage_values = []
raw_data_size = 76 # Raw data size in GB
for ds_hw_key in unique_dataset_hardware_configs:
current_ds, current_hw = ds_hw_key
for method_name in all_method_names:
# Get storage for this method
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
if not np.isnan(disk_storage):
all_storage_values.append(disk_storage)
# Get latencies for this method under current_ds, current_hw
latency_key = (current_ds, current_hw, method_name)
if latency_key in latency_data:
for recall, latency in latency_data[latency_key]:
if not np.isnan(latency):
all_latency_values.append(latency)
# Add padding to limits
min_lat = min(all_latency_values) if all_latency_values else 0.001
max_lat = max(all_latency_values) if all_latency_values else 1
min_store = min(all_storage_values) if all_storage_values else 0
max_store = max(all_storage_values) if all_storage_values else 1
# Convert storage values to proportion of raw data
min_store_proportion = min_store / raw_data_size if all_storage_values else 0
max_store_proportion = max_store / raw_data_size if all_storage_values else 0.1
# Padding for log scale latency - adjust minimum to be more reasonable
lat_log_min = -1 # Changed from -2 to -1 to set minimum to 10^-1 (0.1s)
lat_log_max = np.log10(max_lat) if max_lat > 0 else 3 # default to 1000 s
lat_padding = (lat_log_max - lat_log_min) * 0.05
global_xlim = [10**(lat_log_min - lat_padding), 10**(lat_log_max + lat_padding)]
if global_xlim[0] <= 0: global_xlim[0] = 0.1 # Changed from 0.01 to 0.1
# Padding for linear scale storage proportion
store_padding = (max_store_proportion - min_store_proportion) * 0.05
global_ylim = [max(0, min_store_proportion - store_padding), max_store_proportion + store_padding]
if global_ylim[0] >= global_ylim[1]: # Avoid inverted or zero range
global_ylim[1] = global_ylim[0] + 0.1
# After loading the data and before plotting, add this code to reorder the datasets
# Find where you define all_datasets (around line 95)
# Original code:
all_datasets = sorted(list(set(ds for ds, _ in unique_dataset_hardware_configs)))
# Replace with this to specify the exact order:
all_datasets_unsorted = list(set(ds for ds, _ in unique_dataset_hardware_configs))
desired_order = ['NQ', 'TriviaQA', 'GPQA','HotpotQA']
all_datasets = [ds for ds in desired_order if ds in all_datasets_unsorted]
# Add any datasets that might be in the data but not in our desired_order list
all_datasets.extend([ds for ds in all_datasets_unsorted if ds not in desired_order])
# Then the rest of your code remains the same:
a10_configs = [(ds, 'A10') for ds in all_datasets if (ds, 'A10') in unique_dataset_hardware_configs]
mac_configs = [(ds, 'MAC') for ds in all_datasets if (ds, 'MAC') in unique_dataset_hardware_configs]
# Create two figures - one for A10 and one for MAC
hardware_configs = [a10_configs, mac_configs]
hardware_names = ['A10', 'MAC']
for fig_idx, configs_for_this_figure in enumerate(hardware_configs):
if not configs_for_this_figure:
continue
num_cols_this_figure = len(configs_for_this_figure)
# 1 row, num_cols_this_figure columns
fig, axs = plt.subplots(1, num_cols_this_figure, figsize=(7 * num_cols_this_figure, 6), sharex=True, sharey=True, squeeze=False)
# fig.suptitle(f"Latency vs. Storage ({hardware_names[fig_idx]})", fontsize=18, y=0.98)
for subplot_idx, (current_ds, current_hw) in enumerate(configs_for_this_figure):
ax = axs[0, subplot_idx] # Accessing column in the first row
ax.set_title(f"{current_ds}", fontsize=25) # No need to show hardware in title since it's in suptitle
for method_name in all_method_names:
marker = method_markers.get(method_name, '+')
disk_storage = storage_info.get(method_name, {}).get('Storage', np.nan)
latency_points_key = (current_ds, current_hw, method_name)
if latency_points_key in latency_data:
points_for_method = latency_data[latency_points_key]
print(f"points_for_method: {points_for_method}")
for recall, latency in points_for_method:
# Only skip if latency is invalid (since we need log scale for x-axis)
# But allow zero storage since y-axis is now linear
if np.isnan(latency) or np.isnan(disk_storage) or latency <= 0:
continue
# Add LLM generation time from CSV
current_llm_add_time = llm_generation_times.get((current_ds, current_hw))
if current_llm_add_time is not None and not np.isnan(current_llm_add_time):
latency = latency + current_llm_add_time
else:
raise ValueError(f"No LLM generation time found for {current_ds} on {current_hw}")
# Special handling for BM25
if method_name == 'BM25':
# BM25 is only valid for 85% recall points (other points are 0)
if recall != 85.0:
continue
color = 'grey'
else:
# Use the color for target recall
color = recall_colors.get(recall, 'grey')
# Convert storage to proportion
disk_storage_proportion = disk_storage / raw_data_size
size = 80
x_offset = -50
if current_ds == 'GPQA':
x_offset = -32
# Apply a small vertical offset to IVF-Recompute points to make them more visible
if method_name == 'IVF-Recompute':
# Add a small vertical offset (adjust the 0.05 value as needed)
disk_storage_proportion += 0.07
size = 80
if method_name == 'DiskANN':
size = 50
if method_name == 'Our':
size = 140
disk_storage_proportion += 0.05
# Add "Pareto Frontier" label to Our method points
if recall == 95:
ax.annotate('Ours',
(latency, disk_storage_proportion),
xytext=(x_offset, 25), # Increased leftward offset from -65 to -120
textcoords='offset points',
fontsize=20,
color='red',
weight='bold',
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.7))
# Increase size for BM25 points
if method_name == 'BM25':
size = 70
size*=5
ax.scatter(latency, disk_storage_proportion, marker=marker, color=color,
s=size, alpha=0.85, edgecolors='black', linewidths=0.7)
ax.set_xscale("log")
ax.set_yscale("linear") # CHANGED from log scale to linear scale for Y-axis
# Generate appropriate powers of 10 based on your data range
min_power = -1
max_power = 4
log_ticks = [10**i for i in range(min_power, max_power+1)]
# Set custom tick positions
ax.set_xticks(log_ticks)
# Create custom bold LaTeX labels with 10^n format
log_tick_labels = [fr'$\mathbf{{10^{{{i}}}}}$' for i in range(min_power, max_power+1)]
ax.set_xticklabels(log_tick_labels, fontsize=24)
# Apply global limits
if subplot_idx == 0:
ax.set_xlim(global_xlim)
ax.set_ylim(global_ylim)
ax.grid(True, which="major", linestyle="--", linewidth=0.6, alpha=0.7)
# Remove minor grid lines completely
ax.grid(False, which="minor")
# Remove ticks
# First set the shared parameters for both axes
ax.tick_params(axis='both', which='both', length=0, labelsize=24)
# Then set the padding only for the x-axis
ax.tick_params(axis='x', which='both', pad=10)
if subplot_idx == 0: # Y-label only for the leftmost subplot
ax.set_ylabel("Proportional Size", fontsize=24)
# X-label for all subplots in a 1xN layout can be okay, or just the middle/last one.
# Let's put it on all for now.
ax.set_xlabel("Latency (s)", fontsize=25)
# Display 100%, 200%, 300% for yaxis
ax.set_yticks([1, 2, 3])
ax.set_yticklabels(['100\%', '200\\%', '300\\%'])
# Create a custom arrow with "Better" text inside
# Create the arrow patch with a wider shaft
arrow = FancyArrowPatch(
(0.8, 0.8), # Start point (top-right)
(0.65, 0.6), # End point (toward bottom-left)
transform=ax.transAxes,
arrowstyle='simple,head_width=40,head_length=35,tail_width=20', # Increased arrow dimensions
facecolor='white',
edgecolor='black',
linewidth=3, # Thicker outline
zorder=5
)
# Add the arrow to the plot
ax.add_patch(arrow)
# Calculate the midpoint of the arrow for text placement
mid_x = (0.8 + 0.65) / 2 + 0.002 + 0.01
mid_y = (0.8 + 0.6) / 2 + 0.01
# Add the "Better" text at the midpoint of the arrow
ax.text(mid_x, mid_y, 'Better',
transform=ax.transAxes,
ha='center',
va='center',
fontsize=16, # Increased font size from 12 to 16
fontweight='bold',
rotation=40, # Rotate to match arrow direction
zorder=6) # Ensure text is on top of arrow
# Create legends (once per figure)
method_legend_handles = []
for method, marker_style in method_markers.items():
if method in all_method_names:
print(f"method: {method}")
# Use black color for BM25 in the legend
if method == 'BM25':
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
markersize=10, label=method))
else:
if method in method_display_names:
method = method_display_names[method]
method_legend_handles.append(mlines.Line2D([], [], color='black', marker=marker_style, linestyle='None',
markersize=10, label=method))
recall_legend_handles = []
sorted_recall_levels = sorted(recall_colors.keys())
for r_level in sorted_recall_levels:
recall_legend_handles.append(mlines.Line2D([], [], color=recall_colors[r_level], marker='o', linestyle='None',
markersize=20, label=f"Target Recall={r_level:.0f}\%"))
# 将图例分成两行:第一行是方法,第二行是召回率
if fig_idx == 0:
# 从方法列表中先排除'Our'
other_methods = [m for m in all_method_names if m != 'Our']
# 按照需要的顺序创建方法列表(将'Our'放在最后)
ordered_methods = other_methods + (['Our'] if 'Our' in all_method_names else [])
# 按照新顺序创建方法图例句柄
method_legend_handles = []
for method in ordered_methods:
if method in method_markers:
marker_style = method_markers[method]
# 使用显示名称映射
display_name = method_display_names.get(method, method)
color = 'black'
marker_size = 22
if method == 'Our':
marker_size = 27
elif 'IVF-Recompute' in method or 'EdgeRAG' in method:
marker_size = 17
elif 'DiskANN' in method:
marker_size = 19
elif 'BM25' in method:
marker_size = 20
method_legend_handles.append(mlines.Line2D([], [], color=color, marker=marker_style,
linestyle='None', markersize=marker_size, label=display_name))
# 创建召回率图例(第二行)- 注意位置调整,放在方法图例下方
recall_legend = fig.legend(handles=recall_legend_handles,
loc='upper center', bbox_to_anchor=(0.5, 1.05), # y坐标降低放在第一行下方
ncol=len(recall_legend_handles), fontsize=28)
# 创建方法图例(第一行)
method_legend = fig.legend(handles=method_legend_handles,
loc='upper center', bbox_to_anchor=(0.5, 0.91),
ncol=len(method_legend_handles), fontsize=28)
# 添加图例到渲染器
fig.add_artist(method_legend)
fig.add_artist(recall_legend)
# 调整布局,为顶部的两行图例留出更多空间
plt.tight_layout(rect=(0, 0, 1.0, 0.74)) # 顶部空间从0.9调整到0.85,给两行图例留出更多空间
save_path = f'./paper_plot/figures/main_exp_fig_{fig_idx+1}.pdf'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Saved figure {fig_idx+1} to {save_path}")
plt.show()

View File

@@ -1,163 +0,0 @@
import csv
import numpy as np
import matplotlib.pyplot as plt
import csv
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
SAVE_PTH = "./paper_plot/figures"
font_size = 16
# Generation(LLama 1B) Generation(LLama 3B) Generation(LLama 7B)
# 0.085s 0.217s 0.472s
llm_inference_time=[0.085, 0.217, 0.472, 0]
USE_LLM_INDEX = 3 # +0
file_path = "./paper_plot/data/main_latency.csv"
with open(file_path, mode="r", newline="") as file:
reader = csv.reader(file)
data = list(reader)
# 打印原始数据
for row in data:
print(",".join(row))
models = ["A10", "MAC"]
datasets = ["NQ", "TriviaQA", "GPQA", "HotpotQA"]
data = [[float(cell) if cell.isdigit() else cell for cell in row] for row in data[1:]]
for k, model in enumerate(models):
fig, axes = plt.subplots(1, 4)
fig.set_size_inches(20, 3)
plt.subplots_adjust(wspace=0, hspace=0)
total_width, n = 6, 6
group = 1
width = total_width * 0.9 / n
x = np.arange(group) * n
exit_idx_x = x + (total_width - width) / n
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "mediumpurple", "green", "red", "blue", "yellow", "silver"]
# hatches = ["", "\\\\", "//", "||", "x", "--", "..", "", "\\\\", "//", "||", "x", "--", ".."]
hatches =["\\\\\\","\\\\"]
labels = [
"HNSW",
"IVF",
"DiskANN",
"IVF-Disk",
"IVF-Recompute",
"Our",
# "DGL-OnDisk",
]
if k == 0:
x_labels = "GraphSAGE"
else:
x_labels = "GAT"
yticks = [0.01, 0.1, 1, 10, 100, 1000,10000] # Log scale ticks
val_limit = 15000 # Upper limit for the plot
for i in range(4):
axes[i].set_yscale('log') # Set y-axis to logarithmic scale
axes[i].set_yticks(yticks)
axes[i].set_ylim(0.01, val_limit) # Lower limit should be > 0 for log scale
axes[i].tick_params(axis="y", labelsize=10)
axes[i].set_xticks([])
# axes[i].set_xticklabels()
axes[i].set_xlabel(datasets[i], fontsize=font_size)
axes[i].grid(axis="y", linestyle="--")
axes[i].set_xlim(exit_idx_x[0] - 0.15 * width - 0.2, exit_idx_x[0] + (n-0.25)* width + 0.2)
for j in range(n):
##TODO add label
# num = float(data[i * 2 + k][j + 3])
# plot_label = [num]
# if j == 6 and i == 3:
# plot_label = ["N/A"]
# num = 0
local_hatches=["////","\\\\","xxxx"]
# here add 3 bars rather than one bar TODO
print('exit_idx_x',exit_idx_x)
# Check if all three models for this algorithm are OOM (data = 0)
is_oom = True
for m in range(3):
if float(data[i * 6 + k*3 + m][j + 3]) != 0:
is_oom = False
break
if is_oom:
# Draw a cross for OOM instead of bars
pos = exit_idx_x + j * width + width * 0.3 # Center position for cross
marker_size = width * 150 # Size of the cross
axes[i].scatter(pos, 0.02, marker='x', color=edgecolors[j], s=marker_size,
linewidth=4, label=labels[j] if j < len(labels) else "", zorder=20)
else:
# Create three separate bar calls instead of trying to plot multiple bars at once
for m in range(3):
num = float(data[i * 6 + k*3 +m][j + 3]) +llm_inference_time[USE_LLM_INDEX]
plot_label = [num]
pos = exit_idx_x + j * width + width * 0.3 * m
print(f"j: {j}, m: {m}, pos: {pos}")
# For log scale, we need to ensure values are positive
plot_value = max(0.01, num) if num < val_limit else val_limit
container = axes[i].bar(
pos,
plot_value,
width=width * 0.3,
color="white",
edgecolor=edgecolors[j],
# edgecolor="k",
hatch=local_hatches[m], # Use different hatches for each of the 3 bars
linewidth=1.0,
label=labels[j] if m == 0 else "", # Only add label for the first bar
zorder=10,
)
# axes[i].bar_label(
# container,
# plot_label,
# fontsize=font_size - 2,
# zorder=200,
# fontweight="bold",
# )
if k == 0:
axes[0].legend(
bbox_to_anchor=(3.25, 1.02),
ncol=7,
loc="lower right",
# fontsize=font_size,
# markerscale=3,
labelspacing=0.2,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
# fancybox=False,
handlelength=2,
handletextpad=0.5,
columnspacing=0.5,
prop={"weight": "bold", "size": font_size},
).set_zorder(100)
axes[0].set_ylabel("Runtime (log scale)", fontsize=font_size, fontweight="bold")
axes[0].set_yticklabels([r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$", r"$10^{1}$", r"$10^{2}$", r"$10^{3}$",r"$10^{4}$"], fontsize=font_size)
axes[1].set_yticklabels([])
axes[2].set_yticklabels([])
axes[3].set_yticklabels([])
plt.savefig(f"{SAVE_PTH }/speed_{model}_revised.pdf", bbox_inches="tight", dpi=300)
## print save
print(f"{SAVE_PTH }/speed_{model}_revised.pdf")

View File

@@ -1,85 +0,0 @@
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
# Comment Test
# om script.settings import DATA_PATH, FIGURE_PATH
# DATA_PATH ="/home/ubuntu/Power-RAG/paper_plot/data"
# FIGURE_PATH = "/home/ubuntu/Power-RAG/paper_plot/figures"
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 2
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
import numpy as np
import pandas as pd
# Load the RAM and Storage data directly from CSV
data = pd.read_csv("./paper_plot/data/ram_storage.csv")
# Explicitly reorder columns to ensure "Our" is at the end
cols = list(data.columns)
if "Our" in cols and cols[-1] != "Our":
cols.remove("Our")
cols.append("Our")
data = data[cols]
# Set up the figure with two columns
fig = plt.figure(figsize=(12, 3))
gs = GridSpec(1, 2, figure=fig)
ax1 = fig.add_subplot(gs[0, 0]) # Left panel for RAM
ax2 = fig.add_subplot(gs[0, 1]) # Right panel for Storage
# Define the visual style elements
edgecolors = ["dimgrey", "#63B8B6", "tomato", "slategray", "silver", "navy"]
hatches = ["/////", "\\\\\\\\\\"]
# Calculate positions for the bars
methods = data.columns[1:] # Skip the 'Hardware' column
num_methods = len(methods)
# Reverse the order of methods for display (to have "Our" at the bottom)
methods = list(methods)[::-1]
y_positions = np.arange(num_methods)
bar_width = 0.6
# Plot RAM data in left panel
ram_bars = ax1.barh(
y_positions,
data.iloc[0, 1:].values[::-1], # Reverse the data to match reversed methods
height=bar_width,
color="white",
edgecolor=edgecolors[0],
hatch=hatches[0],
linewidth=1.0,
label="RAM",
zorder=10,
)
ax1.set_title("RAM Usage", fontsize=14, fontweight='bold')
ax1.set_yticks(y_positions)
ax1.set_yticklabels(methods, fontsize=14)
ax1.set_xlabel("Size (\\textit{GB})", fontsize=14)
ax1.xaxis.set_tick_params(labelsize=14)
# Plot Storage data in right panel
storage_bars = ax2.barh(
y_positions,
data.iloc[1, 1:].values[::-1], # Reverse the data to match reversed methods
height=bar_width,
color="white",
edgecolor=edgecolors[1],
hatch=hatches[1],
linewidth=1.0,
label="Storage",
zorder=10,
)
ax2.set_title("Storage Usage", fontsize=14, fontweight='bold')
ax2.set_yticks(y_positions)
ax2.set_yticklabels(methods, fontsize=14)
ax2.set_xlabel("Size (\\textit{GB})", fontsize=14)
ax2.xaxis.set_tick_params(labelsize=14)
plt.tight_layout()
plt.savefig("./paper_plot/figures/ram_storage_double_column.pdf", bbox_inches="tight", dpi=300)
print("Saving the figure to ./paper_plot/figures/ram_storage_double_column.pdf")

View File

@@ -1,141 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /bottleneck_breakdown.py
# \brief: Illustrates the query time bottleneck on consumer devices (Final Version - Font & Legend Adjust).
# Author: Gemini Assistant (adapted from user's style and feedback)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter # Not strictly needed for just font, but imported if user wants to try
# Set matplotlib styles similar to the example
plt.rcParams["font.family"] = "Helvetica" # Primary font family
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.0
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# Attempt to make LaTeX use Helvetica as the main font
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{helvet} % helvetica font
\usepackage{sansmath} % helvetica for math
\sansmath % activate sansmath
\renewcommand{\familydefault}{\sfdefault} % make sans-serif the default family
"""
# Final Data for the breakdown (3 Segments)
labels_raw = [ # Raw labels before potential LaTeX escaping
'IO: Text + PQ Lookup',
'CPU: Tokenize + Distance Compute',
'GPU: Embedding Recompute',
]
# Times in ms, ordered for stacking
times_ms = np.array([
8.009, # Quantization
16.197, # Search
76.512, # Embedding Recomputation
])
total_time_ms = times_ms.sum()
percentages = (times_ms / total_time_ms) * 100
# Prepare labels for legend, escaping for LaTeX if active
labels_legend = []
# st1 = r'&' # Not needed as current labels_raw don't have '&'
for label, time, perc in zip(labels_raw, times_ms, percentages):
# Construct the percentage string carefully for LaTeX
perc_str = f"{perc:.1f}" + r"\%" # Correct way to form 'NN.N\%'
# label_tex = label.replace('&', st1) # Use if '&' is in labels_raw
label_tex = label # Current labels_raw are clean for LaTeX
labels_legend.append(
f"{label_tex}\n({time:.1f}ms, {perc_str})"
)
# Styling based on user's script
# Using first 3 from the provided lists
edgecolors_list = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
hatches_list = ["/////", "xxxxx", "\\\\\\\\\\"]
edgecolors = edgecolors_list[:3]
hatches = hatches_list[:3]
fill_color = "white"
# Create the figure and axes
# Adjusted figure size to potentially accommodate legend on the right
fig, ax = plt.subplots()
fig.set_size_inches(7, 1.5) # Width increased slightly, height adjusted
# Adjusted right margin for external legend, bottom for x-label
plt.subplots_adjust(left=0.12, right=0.72, top=0.95, bottom=0.25)
# Create the horizontal stacked bar
bar_height = 0.2
y_pos = 0
left_offset = 0
for i in range(len(times_ms)):
ax.barh(
y_pos,
times_ms[i],
height=bar_height,
left=left_offset,
color=fill_color,
edgecolor=edgecolors[i],
hatch=hatches[i],
linewidth=1.5,
label=labels_legend[i],
zorder=10
)
text_x_pos = left_offset + times_ms[i] / 2
if times_ms[i] > total_time_ms * 0.03: # Threshold for displaying text
ax.text(
text_x_pos,
y_pos,
f"{times_ms[i]:.1f}ms",
ha='center',
va='center',
fontsize=8,
fontweight='bold',
color='black',
zorder=20,
bbox=dict(facecolor='white', edgecolor='none', pad=0.5, alpha=0.8)
)
left_offset += times_ms[i]
# Set plot limits and labels
ax.set_xlim([0, total_time_ms * 1.02])
ax.set_xlabel("Time (ms)", fontsize=14, fontweight='bold', x=0.75, )
# Y-axis: Remove y-ticks and labels
ax.set_yticks([])
ax.set_yticklabels([])
# Legend: Placed to the right of the plot
ax.legend(
# (x, y) for anchor, (0,0) is bottom left, (1,1) is top right of AXES
# To place outside on the right, x should be > 1
bbox_to_anchor=(1.03, 0.5), # x > 1 means outside to the right, y=0.5 for vertical center
ncol=1, # Single column for a taller, narrower legend
loc="center left", # Anchor the legend's left-center to bbox_to_anchor point
labelspacing=0.5, # Adjust spacing
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=1.5,
handletextpad=0.6,
columnspacing=1.5,
prop={"weight": "bold", "size": 9},
).set_zorder(100)
# Save the figure (using the original generic name as requested)
output_filename = "./bottleneck_breakdown.pdf"
# plt.tight_layout() # tight_layout might conflict with external legend; adjust subplots_adjust instead
plt.savefig(output_filename, bbox_inches="tight", dpi=300)
print(f"Saved plot to {output_filename}")
# plt.show() # Uncomment to display plot interactively

View File

@@ -1,226 +0,0 @@
import matplotlib.pyplot as plt
import numpy as np
# import matplotlib.ticker as mticker # Not actively used
import os
FIGURE_PATH = "paper_plot/figures"
try:
os.makedirs(FIGURE_PATH, exist_ok=True)
print(f"Images will be saved to: {os.path.abspath(FIGURE_PATH)}")
except OSError as e:
print(f"Create {FIGURE_PATH} failed: {e}. Images will be saved in the current working directory.")
FIGURE_PATH = "."
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 2
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
method_labels = ["gte-small (33M)", "contriever-msmarco (110M)"]
dataset_names = ["NQ", "TriviaQA"]
metrics_plot1 = ["Exact Match", "F1"]
small_nq_f1 = 0.2621040899
small_tq_f1 = 0.4698198059
small_nq_em_score = 0.1845
small_tq_em_score = 0.4015
small_nq_time = 1.137
small_tq_time = 1.173
large_nq_f1 = 0.2841386117
large_tq_f1 = 0.4548340289
large_nq_em_score = 0.206
large_tq_em_score = 0.382
large_nq_time = 2.632
large_tq_time = 2.684
data_scores_plot1 = {
"NQ": {"Exact Match": [small_nq_em_score, large_nq_em_score], "F1": [small_nq_f1, large_nq_f1]},
"TriviaQA": {"Exact Match": [small_tq_em_score, large_tq_em_score], "F1": [small_tq_f1, large_tq_f1]}
}
latency_data_plot2 = {
"NQ": [small_nq_time, large_nq_time],
"TriviaQA": [small_tq_time, large_tq_time]
}
edgecolors = ["dimgrey", "tomato"]
hatches = ["/////", "\\\\\\\\\\"]
# Changed: bar_center_separation_in_group increased for larger gap
bar_center_separation_in_group = 0.42
# Changed: bar_visual_width decreased for narrower bars
bar_visual_width = 0.28
figsize_plot1 = (4, 2.5)
# Changed: figsize_plot2 width adjusted to match figsize_plot1 for legend/caption alignment
figsize_plot2 = (2.5, 2.5)
# Define plot1_xlim_per_subplot globally so it can be accessed by create_plot2_latency
plot1_xlim_per_subplot = (0.0, 2.0) # Explicit xlim for plot 1 subplots
common_subplots_adjust_params = dict(wspace=0.30, top=0.80, bottom=0.22, left=0.09, right=0.96)
def create_plot1_em_f1():
fig, axs = plt.subplots(1, 2, figsize=figsize_plot1)
fig.subplots_adjust(**common_subplots_adjust_params)
num_methods = len(method_labels)
metric_group_centers = np.array([0.5, 1.5])
# plot1_xlim_per_subplot is now global
for i, dataset_name in enumerate(dataset_names):
ax = axs[i]
for metric_idx, metric_name in enumerate(metrics_plot1):
metric_center_pos = metric_group_centers[metric_idx]
current_scores_raw = data_scores_plot1[dataset_name][metric_name]
current_scores_percent = [val * 100 for val in current_scores_raw]
for j, method_label in enumerate(method_labels):
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos = metric_center_pos + offset
ax.bar(
bar_center_pos, current_scores_percent[j], width=bar_visual_width, color="white",
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
label=method_label if i == 0 and metric_idx == 0 else None
)
ax.text(
bar_center_pos, current_scores_percent[j] + 0.8, f"{current_scores_percent[j]:.1f}",
ha='center', va='bottom', fontsize=8, fontweight='bold'
)
ax.set_xticks(metric_group_centers)
ax.set_xticklabels(metrics_plot1, fontsize=9, fontweight='bold')
ax.set_title(dataset_name, fontsize=12, fontweight='bold')
ax.set_xlim(plot1_xlim_per_subplot) # Apply consistent xlim
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
if i == 0:
ax.set_ylabel("Accuracy (\%)", fontsize=12, fontweight="bold")
all_subplot_scores_percent = []
for metric_name_iter in metrics_plot1:
all_subplot_scores_percent.extend([val * 100 for val in data_scores_plot1[dataset_name][metric_name_iter]])
max_val = max(all_subplot_scores_percent) if all_subplot_scores_percent else 0
ax.set_ylim(0, max_val * 1.22 if max_val > 0 else 10)
ax.tick_params(axis='y', labelsize=12)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1.0)
spine.set_edgecolor("black")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=len(method_labels),
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
prop={"weight": "bold", "size": 9}
)
# fig.text(0.5, 0.06, "(a) EM \& F1", ha='center', va='center', fontweight='bold', fontsize=11)
save_path = os.path.join(FIGURE_PATH, "plot1_em_f1.pdf")
# plt.tight_layout() # Adjusted call below
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
plt.close(fig)
print(f"Figure 1 (Exact Match & F1) has been saved to: {save_path}")
def create_plot2_latency():
fig, axs = plt.subplots(1, 2, figsize=figsize_plot2) # figsize_plot2 width is now 8.0
fig.subplots_adjust(**common_subplots_adjust_params)
num_methods = len(method_labels)
method_group_center_in_subplot = 0.5
# Calculate bar extents to determine focused xlim
bar_positions_calc = []
for j_idx in range(num_methods):
offset_calc = (j_idx - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos_calc = method_group_center_in_subplot + offset_calc
bar_positions_calc.append(bar_center_pos_calc)
min_bar_actual_edge = min(bar_positions_calc) - bar_visual_width / 2.0
max_bar_actual_edge = max(bar_positions_calc) + bar_visual_width / 2.0
# Define padding around the bars
# Option 1: Fixed padding (e.g., 0.15 as derived from plot 1 visual)
# padding_val = 0.15
# plot2_xlim_calculated = (min_bar_actual_edge - padding_val, max_bar_actual_edge + padding_val)
# This would be (0.15 - 0.15, 0.85 + 0.15) = (0.0, 1.0)
# Option 2: Center the group (0.5) in a span of 1.0
plot2_xlim_calculated = (method_group_center_in_subplot - 0.5, method_group_center_in_subplot + 0.5)
# This is (0.5 - 0.5, 0.5 + 0.5) = (0.0, 1.0)
# This is simpler and achieves the (0.0, 1.0) directly.
for i, dataset_name in enumerate(dataset_names):
ax = axs[i]
current_latencies = latency_data_plot2[dataset_name]
for j, method_label in enumerate(method_labels):
offset = (j - (num_methods - 1) / 2.0) * bar_center_separation_in_group
bar_center_pos = method_group_center_in_subplot + offset
ax.bar(
bar_center_pos, current_latencies[j], width=bar_visual_width, color="white",
edgecolor=edgecolors[j], hatch=hatches[j], linewidth=1.5,
label=method_label if i == 0 else None
)
ax.text(
bar_center_pos, current_latencies[j] + 0.05, f"{current_latencies[j]:.2f}",
ha='center', va='bottom', fontsize=10, fontweight='bold'
)
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: ' {:.0f}'.format(y)))
ax.set_xticks([0.5])
ax.set_xticklabels(["Latency"], color="white", fontsize=12)
# set tick hatches
ax.tick_params(axis='x', colors="white")
ax.set_title(dataset_name, fontsize=13, fontweight='bold')
ax.set_xlim(plot2_xlim_calculated)
if i == 0:
ax.set_ylabel("Latency (s)", fontsize=12, fontweight="bold")
max_latency_in_subplot = max(current_latencies) if current_latencies else 0
ax.set_ylim(0, max_latency_in_subplot * 1.22 if max_latency_in_subplot > 0 else 1)
ax.tick_params(axis='y', labelsize=12)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1.0)
spine.set_edgecolor("black")
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(
handles, labels, loc="upper center", bbox_to_anchor=(0.5, 0.97), ncol=num_methods,
edgecolor="black", facecolor="white", framealpha=1, shadow=False, fancybox=False,
handlelength=1.5, handletextpad=0.4, columnspacing=0.8,
prop={"weight": "bold", "size": 9}
)
# fig.text(0.5, 0.06, "(b) Latency", ha='center', va='center', fontweight='bold', fontsize=11)
save_path = os.path.join(FIGURE_PATH, "plot2_latency.pdf")
# plt.tight_layout() # Adjusted call below
fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.88)) # Adjusted to make space for fig.text and fig.legend
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.03)
plt.close(fig)
print(f"Figure 2 (Latency) has been saved to: {save_path}")
if __name__ == "__main__":
print("Start generating figures...")
if plt.rcParams["text.usetex"]:
print("Info: LaTeX rendering is enabled. Ensure LaTeX is installed and configured if issues arise, or set plt.rcParams['text.usetex'] to False.")
create_plot1_em_f1()
create_plot2_latency()
print("All figures have been generated.")

View File

@@ -1,111 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Motto: Were It to Benefit My Country, I Would Lay Down My Life!
# \file: /speed_ablation.py
# \brief:
# Author: raphael hao
# %%
import numpy as np
import pandas as pd
# %%
# from script.settings import DATA_PATH, FIGURE_PATH
# Load the latency ablation data
latency_data = pd.read_csv("./paper_plot/data/latency_ablation.csv")
# Filter for SpeedUp metric only
speedup_data = latency_data[latency_data['Metric'] == 'SpeedUp']
# %%
from matplotlib import pyplot as plt
plt.rcParams["font.family"] = "Helvetica"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["hatch.linewidth"] = 1.5
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["text.usetex"] = True
# %%
fig, ax = plt.subplots()
fig.set_size_inches(5, 1.5)
plt.subplots_adjust(wspace=0, hspace=0)
total_width, n = 3, 3
group = len(speedup_data['Dataset'].unique())
width = total_width * 0.9 / n
x = np.arange(group) * n
exit_idx_x = x + (total_width - width) / n
edgecolors = ["dimgrey", "#63B8B6", "tomato", "silver", "slategray"]
hatches = ["/////", "xxxxx", "\\\\\\\\\\"]
labels = ["Base", "Base + Two-level", "Base + Two-level + Batch"]
datasets = speedup_data['Dataset'].unique()
for i, dataset in enumerate(datasets):
dataset_data = speedup_data[speedup_data['Dataset'] == dataset]
for j in range(n):
if j == 0:
value = dataset_data['Original'].values[0]
elif j == 1:
value = dataset_data['original + two_level'].values[0]
else:
value = dataset_data['original + two_level + batch'].values[0]
ax.text(
exit_idx_x[i] + j * width,
value + 0.05,
f"{value:.2f}",
ha='center',
va='bottom',
fontsize=10,
fontweight='bold',
rotation=0,
zorder=20,
)
ax.bar(
exit_idx_x[i] + j * width,
value,
width=width * 0.8,
color="white",
edgecolor=edgecolors[j],
hatch=hatches[j],
linewidth=1.5,
label=labels[j] if i == 0 else None,
zorder=10,
)
ax.set_ylim([0.5, 2.3])
ax.set_yticks(np.arange(0.5, 2.2, 0.5))
ax.set_yticklabels(np.arange(0.5, 2.2, 0.5), fontsize=12)
ax.set_xticks(exit_idx_x + width)
ax.set_xticklabels(datasets, fontsize=10)
# ax.set_xlabel("Different Datasets", fontsize=14)
ax.legend(
bbox_to_anchor=(-0.03, 1.4),
ncol=3,
loc="upper left",
labelspacing=0.1,
edgecolor="black",
facecolor="white",
framealpha=1,
shadow=False,
fancybox=False,
handlelength=0.8,
handletextpad=0.6,
columnspacing=0.8,
prop={"weight": "bold", "size": 10},
).set_zorder(100)
ax.set_ylabel("Speedup", fontsize=11)
plt.savefig("./paper_plot/figures/latency_speedup.pdf", bbox_inches="tight", dpi=300)
# %%
print(f"Save to ./paper_plot/figures/latency_speedup.pdf")

View File

@@ -1 +0,0 @@
analyze_diskann_graph

View File

@@ -1,227 +0,0 @@
#include <cassert>
#include <cstdint>
#include <cstring>
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include <vector>
static const size_t DISKANN_SECTOR_LEN = 4096; // Typical sector size
// ! Use float as CoordT
using CoordT = float;
int main(int argc, char **argv) {
if (argc < 3) {
std::cerr << "Usage: " << argv[0]
<< " <diskann_index_file> <output_degree_file>" << std::endl;
return -1;
}
std::string disk_index_path = argv[1];
std::string output_degree_path = argv[2];
std::ifstream in(disk_index_path, std::ios::binary);
if (!in.is_open()) {
std::cerr << "Failed to open file: " << disk_index_path << std::endl;
return -1;
}
// =========== 1) Read meta information (corresponds to
// save_bin<uint64_t>(...,...,...,1,0)) ============== Read bin header:
// (npts_i32, dim_i32)
int32_t meta_count_i32 = 0, meta_dim_i32 = 0;
in.read(reinterpret_cast<char *>(&meta_count_i32), sizeof(int32_t));
in.read(reinterpret_cast<char *>(&meta_dim_i32), sizeof(int32_t));
size_t meta_count = static_cast<size_t>(meta_count_i32);
size_t meta_dim = static_cast<size_t>(meta_dim_i32);
// According to the diskann::save_bin writing method, here meta_dim is usually
// 1
std::cout << "[LOG] meta_count=" << meta_count << ", meta_dim=" << meta_dim
<< std::endl;
if (meta_dim != 1) {
std::cerr << "[ERROR] meta_dim != 1,不符合 create_disk_layout 的写盘约定。"
<< std::endl;
return -1;
}
// Read meta array
std::vector<uint64_t> meta(meta_count);
in.read(reinterpret_cast<char *>(meta.data()), meta_count * sizeof(uint64_t));
if (!in.good()) {
std::cerr << "[ERROR] Failed to read meta array, file is incomplete."
<< std::endl;
return -1;
}
// meta[0..] Metadata
// 0: npts_64, 1: ndims_64, 2: medoid, 3: max_node_len, 4: nnodes_per_sector,
// 5: vamana_frozen_num, 6: vamana_frozen_loc, 7: append_reorder_data, ...
const uint64_t npts_64 = meta[0];
const uint64_t ndims_64 = meta[1];
const uint64_t medoid = meta[2];
const uint64_t max_node_len = meta[3];
const uint64_t nnodes_per_sector = meta[4];
const uint64_t vamana_frozen_num = meta[5];
const uint64_t vamana_frozen_loc = meta[6];
const uint64_t append_reorder_data = meta[7];
std::cout << "[LOG] npts_64=" << npts_64 << " ndims_64=" << ndims_64
<< " max_node_len=" << max_node_len
<< " nnodes_per_sector=" << nnodes_per_sector << std::endl;
// If append_reorder_data==1, it means that reorder_data is appended at the
// end of the index, but it does not affect the degree statistics, we can
// ignore that part of the vector.
// =========== 2) Skip the first sector (all empty/placeholder information)
// ==============
in.seekg(DISKANN_SECTOR_LEN, std::ios::beg);
if (!in.good()) {
std::cerr << "[ERROR] Failed to seek to the first sector." << std::endl;
return -1;
}
// =========== 3) Calculate the total number of sectors ==============
// In create_disk_layout:
// If nnodes_per_sector > 0, then n_sectors = ceil(npts_64 /
// nnodes_per_sector) Otherwise nsectors_per_node = ceil(max_node_len /
// 4096), n_sectors = nsectors_per_node * npts_64
uint64_t n_sectors = 0;
if (nnodes_per_sector > 0) {
// Equivalent to Roundup(npts_64, nnodes_per_sector) / nnodes_per_sector
n_sectors = (npts_64 + nnodes_per_sector - 1) / nnodes_per_sector;
} else {
// multi-sector per node
uint64_t nsectors_per_node =
(max_node_len + DISKANN_SECTOR_LEN - 1) / DISKANN_SECTOR_LEN;
n_sectors = nsectors_per_node * npts_64;
}
std::cout << "[LOG] estimated #sectors storing adjacency = " << n_sectors
<< std::endl;
// =========== 4) Read the degree of all nodes in order ==============
// The memory layout of adjacency_count in each node: offset = ndims_64 *
// sizeof(CoordT) This is followed by 4 bytes for the number of neighbors
// uint32_t If you want to read the complete neighbor list, it is
// adjacency_count * sizeof(uint32_t) But we only count the count
std::vector<uint32_t> degrees(npts_64, 0); // Store the degree of each node
size_t node_id = 0; // Current node number
// Buffer for reading one sector at a time
std::vector<char> sector_buf(DISKANN_SECTOR_LEN, 0);
// If nnodes_per_sector>0, it means that one sector holds multiple nodes
// Otherwise, one node occupies nsectors_per_node sectors
if (nnodes_per_sector > 0) {
// Read one sector at a time
for (uint64_t s = 0; s < n_sectors; s++) {
in.read((char *)sector_buf.data(), DISKANN_SECTOR_LEN);
if (!in.good()) {
if (node_id < npts_64) {
std::cerr << "[ERROR] Failed to read sector " << s
<< ", nodes not finished, file error or incomplete."
<< std::endl;
return -1;
}
break; // If all nodes are read, you can exit
}
// Parse each node in sector_buf
for (uint64_t i = 0; i < nnodes_per_sector; i++) {
if (node_id >= npts_64)
break; // All node degrees have been obtained
// The starting offset of the node in sector_buf
size_t node_offset = i * max_node_len;
// offset first skips ndims_64 * sizeof(CoordT)
size_t degree_offset = node_offset + ndims_64 * sizeof(CoordT);
// Ensure not out of bounds
if (degree_offset + sizeof(uint32_t) > sector_buf.size()) {
std::cerr << "[ERROR] 不应该发生: 读取degree越过了扇区边界."
<< std::endl;
return -1;
}
uint32_t deg = 0;
memcpy(&deg, sector_buf.data() + degree_offset, sizeof(uint32_t));
degrees[node_id] = deg;
node_id++;
}
}
} else {
// Each node occupies nsectors_per_node sectors
uint64_t nsectors_per_node =
(max_node_len + DISKANN_SECTOR_LEN - 1) / DISKANN_SECTOR_LEN;
// Read each node
for (uint64_t n = 0; n < npts_64; n++) {
// Read multiple sectors into a multi-sector buffer
std::vector<char> node_buf(nsectors_per_node * DISKANN_SECTOR_LEN, 0);
in.read((char *)node_buf.data(), node_buf.size());
if (!in.good()) {
std::cerr << "[ERROR] Failed to read sector corresponding to node " << n
<< ", file error or incomplete." << std::endl;
return -1;
}
// Parse the degree in node_buf
size_t degree_offset = ndims_64 * sizeof(CoordT);
if (degree_offset + sizeof(uint32_t) > node_buf.size()) {
std::cerr << "[ERROR] Should not happen: reading degree beyond node "
"region."
<< std::endl;
return -1;
}
uint32_t deg = 0;
memcpy(&deg, node_buf.data() + degree_offset, sizeof(uint32_t));
degrees[n] = deg;
}
}
// We assert here: node_id should equal npts_64 (in multi-node mode)
if (nnodes_per_sector > 0) {
if (node_id != npts_64) {
std::cerr << "[ERROR] Actually read " << node_id
<< " nodes, but meta npts_64=" << npts_64
<< ", file may be incorrect or parsing method is wrong."
<< std::endl;
return -1;
}
}
// =========== 5) Calculate min / max / average degree ==============
uint64_t sum_deg = 0;
uint32_t min_deg = std::numeric_limits<uint32_t>::max();
uint32_t max_deg = 0;
for (uint64_t n = 0; n < npts_64; n++) {
uint32_t d = degrees[n];
sum_deg += d;
if (d < min_deg)
min_deg = d;
if (d > max_deg)
max_deg = d;
}
double avg_deg = (npts_64 == 0) ? 0.0 : double(sum_deg) / double(npts_64);
// =========== 6) Output results ==============
std::cout << "DiskANN index file: " << disk_index_path << std::endl;
std::cout << "Total points: " << npts_64 << std::endl;
std::cout << "Min degree : " << min_deg << std::endl;
std::cout << "Max degree : " << max_deg << std::endl;
std::cout << "Avg degree : " << avg_deg << std::endl;
// =========== 7) Write degrees to output file ==============
std::ofstream out_deg(output_degree_path);
if (!out_deg.is_open()) {
std::cerr << "[ERROR] Failed to open output file: " << output_degree_path
<< std::endl;
// Don't necessarily exit, maybe just warn? Depends on desired behavior.
// For now, we continue closing the input file.
} else {
std::cout << "[LOG] Writing degrees to " << output_degree_path << "..."
<< std::endl;
for (uint64_t n = 0; n < npts_64; n++) {
out_deg << degrees[n] << std::endl;
}
out_deg.close();
std::cout << "[LOG] Finished writing degrees." << std::endl;
}
in.close();
return 0;
}

View File

@@ -1,187 +0,0 @@
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import re
# 设置风格
plt.style.use('ggplot')
sns.set(font_scale=1.2)
# 读取数据 - 修改为自定义读取逻辑
log_file = './top3_positions_log.txt'
# 手动解析文件
data = []
header = None
with open(log_file, 'r') as f:
lines = f.readlines()
header = lines[0].strip().split(',')
# 检查是否存在ThreadID列
has_thread_id = 'ThreadID' in header
for line in lines[1:]:
# 跳过非数据行,如"Search X results:"
if 'results:' in line or not ',' in line:
continue
# 分割并解析数据行
parts = line.strip().split(',')
# 检查数据是否符合格式
if len(parts) >= 7: # 至少需要7个字段
# 对于旧格式(无ThreadID)的数据
if not has_thread_id and len(parts) == 7:
data.append([parts[0], 0, parts[1], parts[2], parts[3], parts[4], parts[5], parts[6]])
# 对于新格式(有ThreadID)的数据
elif has_thread_id and len(parts) == 8:
data.append(parts)
# 处理不一致的格式
elif not has_thread_id and len(parts) == 8:
# 假设第二列是ThreadID
data.append(parts)
if not has_thread_id:
has_thread_id = True
header.insert(1, 'ThreadID')
# 确保header正确
if not has_thread_id:
header.insert(1, 'ThreadID')
# 创建DataFrame并确保列名正确
if len(header) == 8: # 确保有8列
df = pd.DataFrame(data, columns=header)
else:
# 如果header不正确则使用默认列名
default_header = ['Search#', 'ThreadID', 'FullSetSize', 'Rank', 'ID', 'PQ_Rank', 'PQ_Distance', 'Exact_Distance']
df = pd.DataFrame(data, columns=default_header)
# 转换数值列
df['Search#'] = pd.to_numeric(df['Search#'], errors='coerce').fillna(0).astype(int)
df['ThreadID'] = pd.to_numeric(df['ThreadID'], errors='coerce').fillna(0).astype(int)
df['FullSetSize'] = pd.to_numeric(df['FullSetSize'], errors='coerce').fillna(0).astype(int)
df['Rank'] = pd.to_numeric(df['Rank'], errors='coerce').fillna(0).astype(int)
df['ID'] = pd.to_numeric(df['ID'], errors='coerce').fillna(0).astype(int)
df['PQ_Rank'] = pd.to_numeric(df['PQ_Rank'], errors='coerce').fillna(0).astype(int)
df['PQ_Distance'] = pd.to_numeric(df['PQ_Distance'], errors='coerce').fillna(0).astype(float)
df['Exact_Distance'] = pd.to_numeric(df['Exact_Distance'], errors='coerce').fillna(0).astype(float)
print(f"读取了 {len(df)} 行数据")
print(f"搜索次数: {df['Search#'].nunique()}")
print(f"线程数: {df['ThreadID'].nunique()}")
# 提取前3名的结果
top3_df = df[df['Rank'] <= 3].copy()
# 分析PQ Rank的分布
pq_positions = []
for rank in [1, 2, 3]:
rank_df = top3_df[top3_df['Rank'] == rank]
pq_positions.append(rank_df['PQ_Rank'].values)
# 创建结果目录
result_dir = './analysis_results'
os.makedirs(result_dir, exist_ok=True)
# 1. 箱型图展示top-3结果在PQ排序中的位置分布
plt.figure(figsize=(10, 6))
box_data = [top3_df[top3_df['Rank'] == i]['PQ_Rank'].values for i in [1, 2, 3]]
sns.boxplot(data=box_data)
plt.xticks([0, 1, 2], ['Top 1', 'Top 2', 'Top 3'])
plt.ylabel('PQ Rank Position')
plt.title('Distribution of PQ Ranks for Top-3 Exact Results')
plt.savefig(os.path.join(result_dir, 'pq_rank_boxplot.png'), dpi=300)
# 2. 直方图每个排名在PQ结果中的位置分布
fig, axs = plt.subplots(1, 3, figsize=(18, 6))
for i, rank in enumerate([1, 2, 3]):
rank_df = top3_df[top3_df['Rank'] == rank]
sns.histplot(x=rank_df['PQ_Rank'].values, bins=20, ax=axs[i])
axs[i].set_title(f'Exact Rank {rank}')
axs[i].set_xlabel('PQ Rank')
axs[i].set_ylabel('Frequency')
plt.tight_layout()
plt.savefig(os.path.join(result_dir, 'pq_rank_histogram.png'), dpi=300)
# 3. 热力图PQ排名与精确排名的关系
plt.figure(figsize=(10, 8))
# 只关注Top 20的排名
bins = list(range(0, 22))
pq_rank_bins = pd.cut(top3_df['PQ_Rank'], bins=bins)
heatmap_data = pd.crosstab(pq_rank_bins, top3_df['Rank'])
sns.heatmap(heatmap_data, cmap='YlGnBu', annot=True, fmt='d')
plt.title('Heatmap of Exact Rank vs PQ Rank (Top 20)')
plt.xlabel('Exact Rank')
plt.ylabel('PQ Rank Range')
plt.savefig(os.path.join(result_dir, 'rank_heatmap.png'), dpi=300)
# 4. 散点图比较PQ距离和精确距离的关系
plt.figure(figsize=(10, 8))
sns.scatterplot(x=top3_df['Exact_Distance'], y=top3_df['PQ_Distance'], hue=top3_df['Rank'], palette='viridis')
plt.title('PQ Distance vs Exact Distance')
plt.xlabel('Exact Distance')
plt.ylabel('PQ Distance')
plt.legend(title='Exact Rank')
# 添加对角线表示完美匹配
min_val = min(top3_df['Exact_Distance'].min(), top3_df['PQ_Distance'].min())
max_val = max(top3_df['Exact_Distance'].max(), top3_df['PQ_Distance'].max())
plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
plt.savefig(os.path.join(result_dir, 'distance_scatter.png'), dpi=300)
# 5. 折线图PQ Rank随结果集大小的变化
plt.figure(figsize=(12, 6))
size_grouped = top3_df.groupby(['FullSetSize', 'Rank'])['PQ_Rank'].mean().reset_index()
for rank in [1, 2, 3]:
rank_data = size_grouped[size_grouped['Rank'] == rank]
plt.plot(rank_data['FullSetSize'], rank_data['PQ_Rank'], marker='o', label=f'Rank {rank}')
plt.xlabel('Result Set Size')
plt.ylabel('Average PQ Rank')
plt.title('Average PQ Rank by Result Set Size')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(result_dir, 'pq_rank_by_size.png'), dpi=300)
# 6. 百分比热力图在PQ排名前K的概率
top_k_values = [1, 5, 10, 20, 50, 100, 200, 300, 500, 700, 800, 900]
top_k_probs = []
for rank in [1, 2, 3]:
rank_df = top3_df[top3_df['Rank'] == rank]
probs = []
for k in top_k_values:
prob = (rank_df['PQ_Rank'] <= k).mean() * 100
probs.append(prob)
top_k_probs.append(probs)
plt.figure(figsize=(10, 6))
sns.heatmap(top_k_probs, annot=True, fmt='.1f', cmap='YlGnBu',
xticklabels=[f'Top-{k}' for k in top_k_values],
yticklabels=['Rank 1', 'Rank 2', 'Rank 3'])
plt.title('Probability (%) of Finding Exact Top-K Results in PQ Top-K')
plt.xlabel('PQ Top-K')
plt.ylabel('Exact Rank')
plt.savefig(os.path.join(result_dir, 'topk_probability.png'), dpi=300)
# 7. 生成统计摘要报告
with open(os.path.join(result_dir, 'summary_report.txt'), 'w') as f:
f.write(f"数据分析摘要\n")
f.write(f"=================\n")
f.write(f"总搜索次数: {df['Search#'].nunique()}\n")
f.write(f"使用线程数: {df['ThreadID'].nunique()}\n\n")
f.write("精确排名前3的结果在PQ排序中的平均位置:\n")
for rank in [1, 2, 3]:
avg_pq_rank = top3_df[top3_df['Rank'] == rank]['PQ_Rank'].mean()
median_pq_rank = top3_df[top3_df['Rank'] == rank]['PQ_Rank'].median()
f.write(f" 排名 {rank}: 平均位置 = {avg_pq_rank:.2f}, 中位数位置 = {median_pq_rank:.1f}\n")
f.write("\n各排名结果在PQ排序前K的命中率:\n")
for rank in [1, 2, 3]:
f.write(f" 排名 {rank}:\n")
for k in top_k_values:
hit_rate = (top3_df[top3_df['Rank'] == rank]['PQ_Rank'] <= k).mean() * 100
f.write(f" 在PQ前 {k} 中的命中率: {hit_rate:.2f}%\n")
print(f"分析完成! 结果已保存到 {result_dir} 目录")

View File

@@ -1,137 +0,0 @@
import os
import torch
import numpy as np
import argparse
from tqdm import tqdm
import json
from contriever.src.contriever import load_retriever
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["KMP_BLOCKTIME"] = "0"
torch.set_num_threads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def embed_queries(queries, model, tokenizer, model_name_or_path, per_gpu_batch_size=64):
"""Embed queries using the model with batching"""
model = model.half()
model.eval()
embeddings = []
batch_question = []
with torch.no_grad():
for k, query in tqdm(enumerate(queries), desc="Encoding queries"):
batch_question.append(query)
# Process when batch is full or at the end
if len(batch_question) == per_gpu_batch_size or k == len(queries) - 1:
encoded_batch = tokenizer.batch_encode_plus(
batch_question,
return_tensors="pt",
max_length=512,
padding=True,
truncation=True,
)
encoded_batch = {k: v.to(device) for k, v in encoded_batch.items()}
output = model(**encoded_batch)
# Contriever typically uses output.last_hidden_state pooling or something specialized
# if "contriever" not in model_name_or_path:
# output = output.last_hidden_state[:, 0, :]
embeddings.append(output.cpu())
batch_question = [] # Reset batch
embeddings = torch.cat(embeddings, dim=0).numpy()
print(f"Query embeddings shape: {embeddings.shape}")
return embeddings
def main():
parser = argparse.ArgumentParser(description="Debug embedding tool")
parser.add_argument("--model", type=str, default="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
help="Model name for embedding (default: facebook/contriever-msmarco)")
parser.add_argument("--batch-size", type=int, default=32,
help="Batch size for encoding (default: 32)")
parser.add_argument("--input-file", type=str,
help="Input file with queries (JSON lines format with 'query' field)")
parser.add_argument("--output-file", type=str, default="embeddings.npy",
help="Output file to save embeddings (default: embeddings.npy)")
parser.add_argument("--text", type=str, nargs="+",
help="Direct text input to embed (can provide multiple)")
parser.add_argument("--save-text", action="store_true",
help="Save the input text alongside embeddings")
args = parser.parse_args()
# Load model
print(f"Loading query encoder: {args.model}")
query_encoder, query_tokenizer, _ = load_retriever(args.model)
query_encoder = query_encoder.to(device)
query_encoder.eval()
# Get queries
queries = []
# From file if provided
if args.input_file:
print(f"Loading queries from: {args.input_file}")
with open(args.input_file, "r") as f:
for line in f:
data = json.loads(line)
queries.append(data["query"])
# From command line if provided
if args.text:
print(f"Using {len(args.text)} queries from command line")
queries.extend(args.text)
# If no queries, use some examples
if not queries:
print("No queries provided, using example queries")
queries = [
"Were there any variances detected for hour 6 on 3/9/01?"
]
print(f"Embedding {len(queries)} queries")
for i, q in enumerate(queries[:5]): # Print first 5 queries
print(f"Query {i+1}: {q}")
if len(queries) > 5:
print(f"... and {len(queries)-5} more")
# Encode queries
embeddings = embed_queries(
queries, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size
)
passages = [
"Start Date: 3/9/01; HourAhead hour: 6; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. Variances detected in Energy Import/Export schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001030906.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 20.00 / Final: 19.80) TRANS_TYPE: FINAL SC_ID: TOSCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: UNCHEM_1_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 29.00 / Final: 28.20) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 45.00 / Final: 43.80) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: PANDOL_6_UNIT $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 6 / Preferred: 13.00 / Final: 12.60) TRANS_TYPE: FINAL SC_ID: Wheelabrat MKT_TYPE: 2 TRANS_DATE: 3/9/01 UNIT_ID: MARTEL_2_AMFOR ---- Energy Import/Export Schedule ---- $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 62.00 / Final: 60.40) TRANS_TYPE: FINAL SC_ID: ECTstCA MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5001 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 63.00 / Final: 61.23) TRANS_TYPE: FINAL SC_ID: ECTstSW MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: PVERDE_5_DEVERS INTERCHG_ID: EPMI_CISO_5000 ENGY_TYPE: FIRM $$$ Variance found in table tblINTCHG_IMPEXP. Details: (Hour: 6 / Preferred: 17.00 / Final: 11.00) TRANS_TYPE: FINAL SC_ID: ECTRT MKT_TYPE: 2 TRANS_DATE: 3/9/01 TIE_POINT: SYLMAR_2_NOB INTERCHG_ID: EPMI_CISO_LUCKY ENGY_TYPE: NFRM",
"Start Date: 3/30/01; HourAhead hour: 15; No ancillary schedules awarded. Variances detected. Variances detected in Generation schedule. LOG MESSAGES: PARSING FILE -->> O:\\Portland\\WestDesk\\California Scheduling\\ISO Final Schedules\\2001033015.txt ---- Generation Schedule ---- $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 0.00 / Final: 0.00) TRANS_TYPE: FINAL SC_ID: ARCO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: CARBGN_6_UNIT 1 $$$ Variance found in table tblGEN_SCHEDULE. Details: (Hour: 15 / Preferred: 45.00 / Final: 44.00) TRANS_TYPE: FINAL SC_ID: DELANO MKT_TYPE: 2 TRANS_DATE: 3/30/01 UNIT_ID: PANDOL_6_UNIT"
]
# Embed passages
passage_embeddings = embed_queries(passages, query_encoder, query_tokenizer, args.model, per_gpu_batch_size=args.batch_size)
# distance with passages 0 and query
distance_0 = np.linalg.norm(embeddings[0] - passage_embeddings[0])
print(f"Distance between query 0 and passage 0: {distance_0}")
# distance with passages 1 and query
distance_1 = np.linalg.norm(embeddings[0] - passage_embeddings[1])
print(f"Distance between query 0 and passage 1: {distance_1}")
# print which one is closer
if distance_0 < distance_1:
print("Query 0 is closer to passage 0")
else:
print("Query 0 is closer to passage 1")
print("Done!")
if __name__ == "__main__":
main()

View File

@@ -1,33 +0,0 @@
import json
import os
input_file = "/gscratch/zlab/rulins/data/lm-eval-data/raw_mmlu.jsonl"
output_file = "/gscratch/zlab/rulins/data/lm-eval-data/mmlu.jsonl"
raw_data = []
with open(input_file, "r") as fin:
for line in fin:
raw_data.append(json.loads(line))
def deduplicate_dicts(dict_list):
unique_dicts = set()
unique_items = []
for item in dict_list:
# Make a hashable version of the dictionary by sorting it
hashable_item = tuple(sorted(item.items()))
if hashable_item not in unique_dicts:
unique_dicts.add(hashable_item)
unique_items.append(item)
return unique_items
unique_data = deduplicate_dicts(raw_data)
print(len(unique_data))
with open(output_file, "w") as fout:
for ex in unique_data:
fout.write(json.dumps(ex) + "\n")

View File

@@ -1,167 +0,0 @@
import time
import multiprocessing
from datasketch import MinHash, MinHashLSH
def shingle_document(text, shingle_size=13):
"""Generate word-based shingles from a document."""
# Split the text into words
words = text.split()
# Generate shingles that are sequences of 'shingle_size' consecutive words
shingles = set(
" ".join(words[i : i + shingle_size])
for i in range(len(words) - shingle_size + 1)
)
return shingles
m = MinHash(num_perm=128)
perm = m.permutations
def create_minhash(shingles, num_perm=128):
"""Create a MinHash object from the set of shingles."""
m = MinHash(permutations=perm)
m.update_batch(map(lambda x: x.encode("utf-8"), shingles))
# for shingle in shingles:
# m.update(shingle.encode('utf-8'))
return m
def abstein_string_for_decon(string):
# Abstein the reading comprehension subject in MMLU where a paragraph from Wikipedia is given in the question
return "refers to the following information" in string
def remove_duplicates_with_minhash(
documents, string_for_decontamination=None, threshold=0.8, num_perm=128
):
# Apply 13-gram Jaccard similarity deduplication and removes ones with similarity > 80% compared to former docs.
# Remove chunks shorter than 13 words.
# Create an LSH index
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
# Dictionary to store the MinHash of each document
minhashes = {}
# Hash string for decontamination first so contaminated samples will be removed
decon_offset = 0
if string_for_decontamination is not None and not abstein_string_for_decon(
string_for_decontamination
):
shingles = shingle_document(string_for_decontamination)
m_decon = create_minhash(shingles, num_perm)
lsh.insert(f"doc_{decon_offset}", m_decon)
minhashes[decon_offset] = m_decon
decon_offset = 1
# Populate the LSH index
short_chunk_indices = []
for idx, ctx in enumerate(documents, start=decon_offset):
doc = ctx["retrieval text"]
shingles = shingle_document(doc)
if not shingles:
short_chunk_indices.append(idx - decon_offset)
m = create_minhash(shingles, num_perm)
lsh.insert(f"doc_{idx}", m)
minhashes[idx] = m
# List to keep track of non-duplicate document indices
non_duplicate_indices = []
# Check each document against the LSH index
for idx, m in minhashes.items():
if idx < decon_offset:
continue
# Query the LSH for near-duplicate candidates
result = lsh.query(m)
# print(result)
# print([minhashes[int(doc_id.split("_")[1])].jaccard(m) for doc_id in result])
# If the document is the only one in its bucket or it appears first in the list
if all(
minhashes[int(doc_id.split("_")[1])].jaccard(m) <= threshold
or int(doc_id.split("_")[1]) >= idx
for doc_id in result
):
non_duplicate_indices.append(idx - decon_offset)
# Return non-duplicate documents
deduplicated_documents = [
documents[i] for i in non_duplicate_indices if i not in short_chunk_indices
]
[doc.update({"quality score": 1}) for doc in deduplicated_documents]
removed_documents = [doc for doc in documents if doc not in deduplicated_documents]
[doc.update({"quality score": 0}) for doc in removed_documents]
print(f"Non-deduplication ctxs num: {len(deduplicated_documents)}")
# for c in deduplicated_documents:
# try:
# print(c['retrieval text'][:10])
# except:
# print(c)
# if len(deduplicated_documents[0]['retrieval text'].split(' ')) < 13:
# import pdb; pdb.set_trace()
return deduplicated_documents # + removed_documents
def process_item(data_item):
time.sleep(0.0001)
id_, ex = data_item
ex["ctxs"] = remove_duplicates_with_minhash(
ex["ctxs"], string_for_decontamination=ex["raw_query"]
)
return id_, ex
def multiprocess_deduplication(data):
items_to_process = list(enumerate(data))
pool = multiprocessing.Pool(processes=32)
for result in pool.imap(process_item, items_to_process):
id_, updated_ex = result
data[id_] = updated_ex
return data
if __name__ == "__main__":
# Example usage:
question = (
"Answer these questions:\n\nQ: when did the eagles win last super bowl?\nA:"
)
docs = [
"Eagles won the Super Bowl.",
"Machine learning provides the ability to automatically learn and improve from experience without being explicitly programmed."
* 20,
"Machine learning provides the ability to automatically learn and improve from experience without being explicitly programmed."
* 20
+ ".",
"An entirely different document looks nothing like the others and should not be considered a duplicate."
* 20,
"Short sentence." * 1,
"As someone who lived in Philly for about five years, I agree about the city\u2019s greatness \u2014 which makes the juxtaposition between its friendly day-to-day interactions and sometimes psychotic sports fandom even more jarring. The Eagles did win three NFL championships before the Super Bowl existed, most recently in 1960. But any fan who was following the team back then is now at least into their mid-60s, if not much older. It is, to say the least, a distant memory from another era. Granted, the Sixers went on their infamous tanking expedition during this span.",
] * 1
import time
num_ex = 1
start = time.time()
data1 = []
for _ in range(num_ex):
cleaned_ex = remove_duplicates_with_minhash(
[{"retrieval text": doc} for doc in docs], question
)
data1.append(cleaned_ex)
time1 = time.time() - start
# ori_data = [{'raw_query': docs[0], 'ctxs': [{'retrieval text': doc} for doc in docs]}] * num_ex
# start = time.time()
# data2 = multiprocess_deduplication(ori_data)
# time2 = time.time()-start
# assert data2[0]['ctxs'] == data1[0]
# print(time1)
# print(time2)

View File

@@ -1,387 +0,0 @@
/*
Run with
g++ ./demo_reader.cpp -o ./demo_reader && ./demo_reader --stats \
/powerrag/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/diskann/_partition.bin
\
/powerrag/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/diskann/_disk_graph.index
*/
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits> // Include for std::numeric_limits
#include <string> // Include for std::string comparison
#include <vector>
#define READ_U64(f, val) \
f.read(reinterpret_cast<char *>(&val), sizeof(uint64_t))
#define READ_U32(f, val) \
f.read(reinterpret_cast<char *>(&val), sizeof(uint32_t))
#define SECTOR_SIZE 4096
// Helper: Get file size
static size_t get_file_size(const std::string &fname) {
std::ifstream ifs(fname, std::ios::binary | std::ios::ate);
if (ifs.fail() || !ifs.is_open()) {
return 0;
}
return static_cast<size_t>(ifs.tellg());
}
// Print first few hex of sector for debug
static void print_hex(const char *buf, size_t len, size_t max_len = 64) {
size_t show_len = (len < max_len) ? len : max_len;
for (size_t i = 0; i < show_len; i++) {
unsigned char c = (unsigned char)buf[i];
std::cout << std::hex << std::setw(2) << std::setfill('0') << (unsigned)c
<< " ";
if ((i + 1) % 16 == 0)
std::cout << "\n ";
}
std::cout << std::dec << "\n";
}
/*
Corrected demo_reader:
1) Read from partition.bin:
- C, partition_nums, nd
- graph_partitions[i]: all nodeIDs in partition i
- id2partition[nodeID]: nodeID => partition i
2) Read from _disk_graph.index:
a) sector0 first has 2 ints: meta_n, meta_dim
b) then meta_n uint64_t
e.g.: [0]=nd, [1]=dim, [2]=??, [3]=max_node_len, [4]=C, [5]..??,
[8]=file_size... specific positions need to be combined with relayout writing c) graph_node_len =
max_node_len - dim_in_meta*sizeof(float) 3) User given target_node_id =>
partition_id= id2partition[node_id]
find node index j in graph_partitions[partition_id]
offset = (partition_id+1)*4096 => sector
adjacency_offset= j*graph_node_len => neighbor_count => neighbors
*/
int main(int argc, char **argv) {
bool calculate_stats = false;
// int arg_offset = 0; // Offset for positional arguments
std::string partition_bin;
std::string graph_index;
uint64_t target_node_id = 0; // Initialize
if (argc != 4) {
std::cerr << "Usage:\n"
<< " " << argv[0]
<< " <partition.bin> <disk_graph.index> <target_node_id> (Reads "
"adjacency for a node)\n"
<< " " << argv[0]
<< " --stats <partition.bin> <disk_graph.index> "
"(Calculates degree statistics)\n";
return 1;
}
// Check if the first argument is the stats flag
if (std::string(argv[1]) == "--stats") {
calculate_stats = true;
partition_bin = argv[2];
graph_index = argv[3];
std::cout << "Mode: Calculating Degree Statistics\n";
} else {
// Assume default mode (single node lookup)
calculate_stats = false;
partition_bin = argv[1];
graph_index = argv[2];
try { // Add error handling for stoull
target_node_id = std::stoull(argv[3]);
} catch (const std::invalid_argument &ia) {
std::cerr << "Error: Invalid target_node_id: " << argv[3] << std::endl;
return 1;
} catch (const std::out_of_range &oor) {
std::cerr << "Error: target_node_id out of range: " << argv[3]
<< std::endl;
return 1;
}
std::cout << "Mode: Single Node Lookup for Node ID " << target_node_id
<< "\n";
}
// 1) Read partition.bin
std::ifstream pf(partition_bin, std::ios::binary);
if (!pf.is_open()) {
std::cerr << "Cannot open partition.bin: " << partition_bin << std::endl;
return 1;
}
uint64_t C, partition_nums, nd;
READ_U64(pf, C);
READ_U64(pf, partition_nums);
READ_U64(pf, nd);
std::cout << "[partition.bin header] C=" << C
<< ", partition_nums=" << partition_nums << ", nd=" << nd
<< std::endl;
// Read partition node lists
std::vector<std::vector<uint32_t> > graph_partitions(partition_nums);
for (uint64_t i = 0; i < partition_nums; i++) {
uint32_t psize;
READ_U32(pf, psize);
graph_partitions[i].resize(psize);
pf.read(reinterpret_cast<char *>(graph_partitions[i].data()),
psize * sizeof(uint32_t));
}
// Read _id2partition[node], size= nd
std::vector<uint32_t> id2partition(nd);
pf.read(reinterpret_cast<char *>(id2partition.data()), nd * sizeof(uint32_t));
pf.close();
std::cout << "Done loading partition info.\n";
if (target_node_id >= nd) {
std::cerr << "target_node_id=" << target_node_id
<< " out of range nd=" << nd << std::endl;
return 1;
}
// 2) Parse _disk_graph.index
std::ifstream gf(graph_index, std::ios::binary);
if (!gf.is_open()) {
std::cerr << "Cannot open disk_graph.index: " << graph_index << std::endl;
return 1;
}
// (a) sector0 => first read 2 ints
int meta_n, meta_dim;
gf.read((char *)&meta_n, sizeof(int));
gf.read((char *)&meta_dim, sizeof(int));
std::cout << "[debug] meta_n=" << meta_n << ", meta_dim=" << meta_dim << "\n";
// (b) Read meta_n uint64_t
std::vector<uint64_t> meta_info(meta_n);
gf.read(reinterpret_cast<char *>(meta_info.data()),
meta_n * sizeof(uint64_t));
// Print
for (int i = 0; i < meta_n; i++) {
std::cout << " meta_info[" << i << "]= " << meta_info[i] << "\n";
}
size_t file_size = get_file_size(graph_index);
std::cout << "[disk_graph.index size] " << file_size << " bytes\n";
// **According to relayout log** you said: meta_info[0]=nd=60450220, meta_info[1]=dim=769,
// meta_info[2]=??(16495248?), meta_info[3]=max_node_len=3320,
// meta_info[4]=16 (C),
// meta_info[8]= 15475261440(file size)
// We manually parse here first:
uint64_t nd_in_meta = meta_info[0];
uint64_t dim_in_meta = meta_info[1];
uint64_t max_node_len = meta_info[3];
uint64_t c_in_meta = meta_info[4];
uint64_t entire_file_sz = meta_info[8];
std::cout << "Based on meta_info:\n"
<< " nd_in_meta= " << nd_in_meta
<< ", dim_in_meta= " << dim_in_meta
<< ", max_node_len= " << max_node_len
<< ", c_in_meta= " << c_in_meta
<< ", entire_file_size= " << entire_file_sz << "\n";
// Calculate graph_node_len
uint64_t dim_size = dim_in_meta * sizeof(float);
uint64_t graph_node_len = max_node_len - dim_size;
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
if (calculate_stats) {
// --- Degree Statistics Calculation Mode ---
std::cout << " Calculated graph_node_len = " << graph_node_len << "\n\n";
if (nd == 0) {
std::cerr << "Graph has 0 nodes (nd=0). Cannot calculate stats."
<< std::endl;
gf.close();
return 1;
}
uint32_t min_degree = std::numeric_limits<uint32_t>::max();
uint32_t max_degree = 0;
uint64_t total_degree = 0;
uint64_t nodes_processed = 0;
std::vector<char> sectorBuf(SECTOR_SIZE);
std::cout << "Calculating degrees for " << nd << " nodes across "
<< partition_nums << " partitions..." << std::endl;
for (uint32_t p = 0; p < partition_nums; ++p) {
uint64_t sector_offset = uint64_t(p + 1) * SECTOR_SIZE;
gf.seekg(sector_offset, std::ios::beg);
if (gf.fail()) {
std::cerr << "Error seeking to sector offset for partition " << p
<< std::endl;
gf.close();
return 1;
}
gf.read(sectorBuf.data(), SECTOR_SIZE);
if (gf.fail() && !gf.eof()) {
std::cerr << "Error reading sector data for partition " << p
<< std::endl;
gf.close();
return 1;
}
gf.clear(); // Reset fail bits
const auto &part_list = graph_partitions[p];
for (size_t j = 0; j < part_list.size(); ++j) {
uint64_t node_offset = j * graph_node_len;
if (node_offset + sizeof(uint32_t) > SECTOR_SIZE) {
std::cerr << "Error: Node offset out of sector bounds.\n"
<< " Partition=" << p << ", node_subIndex=" << j
<< ", node_offset=" << node_offset
<< ", graph_node_len=" << graph_node_len << std::endl;
gf.close();
return 1;
}
char *adjacency_ptr = sectorBuf.data() + node_offset;
uint32_t neighbor_count = *reinterpret_cast<uint32_t *>(adjacency_ptr);
min_degree = std::min(min_degree, neighbor_count);
max_degree = std::max(max_degree, neighbor_count);
total_degree += neighbor_count;
nodes_processed++;
}
if (p % 10 == 0 || p == partition_nums - 1) {
std::cout << " Processed partition " << p + 1 << " / "
<< partition_nums << "...\r" << std::flush;
}
}
std::cout << "\nFinished processing partitions." << std::endl;
if (nodes_processed != nd) {
std::cerr << "Warning: Processed " << nodes_processed
<< " nodes, but expected " << nd << std::endl;
}
double avg_degree = (nd > 0) ? static_cast<double>(total_degree) / nd : 0.0;
std::cout << "\n--- Degree Statistics ---\n";
std::cout << "Min Degree: "
<< (min_degree == std::numeric_limits<uint32_t>::max()
? 0
: min_degree)
<< std::endl; // Handle case of 0 nodes
std::cout << "Max Degree: " << max_degree << std::endl;
std::cout << "Avg Degree: " << std::fixed << std::setprecision(2)
<< avg_degree << std::endl;
std::cout << "Total Degree (Sum): " << total_degree << std::endl;
std::cout << "Nodes Processed: " << nodes_processed << std::endl;
} else {
uint64_t nd_in_meta = meta_info[0];
uint64_t c_in_meta = meta_info[4];
uint64_t entire_file_sz = meta_info[8];
std::cout << "Based on meta_info:\n"
<< " nd_in_meta= " << nd_in_meta
<< ", dim_in_meta= " << dim_in_meta
<< ", max_node_len= " << max_node_len
<< ", c_in_meta= " << c_in_meta
<< ", entire_file_size= " << entire_file_sz << "\n";
std::cout << " => graph_node_len= " << graph_node_len << "\n\n";
if (target_node_id >= nd) {
std::cerr << "target_node_id=" << target_node_id
<< " out of range nd=" << nd << std::endl;
gf.close();
return 1;
}
// We need id2partition only for single-node lookup
std::vector<uint32_t> id2partition(nd);
{ // Read id2partition again as it was skipped before
std::ifstream pf_again(partition_bin, std::ios::binary);
uint64_t header_offset =
3 * sizeof(uint64_t); // Skip C, partition_nums, nd
uint64_t partition_list_offset = 0;
for (uint64_t i = 0; i < partition_nums; i++) {
partition_list_offset += sizeof(uint32_t); // Size field
partition_list_offset +=
graph_partitions[i].size() * sizeof(uint32_t); // Data
}
pf_again.seekg(header_offset + partition_list_offset, std::ios::beg);
pf_again.read(reinterpret_cast<char *>(id2partition.data()),
nd * sizeof(uint32_t));
// Error check pf_again if needed
}
// 3) Find target_node_id => partition_id => subIndex
uint32_t partition_id = id2partition[target_node_id];
if (partition_id >= partition_nums) {
std::cerr << "Partition ID out-of-range for target node.\n";
gf.close();
return 1;
}
const auto &part_list = graph_partitions[partition_id]; // Use const ref
auto it =
std::find(part_list.begin(), part_list.end(), (uint32_t)target_node_id);
if (it == part_list.end()) {
std::cerr << "Cannot find node " << target_node_id << " in partition "
<< partition_id << std::endl;
gf.close();
return 1;
}
size_t j = std::distance(part_list.begin(), it);
// 4) sector => (partition_id+1)* 4096
uint64_t sector_offset = uint64_t(partition_id + 1) * SECTOR_SIZE;
gf.seekg(sector_offset, std::ios::beg);
std::vector<char> sectorBuf(SECTOR_SIZE);
gf.read(sectorBuf.data(), SECTOR_SIZE);
if (gf.fail() && !gf.eof()) {
std::cerr << "Error reading sector data for partition " << partition_id
<< std::endl;
gf.close();
return 1;
}
gf.clear(); // Reset fail bits
std::cout << "Partition #" << partition_id
<< ", nodeCount= " << part_list.size()
<< ", offset= " << sector_offset << "\n"
<< " first64 hex:\n ";
print_hex(sectorBuf.data(), SECTOR_SIZE, 64);
// adjacency_offset= j* graph_node_len
uint64_t node_offset = j * graph_node_len;
if (node_offset + sizeof(uint32_t) >
SECTOR_SIZE) { // Check only for neighbor_count read first
std::cerr << "Out-of-range. j=" << j << ", node_offset=" << node_offset
<< ", node_offset+4=" << (node_offset + sizeof(uint32_t))
<< "> 4096\n";
gf.close();
return 1;
}
char *adjacency_ptr = sectorBuf.data() + node_offset;
uint32_t neighbor_count = *reinterpret_cast<uint32_t *>(adjacency_ptr);
std::cout << "[Node " << target_node_id << "] partition=" << partition_id
<< ", subIndex=" << j << ", adjacency_offset=" << node_offset
<< ", neighbor_count=" << neighbor_count << "\n";
size_t needed = neighbor_count * sizeof(uint32_t);
if (node_offset + sizeof(uint32_t) + needed > SECTOR_SIZE) {
std::cerr << "Neighbors partly out-of-range => neighbor_count="
<< neighbor_count << "\n";
// Option: Can still print partial list if needed, but indicating it's
// truncated
gf.close();
return 1; // Or handle differently
}
std::vector<uint32_t> neighbors(neighbor_count);
memcpy(neighbors.data(), adjacency_ptr + sizeof(uint32_t), needed);
std::cout << " neighbors=[";
for (size_t kk = 0; kk < std::min<size_t>(10, neighbor_count); kk++) {
std::cout << neighbors[kk];
if (kk + 1 < std::min<size_t>(10, neighbor_count))
std::cout << ", ";
}
if (neighbor_count > 10)
std::cout << " ... (total " << neighbor_count << ")";
std::cout << "]\n";
} // End of else (single node lookup mode)
gf.close();
return 0;
}

View File

@@ -1,15 +0,0 @@
#! /bin/fish
# get the dir of this script
set -x SCRIPT_DIR (dirname (realpath $0))
g++ $SCRIPT_DIR/analyze_diskann_graph.cpp -o $SCRIPT_DIR/analyze_diskann_graph
# get args
set -x INDEX_PATH $argv[1]
./analyze_diskann_graph $INDEX_PATH $INDEX_PATH.degree_distribution.txt
python plot_degree_distribution.py $INDEX_PATH.degree_distribution.txt
rm $INDEX_PATH.degree_distribution.txt

View File

@@ -1,30 +0,0 @@
#!/usr/bin/env fish
set scaling_out_dir "/Users/ec2-user/scaling_out"
# Define an array of paths to download
set paths \
"examples/" \
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/_disk_graph.index" \
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/_partition.bin" \
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_medoids.bin" \
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_centroids.bin" \
"indices/rpj_wiki/facebook/contriever-msmarco/diskann/ann_disk.index_max_base_norm.bin" \
"embeddings/facebook/contriever-msmarco/rpj_wiki/compressed_10/" \
"passages/rpj_wiki/8-shards/" \
"indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json"
# Download each path using a for loop
for path in $paths
echo "Downloading $path..."
# if ends with /, then create the directory
if string match -q "*/" $path
echo "Creating directory $scaling_out_dir/$path"
mkdir -p "$scaling_out_dir/$path"
aws s3 cp "s3://retrieval-scaling-out/$path" "$scaling_out_dir/$path" --recursive
else
aws s3 cp "s3://retrieval-scaling-out/$path" "$scaling_out_dir/$path"
end
end
echo "Download completed."

View File

@@ -1,422 +0,0 @@
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.stats import kendalltau, spearmanr
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else
("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"使用设备: {device}")
# 定义自定义比较函数(基于内积)
def compare(a, b):
"""
计算两个向量的内积,并返回其负值作为距离度量
数值越小表示越相似(与提供的代码一致)
"""
result = np.dot(a, b)
return -result # 返回负值,与原代码一致
# 批量计算相似度
def compute_similarities(queries, corpus):
"""计算查询向量与语料库向量之间的相似度矩阵"""
similarities = np.zeros((len(queries), len(corpus)))
for i, query in enumerate(queries):
for j, doc in enumerate(corpus):
similarities[i, j] = compare(query, doc)
return similarities
# 加载两个模型
model_names = [
"facebook/contriever-msmarco", # Contriever模型
"facebook/contriever-msmarco-int4" # Contriever模型 (int4)
]
# 扩展的样本文本 - 分为多个主题组
texts = [
# 组1: 关于狐狸和动物 (0-9)
"The quick brown fox jumps over the lazy dog.",
"A rapid auburn fox leaps above the inactive canine.",
"The sly fox outsmarts the hunting hounds in the forest.",
"Foxes are known for their cunning behavior and bushy tails.",
"The red fox is the largest of the true foxes and the most common fox species.",
"Dogs have been companions to humans for thousands of years.",
"The lazy dog slept through the commotion of the playful fox.",
"Wolves and foxes belong to the same family, Canidae.",
"The arctic fox changes its coat color with the seasons.",
"Domestic dogs come in hundreds of breeds of various sizes and appearances.",
# 组2: 人工智能和机器学习 (10-19)
"Machine learning is a branch of artificial intelligence.",
"Deep learning is a subset of machine learning.",
"Neural networks are computing systems inspired by biological neural networks.",
"AI systems can now beat human champions at complex games like chess and Go.",
"Natural language processing allows computers to understand human language.",
"Reinforcement learning involves training agents to make sequences of decisions.",
"Computer vision enables machines to derive information from images and videos.",
"The Turing test measures a machine's ability to exhibit intelligent behavior.",
"Supervised learning uses labeled training data to learn the mapping function.",
"Unsupervised learning finds patterns in data without pre-existing labels.",
# 组3: 巴黎和法国地标 (20-29)
"The Eiffel Tower is located in Paris, France.",
"The Louvre Museum is in the city of Paris.",
"Notre-Dame Cathedral is a medieval Catholic cathedral on the Île de la Cité in Paris.",
"The Arc de Triomphe stands at the center of the Place Charles de Gaulle in Paris.",
"The Seine River flows through the heart of Paris.",
"Montmartre is a large hill in Paris's 18th arrondissement known for its artistic history.",
"The Palace of Versailles is located in the Île-de-France region of France.",
"The Champs-Élysées is an avenue in Paris famous for its theatres, cafés, and luxury shops.",
"The Sacré-Cœur Basilica offers one of the most beautiful panoramic views of Paris.",
"The Musée d'Orsay houses the largest collection of impressionist masterpieces in the world.",
# 组4: 可再生能源 (30-39)
"Solar panels convert sunlight into electricity.",
"Wind turbines generate power from moving air.",
"Hydroelectric power is generated from flowing water.",
"Geothermal energy harnesses heat from within the Earth.",
"Biomass energy comes from organic materials like plants and waste.",
"Tidal energy uses the natural rise and fall of coastal tidal waters.",
"Renewable energy sources can help reduce greenhouse gas emissions.",
"Solar farms can span hundreds of acres with thousands of panels.",
"Offshore wind farms are built in bodies of water to harvest wind energy.",
"Energy storage systems are crucial for balancing renewable energy supply and demand.",
# 组5: 编程语言 (40-49)
"Python is a popular programming language for data science.",
"JavaScript is commonly used for web development.",
"Java is known for its 'write once, run anywhere' capability.",
"C++ provides high-performance and close hardware control.",
"Ruby is praised for its simplicity and productivity.",
"PHP is a server-side scripting language designed for web development.",
"Swift is used to develop applications for Apple platforms.",
"Rust offers memory safety without using garbage collection.",
"Go was designed at Google to improve programming productivity.",
"Kotlin is fully interoperable with Java and provides more concise syntax.",
]
# 扩展的查询句子
query_texts = [
# 动物相关查询
"A fox jumped over a dog.",
"Wild animals and their behaviors in forests.",
"Different species of foxes around the world.",
# AI相关查询
"Artificial intelligence and neural networks.",
"Machine learning algorithms and applications.",
"The future of deep learning technology.",
# 巴黎相关查询
"Famous landmarks in Paris, France.",
"Tourist attractions along the Seine River.",
"Historical buildings and museums in Paris.",
# 能源相关查询
"Renewable energy sources and sustainability.",
"Solar and wind power generation technologies.",
"Alternative clean energy solutions for the future.",
# 编程相关查询
"Computer programming languages comparison.",
"Best languages for web development.",
"Programming tools for data science applications."
]
# 函数获取BGE模型的嵌入
def get_bge_embeddings(model, tokenizer, texts, device):
# 处理大量文本时分批进行
batch_size = 16
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
max_length=512, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
# BGE使用[CLS]标记
embeddings = model_output.last_hidden_state[:, 0]
# 归一化嵌入
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(normalized_embeddings.cpu().numpy())
return np.vstack(all_embeddings)
# 函数获取Contriever模型的嵌入
def get_contriever_embeddings(model, tokenizer, texts, device, use_int4=False):
# 处理大量文本时分批进行
batch_size = 16
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
encoded_input = tokenizer(batch_texts, padding=True, truncation=True,
max_length=512, return_tensors='pt').to(device)
with torch.no_grad():
model_output = model(**encoded_input)
# Contriever使用平均池化
attention_mask = encoded_input['attention_mask'].unsqueeze(-1)
embeddings = (model_output.last_hidden_state * attention_mask).sum(1) / attention_mask.sum(1)
# 归一化嵌入
normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(normalized_embeddings.cpu().numpy())
return np.vstack(all_embeddings)
# 主函数
def compare_embeddings():
results = {}
for i, model_name in enumerate(model_names):
model_display_name = model_name
# 给第二个模型一个不同的显示名称,以便区分
if i == 1:
model_display_name = "facebook/contriever-msmarco-int4"
print(f"\n======= 加载模型 {i+1}: {model_display_name} =======")
tokenizer = AutoTokenizer.from_pretrained(model_names[0]) # 两个模型使用相同的tokenizer
# 如果是第二个模型int4版本进行量化
if i == 1:
print("应用int4量化...")
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = AutoModel.from_pretrained(
model_names[0], # 使用相同的基础模型
quantization_config=quantization_config,
device_map="auto"
)
print("成功加载int4模型")
except Exception as e:
print(f"int4加载失败: {e}")
print("回退到标准模型...")
model = AutoModel.from_pretrained(model_names[0]).to(device)
else:
model = AutoModel.from_pretrained(model_names[0]).to(device)
model.eval()
print(f"使用 {model_display_name} 生成嵌入...")
# 所有模型都使用contriever
use_int4 = i == 1
corpus_embeddings = get_contriever_embeddings(model, tokenizer, texts, device, use_int4)
query_embeddings = get_contriever_embeddings(model, tokenizer, query_texts, device, use_int4)
print(f"语料库嵌入形状: {corpus_embeddings.shape}")
print(f"查询嵌入形状: {query_embeddings.shape}")
# 使用自定义函数计算相似度
similarity_scores = compute_similarities(query_embeddings, corpus_embeddings)
# 对每个查询,按相似度排序文本索引(较小的值表示更相似)
ranked_indices = {}
for j, scores in enumerate(similarity_scores):
# 按相似度从低到高排序(因为我们返回的是负内积值)
sorted_indices = np.argsort(scores)
ranked_indices[f"query_{j+1}"] = sorted_indices
results[model_display_name] = {
'corpus_embeddings': corpus_embeddings,
'query_embeddings': query_embeddings,
'similarity_scores': similarity_scores,
'ranked_indices': ranked_indices
}
# 立即打印这个模型的一些结果作为验证
print(f"\n=== {model_display_name} 初步结果 ===")
# 显示第一个查询的前3个结果
query_idx = 0
ranked_idx = ranked_indices[f"query_{query_idx+1}"]
top_texts = [texts[idx] for idx in ranked_idx[:3]]
print(f"查询: '{query_texts[query_idx]}'")
print(f"排名前3位的文本:")
for j, text in enumerate(top_texts):
idx = ranked_idx[j]
score = similarity_scores[query_idx][idx]
print(f" {j+1}. [ID:{idx}] {text} (分数: {score:.4f})")
return results
# 分析结果
def analyze_results(results):
models = list(results.keys())
# 1. 比较相似度分数
print("\n=== 相似度分数比较 ===")
for model_name, result in results.items():
similarities = result['similarity_scores'].flatten()
print(f"{model_name} 相似度统计:")
print(f" 平均值: {similarities.mean():.4f}")
print(f" 最小值: {similarities.min():.4f}")
print(f" 最大值: {similarities.max():.4f}")
print(f" 标准差: {similarities.std():.4f}")
# 2. 比较排序结果针对每个查询显示前5个结果
print("\n=== 排序结果比较 ===")
for query_idx in range(len(query_texts)):
query_key = f"query_{query_idx+1}"
print(f"\n查询 {query_idx+1}: '{query_texts[query_idx]}'")
for model_name in models:
ranked_idx = results[model_name]['ranked_indices'][query_key]
top_texts = [texts[idx] for idx in ranked_idx[:5]]
print(f"{model_name} 排名前5位的文本:")
for i, text in enumerate(top_texts):
idx = ranked_idx[i]
score = results[model_name]['similarity_scores'][query_idx][idx]
print(f" {i+1}. [ID:{idx}] {text} (分数: {score:.4f})")
# 3. 排序一致性分析
print("\n=== 模型间排序一致性分析 ===")
kendall_tau_scores = []
spearman_scores = []
for query_idx in range(len(query_texts)):
query_key = f"query_{query_idx+1}"
# 获取各模型的排序结果只比较前10个结果
model1_top10 = results[models[0]]['ranked_indices'][query_key][:10]
model2_top10 = results[models[1]]['ranked_indices'][query_key][:10]
# 计算排序一致性
kt, _ = kendalltau(model1_top10, model2_top10)
sr, _ = spearmanr(model1_top10, model2_top10)
kendall_tau_scores.append(kt)
spearman_scores.append(sr)
# 计算前10个结果的重叠率
overlap = len(set(model1_top10) & set(model2_top10))
overlap_rate = overlap / 10.0
print(f"查询 {query_idx+1} '{query_texts[query_idx]}':")
print(f" Kendall's Tau = {kt:.4f}, Spearman's rank correlation = {sr:.4f}")
print(f" 前10结果重叠率: {overlap_rate:.2f} ({overlap}/10)")
print(f"\n平均 Kendall's Tau: {np.mean(kendall_tau_scores):.4f}")
print(f"平均 Spearman's rank correlation: {np.mean(spearman_scores):.4f}")
# 4. 可视化相似度分布差异
plt.figure(figsize=(12, 6))
for i, model_name in enumerate(models):
sns.histplot(results[model_name]['similarity_scores'].flatten(),
kde=True, label=model_name, alpha=0.6)
plt.title('不同模型的相似度分布')
plt.xlabel('相似度得分(越小越相似)')
plt.ylabel('频率')
plt.legend()
plt.savefig('similarity_distribution.png')
print("已保存相似度分布图表到 'similarity_distribution.png'")
# 5. 可视化主题相关性
plt.figure(figsize=(15, 10))
# 为每个主题组定义颜色
topic_colors = {
'动物': 'blue',
'AI': 'red',
'巴黎': 'green',
'能源': 'purple',
'编程': 'orange'
}
# 定义主题组范围
topic_ranges = {
'动物': (0, 10),
'AI': (10, 20),
'巴黎': (20, 30),
'能源': (30, 40),
'编程': (40, 50)
}
# 对每个查询显示前10个结果的主题分布
query_groups = [
[0, 1, 2], # 动物查询组
[3, 4, 5], # AI查询组
[6, 7, 8], # 巴黎查询组
[9, 10, 11], # 能源查询组
[12, 13, 14] # 编程查询组
]
for group_idx, group in enumerate(query_groups):
plt.subplot(len(query_groups), 1, group_idx+1)
# 为每个模型计算主题分布
bar_width = 0.35
bar_positions = np.arange(len(topic_ranges))
for model_idx, model_name in enumerate(models):
# 统计每个主题在前10个结果中的出现次数
topic_counts = {topic: 0 for topic in topic_ranges.keys()}
for query_idx in group:
query_key = f"query_{query_idx+1}"
top10 = results[model_name]['ranked_indices'][query_key][:10]
for idx in top10:
for topic, (start, end) in topic_ranges.items():
if start <= idx < end:
topic_counts[topic] += 1
# 绘制主题分布柱状图
plt.bar(bar_positions + (model_idx * bar_width),
list(topic_counts.values()),
bar_width,
label=model_name)
plt.title(f"查询组 {group_idx+1}: {', '.join([query_texts[i] for i in group[:1]])}")
plt.xticks(bar_positions + bar_width/2, list(topic_ranges.keys()))
plt.ylabel('前10结果中的出现次数')
plt.legend()
plt.tight_layout()
plt.savefig('topic_distribution.png')
print("已保存主题分布图表到 'topic_distribution.png'")
# 6. 可视化查询与相关文档的相似度热图
plt.figure(figsize=(15, 12))
for i, model_name in enumerate(models):
plt.subplot(2, 1, i+1)
# 获取相似度矩阵(负数越小表示越相似)
sim_matrix = results[model_name]['similarity_scores']
# 将负值转换为正值以便可视化(越大表示越相似)
sim_matrix_viz = -sim_matrix
# 绘制热图
sns.heatmap(sim_matrix_viz, cmap='YlGnBu',
xticklabels=[f"Doc{i}" for i in range(len(texts))],
yticklabels=[f"Q{i+1}" for i in range(len(query_texts))],
cbar_kws={'label': '相似度(越高越相似)'})
plt.title(f"{model_name} 相似度热图")
plt.xlabel('文档ID')
plt.ylabel('查询ID')
plt.tight_layout()
plt.savefig('similarity_heatmap.png')
print("已保存相似度热图到 'similarity_heatmap.png'")
if __name__ == "__main__":
print("开始比较嵌入模型...")
results = compare_embeddings()
analyze_results(results)
print("\n比较完成!")

View File

@@ -1,444 +0,0 @@
# Filename: evaluate_results_xai_line_sync.py
import openai
import json
import os
import time
from dotenv import load_dotenv
from tqdm import tqdm
from collections import defaultdict
import concurrent.futures
from typing import List, Dict, Any, Tuple
# --- Configuration ---
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("Please set the OPENAI_API_KEY in your .env file")
try:
client = openai.OpenAI(
api_key=OPENAI_API_KEY,
)
except ImportError:
print("Please install the latest OpenAI library: pip install --upgrade openai")
exit()
except openai.AuthenticationError:
print("OpenAI library reported an AuthenticationError. Ensure OPENAI_API_KEY is correct.")
exit()
LLM_MODEL = "gpt-3.5-turbo" # Using OpenAI's standard model
MAX_RETRIES = 5
INITIAL_RETRY_DELAY_SECONDS = 5
REQUEST_TIMEOUT_SECONDS = 90
MAX_WORKERS = 10 # Number of parallel workers
# --- File Paths (Adjust as needed) ---
# User provided paths
QUERIES_FILE_PATH = "/opt/dlami/nvme/scaling_out/examples/enron_eval_retrieval.jsonl"
RAW_PASSAGES_FILE_PATH = "/opt/dlami/nvme/scaling_out/passages/enron_emails/1-shards/raw_passages-0-of-1.jsonl"
RESULTS_FILE_PATH = "search_results_top10_bm25.jsonl" # This file's Nth line corresponds to QUERIES_FILE_PATH's Nth line
OUTPUT_EVALUATION_FILE = "llm_containment_evaluations_xai_line_sync.jsonl"
# --- LLM Prompt Definitions for Containment (Same as before) ---
CONTAINMENT_SYSTEM_PROMPT = """You are an AI evaluator. Your task is to determine if the core information presented in the 'Retrieved Passage' is directly contained within *any* of the text snippets provided in the 'Ground Truth Email Snippets' list."""
CONTAINMENT_USER_TEMPLATE = """Retrieved Passage:
"{retrieved_passage_text}"
---
Ground Truth Email Snippets (Parts of the correct source email):
{ground_truth_snippets_formatted_list}
---
Is the core information of the 'Retrieved Passage' directly present or fully contained within *any* of the 'Ground Truth Email Snippets' listed above?
- Focus on whether the specific facts or statements in the 'Retrieved Passage' can be found within the ground truth snippets.
- Ignore minor formatting differences. If the retrieved passage is a direct quote or a very close paraphrase of content within the ground truth snippets, answer YES.
- Respond YES if the Retrieved Passage's content is clearly represented in one or more of the ground truth snippets.
- Respond NO if the Retrieved Passage's content is not found, is contradictory, or introduces significant information not present in the ground truth snippets.
Your response must be a single word: YES or NO.
"""
# --- Data Loading Functions ---
def load_queries_as_list(file_path):
"""
Loads queries from a jsonl file into a list, preserving order.
Each item in the list is a dict containing original_id, query_text, and ground_truth_message_ids.
"""
queries_list = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
required_keys = ["id", "query", "ground_truth_message_ids"]
if not all(key in data for key in required_keys):
print(f"Warning: Skipping line {line_num + 1} in query file due to missing keys: {line.strip()}")
continue
if not isinstance(data["ground_truth_message_ids"], list):
print(f"Warning: 'ground_truth_message_ids' is not a list in line {line_num + 1}. Skipping: {line.strip()}")
continue
queries_list.append({
"original_id": data["id"], # Store the original ID from the file
"query_text": data["query"],
"ground_truth_message_ids": data["ground_truth_message_ids"]
})
except json.JSONDecodeError:
print(f"Warning: Skipping malformed JSON line {line_num + 1} in query file: {line.strip()}")
except FileNotFoundError:
print(f"Error: Queries file not found at {file_path}")
exit()
print(f"Loaded {len(queries_list)} queries (as a list) from {file_path}")
return queries_list
def load_all_passages_by_message_id(raw_passages_file_path):
"""Loads all raw passages into memory, grouped by message_id. (Same as before)"""
passages_dict = defaultdict(list)
# ... (implementation from previous script, no changes needed here) ...
print(f"Loading all raw passages from {raw_passages_file_path} into memory...")
try:
with open(raw_passages_file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
if "message_id" in data and "text" in data:
passages_dict[data["message_id"]].append(data["text"])
else:
print(f"Warning: Skipping line {line_num+1} in raw passages file due to missing 'message_id' or 'text'.")
except json.JSONDecodeError:
print(f"Warning: Skipping malformed JSON line {line_num + 1} in raw passages file: {line.strip()}")
print(f"Finished loading raw passages. Found {len(passages_dict)} unique message IDs.")
except FileNotFoundError:
print(f"Error: Raw passages file not found at {raw_passages_file_path}")
exit()
except MemoryError:
print("Error: Ran out of memory loading all raw passages. Consider an indexed approach.")
exit()
return dict(passages_dict)
def load_search_results_as_list(file_path):
"""Loads search results from a jsonl file into a list, preserving order."""
results_list = []
# ... (implementation similar to load_queries_as_list, parsing each line as a dict) ...
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
# We expect "query_id" (though not used for matching) and "passages"
if "passages" not in data: # query_id might be implicitly by order
print(f"Warning: Skipping line {line_num + 1} in search results file due to missing 'passages' key: {line.strip()}")
continue
results_list.append(data)
except json.JSONDecodeError:
print(f"Warning: Skipping malformed JSON line {line_num + 1} in search results file: {line.strip()}")
except FileNotFoundError:
print(f"Error: Search results file not found at {file_path}")
exit()
print(f"Loaded {len(results_list)} search result sets (as a list) from {file_path}")
return results_list
def format_ground_truth_snippets(snippet_list):
"""Formats the list of ground truth snippets for the prompt. (Same as before)"""
# ... (implementation from previous script) ...
if not snippet_list:
return " [No ground truth snippets found for the target message ID(s)]"
formatted = []
for i, snippet in enumerate(snippet_list):
display_snippet = (snippet[:500] + '...') if len(snippet) > 500 else snippet
formatted.append(f" {i+1}. {display_snippet}")
return "\n".join(formatted)
# --- LLM API Call Function ---
def get_llm_containment_evaluation(retrieved_passage_text: str, ground_truth_snippets_list: List[str], query_id_for_log: str, passage_identifier_info: str, query_text_for_context: str = "") -> str:
"""Calls the OpenAI API with retry logic."""
formatted_gt_snippets = format_ground_truth_snippets(ground_truth_snippets_list)
# max_gt_chars_in_prompt = 5000 # Arbitrary limit, adjust as needed
# if len(formatted_gt_snippets) > max_gt_chars_in_prompt:
# print(f"Warning: Ground truth snippets for Q_log_id:{query_id_for_log} are too long ({len(formatted_gt_snippets)} chars), truncating for LLM prompt.")
# formatted_gt_snippets = formatted_gt_snippets[:max_gt_chars_in_prompt] + "\n [... Snippets Truncated ...]"
user_prompt = CONTAINMENT_USER_TEMPLATE.format(
retrieved_passage_text=retrieved_passage_text,
ground_truth_snippets_formatted_list=formatted_gt_snippets
)
messages = [
{"role": "system", "content": CONTAINMENT_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
]
current_retry_delay = INITIAL_RETRY_DELAY_SECONDS
for attempt in range(MAX_RETRIES):
try:
response = client.chat.completions.create(
model=LLM_MODEL,
messages=messages,
temperature=0.0,
max_tokens=10,
timeout=REQUEST_TIMEOUT_SECONDS
)
answer = response.choices[0].message.content.strip().upper()
if answer in ["YES", "NO"]:
return answer
else:
print(f"Warning: Unexpected LLM response content '{answer[:100]}' for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Defaulting to NO.")
return "NO"
except openai.APIConnectionError as e:
error_message = f"API Connection Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
except openai.RateLimitError as e:
error_message = f"API Rate Limit Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e}"
except openai.APIStatusError as e:
error_message = f"API Status Error (Attempt {attempt + 1}/{MAX_RETRIES}): {e.status_code} - {e.response}"
if e.status_code == 401:
return "ERROR_AUTH"
if e.status_code == 500:
pass
else:
return "ERROR_API_CLIENT"
except Exception as e:
error_message = f"Unexpected error with OpenAI lib (Attempt {attempt + 1}/{MAX_RETRIES}): {type(e).__name__} - {e}"
print(f"{error_message}. Query Log ID: {query_id_for_log}, Passage: {passage_identifier_info}")
if "ERROR_AUTH" in error_message or "ERROR_API_CLIENT" in error_message:
break
if attempt < MAX_RETRIES - 1:
print(f"Retrying in {current_retry_delay} seconds...")
time.sleep(current_retry_delay)
current_retry_delay = min(current_retry_delay * 2, 60)
else:
print(f"Max retries ({MAX_RETRIES}) reached for Q_log_id:{query_id_for_log} P:{passage_identifier_info}. Skipping.")
return "ERROR_MAX_RETRIES"
return "ERROR_MAX_RETRIES"
def process_query_passage_pair(args: Tuple[Dict[str, Any], Dict[str, Any], Dict[str, List[str]], set]) -> List[Dict[str, Any]]:
"""Process a single query-passage pair for parallel execution."""
query_info, result_item, passages_lookup, already_evaluated = args
evaluations = []
query_original_id = query_info["original_id"]
query_text = query_info["query_text"]
target_message_ids = query_info.get("ground_truth_message_ids", [])
if not target_message_ids:
return evaluations
ground_truth_snippets = []
for msg_id_in_query_file in target_message_ids:
msg_id_to_lookup = msg_id_in_query_file
if msg_id_in_query_file.startswith("<") and msg_id_in_query_file.endswith(">"):
msg_id_to_lookup = msg_id_in_query_file[1:-1]
snippets = passages_lookup.get(msg_id_to_lookup)
if snippets:
ground_truth_snippets.extend(snippets)
if not ground_truth_snippets:
return evaluations
retrieved_passages = result_item.get("passages", [])
if not retrieved_passages:
return evaluations
for passage_idx, passage_obj in enumerate(retrieved_passages):
if not isinstance(passage_obj, dict):
print(f"Warning: Invalid passage format for Q_original_id:{query_original_id}, passage index {passage_idx}. Skipping passage.")
continue
retrieved_passage_text = passage_obj.get("text", "").strip()
passage_identifier = passage_obj.get("passage_id", passage_obj.get("id", f"retrieved_idx_{passage_idx}"))
evaluation_key = (query_original_id, passage_identifier)
if evaluation_key in already_evaluated:
continue
passage_text_preview = (retrieved_passage_text[:75] + '...') if len(retrieved_passage_text) > 75 else retrieved_passage_text
if not retrieved_passage_text:
evaluation = "NO"
else:
evaluation = get_llm_containment_evaluation(
retrieved_passage_text,
ground_truth_snippets,
query_original_id,
passage_identifier,
query_text
)
if evaluation == "ERROR_AUTH":
print("Authentication error with OpenAI API. Stopping script.")
return evaluations
evaluation_record = {
"query_original_id": query_original_id,
"passage_identifier": passage_identifier,
"passage_text_preview": passage_text_preview,
"evaluation": evaluation,
"model_used": LLM_MODEL,
"ground_truth_message_ids_checked": target_message_ids
}
evaluations.append(evaluation_record)
return evaluations
# --- Resume Logic ---
def load_existing_evaluations(output_file):
"""Loads already evaluated query-passage pairs using 'passage_identifier' and 'query_original_id'. (Same as before, but keying with original_id)"""
# ... (implementation from previous script, ensure it uses the correct ID for keys) ...
evaluated_pairs = set()
if os.path.exists(output_file):
print(f"Loading existing containment evaluations from {output_file}...")
with open(output_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
try:
data = json.loads(line)
# Key for resuming should be based on the logged original query ID
query_original_id = data.get('query_original_id')
passage_identifier = data.get('passage_identifier')
if query_original_id is not None and passage_identifier is not None:
evaluated_pairs.add((query_original_id, passage_identifier))
else:
print(f"Warning: Could not identify query_original_id/passage_identifier in existing file line {line_num + 1}.")
except json.JSONDecodeError:
print(f"Warning: Skipping malformed line {line_num + 1} in existing file: {line.strip()}")
except KeyError as e:
print(f"Warning: Skipping line {line_num + 1} with missing key '{e}' in existing file: {line.strip()}")
print(f"Loaded {len(evaluated_pairs)} existing evaluation records.")
else:
print(f"No existing evaluation file found at {output_file}. Starting fresh.")
return evaluated_pairs
# --- Main Execution Logic ---
def main():
"""Main function to run the containment evaluation process using parallel processing."""
print(f"Starting containment evaluation using OpenAI model: {LLM_MODEL} via OpenAI library interface.")
# Load data as lists
queries_list = load_queries_as_list(QUERIES_FILE_PATH)
passages_lookup = load_all_passages_by_message_id(RAW_PASSAGES_FILE_PATH)
search_results_list = load_search_results_as_list(RESULTS_FILE_PATH)
if not queries_list or not search_results_list or not passages_lookup:
print("Error loading one or more input files or raw passages. Exiting.")
return
# Determine the number of items to process
num_items_to_process = min(len(queries_list), len(search_results_list))
print(f"Will process {num_items_to_process} query-result pairs.")
already_evaluated = load_existing_evaluations(OUTPUT_EVALUATION_FILE)
try:
with open(OUTPUT_EVALUATION_FILE, 'a', encoding='utf-8') as outfile:
# Prepare arguments for parallel processing
process_args = [
(queries_list[i], search_results_list[i], passages_lookup, already_evaluated)
for i in range(num_items_to_process)
]
# Use ThreadPoolExecutor for parallel processing
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submit all tasks and get futures
futures = [executor.submit(process_query_passage_pair, args) for args in process_args]
# Process results as they complete
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing query-result pairs"):
try:
evaluations = future.result()
for evaluation in evaluations:
outfile.write(json.dumps(evaluation) + "\n")
outfile.flush()
# Update already_evaluated set
already_evaluated.add((evaluation["query_original_id"], evaluation["passage_identifier"]))
except Exception as e:
print(f"Error processing query-result pair: {e}")
except IOError as e:
print(f"Error writing to output file {OUTPUT_EVALUATION_FILE}: {e}")
return
except Exception as e:
print(f"An unexpected error occurred during the main processing loop: {e}")
return
print("\n--- Containment Evaluation Script Finished ---")
# --- Final Summary Calculation ---
print(f"Calculating final summary statistics from: {OUTPUT_EVALUATION_FILE}")
final_query_containment_found = {}
total_evaluated_pairs = 0
error_count = 0
evaluated_query_original_ids = set()
try:
with open(OUTPUT_EVALUATION_FILE, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
total_evaluated_pairs += 1
try:
data = json.loads(line)
q_original_id = data['query_original_id']
eval_result = data['evaluation']
evaluated_query_original_ids.add(q_original_id)
if eval_result == "YES":
final_query_containment_found[q_original_id] = True
elif q_original_id not in final_query_containment_found:
final_query_containment_found[q_original_id] = False
if eval_result not in ["YES", "NO"]:
error_count += 1
except (json.JSONDecodeError, KeyError) as e:
print(f"Error reading line {line_num + 1} during summary: {e} - Line: {line.strip()}")
error_count += 1
num_queries_with_any_contained = sum(1 for contained in final_query_containment_found.values() if contained)
total_unique_queries_evaluated = len(evaluated_query_original_ids)
if total_unique_queries_evaluated > 0:
containment_rate_at_10 = num_queries_with_any_contained / total_unique_queries_evaluated
print(f"\n--- Final Statistics (Containment Check) ---")
print(f"Total unique queries processed (based on output file entries): {total_unique_queries_evaluated}")
print(f"Number of queries with at least one contained passage (YES): {num_queries_with_any_contained}")
print(f"Containment Match Rate @ Top 10 (Any YES): {containment_rate_at_10:.4f}")
print(f"Total query-passage pairs processed (lines in output file): {total_evaluated_pairs}")
if error_count > 0:
print(f"Number of evaluation errors or non-YES/NO results: {error_count}")
else:
print("No evaluation results found to summarize.")
except FileNotFoundError:
print(f"Error: Output file {OUTPUT_EVALUATION_FILE} not found for summary.")
except Exception as e:
print(f"An unexpected error occurred during summary calculation: {e}")
print(f"\nDetailed containment evaluations saved to: {OUTPUT_EVALUATION_FILE}")
if __name__ == "__main__":
# Dummy files for testing the line sync logic
if not os.path.exists(QUERIES_FILE_PATH):
print(f"Warning: {QUERIES_FILE_PATH} not found. Creating dummy file.")
with open(QUERIES_FILE_PATH, 'w', encoding='utf-8') as f:
json.dump({"id": "q_alpha", "query": "Query Alpha Text", "ground_truth_message_ids": ["<msg_A>"]}, f); f.write("\n") # Line 0
json.dump({"id": "q_beta", "query": "Query Beta Text", "ground_truth_message_ids": ["<msg_B>"]}, f); f.write("\n") # Line 1
json.dump({"id": "q_gamma", "query": "Query Gamma Text", "ground_truth_message_ids": ["<msg_C>"]}, f); f.write("\n")# Line 2
if not os.path.exists(RAW_PASSAGES_FILE_PATH):
print(f"Warning: {RAW_PASSAGES_FILE_PATH} not found. Creating dummy file.")
with open(RAW_PASSAGES_FILE_PATH, 'w', encoding='utf-8') as f:
json.dump({"text": "Content from message A snippet 1.", "id": 100, "message_id": "<msg_A>"}, f); f.write("\n")
json.dump({"text": "Content from message A snippet 2.", "id": 101, "message_id": "<msg_A>"}, f); f.write("\n")
json.dump({"text": "Content from message B.", "id": 200, "message_id": "<msg_B>"}, f); f.write("\n")
json.dump({"text": "Content from message D (unrelated).", "id": 300, "message_id": "<msg_D>"}, f); f.write("\n")
# RESULTS_FILE_PATH should have results corresponding line-by-line to QUERIES_FILE_PATH
if not os.path.exists(RESULTS_FILE_PATH):
print(f"Warning: {RESULTS_FILE_PATH} not found. Creating dummy file (2 entries).")
with open(RESULTS_FILE_PATH, 'w', encoding='utf-8') as f:
# Result for query "q_alpha" (line 0 in queries file)
json.dump({"query_id": "this_can_be_ignored_if_line_sync", "passages": [{"id": 101, "text": "Content from message A snippet 2."}, {"id": 300, "text": "Content from message D (unrelated)."}]}, f); f.write("\n")
# Result for query "q_beta" (line 1 in queries file)
json.dump({"query_id": "this_too", "passages": [{"id": 999, "text": "Some other text."}, {"id": 200, "text": "Content from message B."}]}, f); f.write("\n")
# Note: Only 2 result sets, but 3 queries in dummy QUERIES_FILE_PATH.
# The script will process min(len(queries_list), len(search_results_list)) if you uncomment that logic,
# or just len(search_results_list) as it's currently written for tqdm.
main()

View File

@@ -1,44 +0,0 @@
# Recompute Embeddings Saved
```console
python ./demo/main.py --mode serve --engine sglang --load-indices diskann --port 8082 --domain rpj_wiki --lazy --recompute --dedup --use-partition
python ./demo/embedding_server.py --domain rpj_wiki
python ./demo/test_serve.py --port 8082 --nprobe 80 --re --dedup
```
Result:
```
Evaluation Results for nprobe = 80:
Final Recall Rate: 0.9333
Average total latency: 2.427s
Average search time: 2.414s
```
其中use-partition也可以不加也可以跑。不加的效果如下
```
Results for nprobe = 80:
Final Recall Rate: 0.9333
Average total latency: 2.434s
Average search time: 2.421s
```
# Recompute Embeddings + Loading from disk
Remove `--dedup --use-partition`
```console
python ./demo/main.py --mode serve --engine sglang --load-indices diskann --port 8082 --domain rpj_wiki --lazy --recompute
python ./demo/embedding_server.py --domain rpj_wiki
python ./demo/test_serve.py --port 8082 --nprobe 80 --re
```
Result:
```
Evaluation Results for nprobe = 80:
Evaluation Results for nprobe = 80:
Average F1 Score: 0.5708
Average Exact Match Score: 0.4500
Average Recall Rate: 0.9333
Average total latency: 3.709s
Average search time: 3.696s
```

View File

@@ -1,599 +0,0 @@
import os
import math
import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import json
import pickle
import pdb
"""
Automatic result extraction for BM25.
"""
def extract_data_to_table(directory_path):
# Regular expression pattern to match the data format in file content
content_pattern = (
r"# tokens: (\d+(\.\d+)?)\tLM PPL: (\d+(\.\d+)?)\tPPL: (\d+(\.\d+)?)"
)
# Regular expression pattern to extract info from file names
file_name_pattern_M = r"(.+)-(\d+)M-seed_(\d+).txt"
file_name_pattern = r"(.+)-(\d+)-seed_(\d+).txt"
# Data storage
data = []
# Iterating through each file in the directory
for file_name in os.listdir(directory_path):
# Checking if the file name matches the pattern
file_match_M = re.match(file_name_pattern_M, file_name)
file_match = re.match(file_name_pattern, file_name)
if file_match_M:
domain, num_samples, seed = file_match_M.groups()
# Reading the file and extracting data
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r") as file:
for line in file:
# Searching for the pattern in each line
content_match = re.search(content_pattern, line)
if content_match:
# Extracting values
tokens, lm_ppl, ppl = (
content_match.groups()[0],
content_match.groups()[2],
content_match.groups()[4],
)
# Adding the extracted data and extra info to the list
data.append(
{
"Domain": domain,
"Samples": int(num_samples) * 1e6,
"Seed": int(seed),
"#eval_tokens": float(tokens),
"LM_PPL": float(lm_ppl),
"PPL": float(ppl),
}
)
elif file_match:
domain, num_samples, seed = file_match.groups()
# Reading the file and extracting data
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r") as file:
for line in file:
# Searching for the pattern in each line
content_match = re.search(content_pattern, line)
if content_match:
# Extracting values
tokens, lm_ppl, ppl = (
content_match.groups()[0],
content_match.groups()[2],
content_match.groups()[4],
)
# Adding the extracted data and extra info to the list
data.append(
{
"Domain": domain,
"Samples": int(num_samples),
"Seed": int(seed),
"#eval_tokens": float(tokens),
"LM_PPL": float(lm_ppl),
"PPL": float(ppl),
}
)
df = pd.DataFrame(data)
grouped_df = df.groupby(["Domain", "Samples", "#eval_tokens"]).mean()
return df, grouped_df
"""
Automatic resutls extraction for dense retrieval. (new)
"""
def extract_dense_scaling_results(log_files, domain=None, plot=None):
# Regular expression pattern to match the key-value pairs in the input string
pattern = r"(\w[\w #]+) = ([\w.]+)"
data_list = []
for file in log_files:
with open(file, "r") as file:
for line in file:
# Use re.findall to extract all matches of the pattern
matches = re.findall(pattern, line)
if matches:
data_dict = {
key.replace(" ", "_").lower(): (
None
if value == "None"
else float(value)
if value.replace(".", "", 1).isdigit()
else value
)
for key, value in matches
}
data_list.append(data_dict)
df = pd.DataFrame(data_list)
if "total_shards" in df.columns:
df["subsample_ratio"] = df["sampled_shards"] / df["total_shards"]
else:
df["subsample_ratio"] = 1 / df["total_shards"]
df = df.sort_values(by="subsample_ratio")
print(df.head)
if plot:
# Setting the plot size for better visibility
plt.figure(figsize=(10, 6))
# Plotting
for concate_k in df["concate_k"].unique():
subset = df[df["concate_k"] == concate_k]
if concate_k == 0:
perplexity_when_concate_k_0 = subset["perplexity"].mean()
plt.axhline(
y=perplexity_when_concate_k_0,
color="r",
linestyle="-",
label="Closed-book",
)
else:
plt.plot(
subset["subsample_ratio"],
subset["perplexity"],
label=f"Concate_k = {concate_k}",
)
plt.title(f"Perplexity Change with Total Shards -- {domain}")
plt.xlabel("Subsample Ratio")
plt.ylabel("Perplexity")
plt.legend()
plt.grid(True)
plt.savefig(plot)
return df
def plot_mmlu():
# C4 results
labels = [
"LM-only",
"top-1 w/ 1/32 C4 datastore",
"top-1 w/ 2/32 C4 datastore",
"top-1 w/ 3/32 C4 datastore",
"top-1 w/ 4/32 C4 datastore",
"top-1 w/ 5/32 C4 datastore",
"top-1 w/ 6/32 C4 datastore",
]
x = [0, 1, 2, 3, 4, 5, 6]
few_shot_0_concat_1 = [30.69, 32.81, 32.05, 32.55, 32.57, 33.03, 32.88]
few_shot_1_concat_1 = [39.67, 41.03, 41.74, 42.1, 42.62, 41.55, 42.09]
few_shot_5_concat_1 = [42.47, 43.75, 44.37, 44.1, 44.84, 43.95, 44.49]
# Plotting the data
plt.figure(figsize=(14, 8))
# Plot for few_shot_0_concat_1
plt.plot(
x,
few_shot_0_concat_1,
marker="o",
linestyle="-",
color="blue",
label="Few-shot k=0, Concat k=1",
)
# Plot for few_shot_1_concat_1
plt.plot(
x,
few_shot_1_concat_1,
marker="s",
linestyle="-",
color="red",
label="Few-shot k=1, Concat k=1",
)
# Plot for few_shot_5_concat_1
plt.plot(
x,
few_shot_5_concat_1,
marker="^",
linestyle="-",
color="green",
label="Few-shot k=5, Concat k=1",
)
# Adding details
plt.title("MMLU Performance")
plt.xlabel("Retrieval-based LM Datastore Composition")
plt.ylabel("Accuracy")
plt.xticks(ticks=x, labels=labels, rotation=45, ha="right")
plt.legend()
plt.tight_layout()
plt.grid(True)
plt.savefig("mmlu_c4_scaling.png")
def extract_lm_eval_results(
result_dir, task_name, model_name, n_shot_list, n_doc_list, datastore_name_filter=""
):
markers = ["o", "s", "^", "D", "*", "p", "H", "x"]
colors = plt.cm.tab20.colors
all_data = []
for subdir, dirs, files in os.walk(result_dir):
num_ints = len(os.path.basename(subdir).split("-"))
for file in files:
if file.endswith(".jsonl"):
file_path = os.path.join(subdir, file)
with open(file_path, "r") as f:
for line in f:
data = json.loads(line)
data["SubdirLevel"] = num_ints
data["n-shot"], data["n-doc"] = (
int(data["n-shot"]),
int(data["n-doc"]),
)
data["Value"] = float(data["Value"])
all_data.append(data)
filtered_data = [
d
for d in all_data
if datastore_name_filter in result_dir
and d["n-shot"] in n_shot_list
and d["n-doc"] in n_doc_list
and d["SubdirLevel"] > 0
]
plot_data = {}
for d in filtered_data:
key = (d["n-shot"], d["n-doc"])
plot_data.setdefault(key, []).append((d["SubdirLevel"], d["Value"]))
sorted_keys = sorted(plot_data.keys(), key=lambda x: (x[0], x[1]))
closed_book_values = {}
for i, key in enumerate(sorted_keys):
n_shot, n_doc = key
if n_doc == 0:
value = plot_data[key][-1][-1]
closed_book_values.update({n_shot: value})
plt.figure(figsize=(15, 10))
for i, key in enumerate(sorted_keys):
n_shot, n_doc = key
if n_doc == 0:
continue
values = plot_data[key]
values.append(
(0, closed_book_values[n_shot])
if n_shot in closed_book_values.keys()
else (0, None)
)
values.sort() # Ensure the values are sorted by SubdirLevel
x_values, y_values = zip(*values) # Unzip the tuple pairs to separate lists
marker = markers[n_shot] if n_doc else ""
color = colors[i % len(colors)] # Choose a color from the colormap
label = f"n-shot={n_shot}, n-doc={n_doc}"
plt.plot(
x_values, y_values, marker=marker, color=color, linestyle="-", label=label
)
# plt.gca().yaxis.set_major_locator(ticker.MaxNLocator(nbins='auto', steps=[1, 2, 5, 10]))
if subject_name == "mmlu":
plot_dir = os.path.join("plots", "mmlu")
else:
plot_dir = "plots"
os.makedirs(plot_dir, exist_ok=True)
plt.xlabel("Number of Index Shards")
plt.ylabel("Accuracy")
plt.title(f"{task_name} scaling performance with {model_name}")
plt.legend()
plt.grid(True)
plt.savefig(f"{plot_dir}/{task_name}_{model_name}.png")
return all_data
def plot_mmlu_persub_figures(directory="plots"):
files = [
file
for file in os.listdir(directory)
if file.startswith("mmlu_") and file.endswith(".png")
]
plots_per_figure = 16
for i in range(0, len(files), plots_per_figure):
# Create a new figure
fig, axs = plt.subplots(4, 4, figsize=(20, 20))
# Flatten the axis array for easy indexing
axs = axs.flatten()
# Iterate over each subplot in the current figure
for ax, file in zip(axs, files[i : i + plots_per_figure]):
# Read the image file
img = plt.imread(os.path.join(directory, file))
# Display the image in the subplot
ax.imshow(img)
ax.set_title(file)
ax.axis("off") # Hide axes
# Adjust layout and display the figure
plt.tight_layout()
plt.savefig(f"mmlu_persub_{i}.png")
def plot_calibration_figures(domain, shard_id=8, show_ci=True, show_all_points=False):
if show_all_points:
show_ci = False
data_path = f"out_calibration/{shard_id}_shard_{domain}/calibration_results_decon_rpj_{domain}_None_samples.pkl"
with open(data_path, "rb") as file:
all_results = pickle.load(file)
all_lm_losses = [item[0] for item in all_results]
all_retrieval_scores = [item[1] for item in all_results]
print(f"Total {len(all_lm_losses)} examples.")
# Compute PPL of top-1 doc v.s. golden doc from top-100
losses_top1 = [losses[0] for losses in all_lm_losses]
avg_losses_top1 = sum(losses_top1) / len(losses_top1)
ppl_losses_top1 = math.exp(avg_losses_top1)
lossed_top100_gold = [min(losses) for losses in all_lm_losses]
avg_losses_top100_gold = sum(lossed_top100_gold) / len(lossed_top100_gold)
ppl_lossed_top100_gold = math.exp(avg_losses_top100_gold)
print(
f"Top-1 doc PPL: {ppl_losses_top1:.4f}\nGold doc from top-100 PPL: {ppl_lossed_top100_gold:.4f}"
)
# Calibration plot
lm_losses = np.array(all_lm_losses)
retrieval_scores = np.array(all_retrieval_scores)
from scipy.special import softmax
import scipy.stats as stats
softmax_lm_losses = softmax(lm_losses, axis=1)
softmax_retrieval_scores = softmax(retrieval_scores, axis=1)
if show_all_points:
lm_losses = lm_losses.flatten()
retrieval_scores = retrieval_scores.flatten()
plt.figure(figsize=(8, 6))
plt.plot(lm_losses, retrieval_scores, marker="o", linestyle="")
plt.title(f"Calibration Curve with {shard_id} Shards")
plt.xlabel("LM Losses")
plt.ylabel("Retrieval Scores")
plt.grid(True)
plt.savefig(f"out_calibration/calibration_all_{shard_id}_shard_{domain}.png")
elif show_ci:
lm_losses_mean = np.mean(lm_losses, axis=0)
retrieval_scores_mean = np.mean(retrieval_scores, axis=0)
lm_losses_sem = stats.sem(lm_losses, axis=0)
retrieval_scores_sem = stats.sem(retrieval_scores, axis=0)
# Assuming a 95% confidence interval, z-score is approximately 1.96 for a normal distribution
z_score = 1.96
losses_ci = lm_losses_sem * z_score
retrieval_ci = retrieval_scores_sem * z_score
plt.figure(figsize=(10, 6))
plt.errorbar(
lm_losses_mean,
retrieval_scores_mean,
xerr=losses_ci,
yerr=retrieval_ci,
fmt="o",
ecolor="lightgray",
alpha=0.5,
capsize=5,
)
plt.xlabel("LM Losses")
plt.ylabel("Retrieval Scores")
plt.title(
f"Calibration plot for {shard_id}-shard {domain} with Confidence Intervals"
)
plt.grid(True)
plt.savefig(f"out_calibration/calibration_ci_{shard_id}_shard_{domain}.png")
else:
lm_losses = np.mean(lm_losses, axis=0)
retrieval_scores = np.mean(retrieval_scores, axis=0)
plt.figure(figsize=(8, 6))
plt.plot(lm_losses, retrieval_scores, marker="o", linestyle="")
plt.title(f"Calibration Curve with {shard_id} Shards")
plt.xlabel("LM Losses")
plt.ylabel("Retrieval Scores")
plt.grid(True)
plt.savefig(f"out_calibration/calibration_{shard_id}_shard_{domain}.png")
return ppl_losses_top1, ppl_lossed_top100_gold, all_lm_losses, all_retrieval_scores
def plot_top1_vs_best_doc(domain, total_shards=8):
lm_only_ppl = {
"books": 21.5250,
"stackexchange": 11.5948,
"wiki": 14.0729,
}
top1_losses, best_losses = [], []
for shard_id in range(1, total_shards + 1):
top1_loss, best_loss, _, _ = plot_calibration_figures(domain, shard_id)
top1_losses.append(top1_loss)
best_losses.append(best_loss)
x = [i for i in range(1, total_shards + 1)]
plt.figure(figsize=(10, 6))
# Plotting
if lm_only_ppl[domain]:
plt.axhline(
y=lm_only_ppl[domain], color="r", linestyle="-", label="Closed-book"
)
plt.plot(x, top1_losses, label=f"Top-1 Doc")
plt.plot(x, best_losses, label=f"Gold Doc")
plt.title(f"Perplexity Change with Total Shards")
plt.xlabel("Number of Shards")
plt.ylabel("Perplexity")
plt.legend()
plt.grid(True)
plt.savefig(f"best_plot_{domain}.png")
def plot_top1_vs_best_doc_per_sample(domain, shard_id, show_top_k=10, special_mark_k=0):
_, _, all_lm_losses, all_retrieval_scores = plot_calibration_figures(
domain, shard_id
)
all_sorted_lm_losses, all_sorted_retrieval_scores = [], []
for lm_losses, retrieval_scores in zip(all_retrieval_scores, all_lm_losses):
sorted_scores, sorted_losses = zip(
*sorted(zip(retrieval_scores, lm_losses), reverse=True)
)
all_sorted_lm_losses.append(sorted_losses)
all_sorted_retrieval_scores.append(sorted_scores)
num_samples = len(all_lm_losses)
x = [i for i in range(num_samples)]
plt.figure(figsize=(25, 6))
# Plotting
for i in range(show_top_k - 1, -1, -1):
plt.plot(
x,
[losses[i] for losses in all_sorted_lm_losses],
label=f"Top-{i + 1}th Doc",
marker="x" if i == special_mark_k else "o",
linestyle="",
)
plt.title(f"Per-sample Loss of {domain} with 1 retrieved doc")
plt.xlabel("Index of the Evaluation Sample")
plt.ylabel("Loss")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.grid(True)
plt.savefig(f"per_sample_{domain}.png")
def compute_variance_across_hards(path, n_shot=5, n_doc=3):
all_data = []
for subdir, dirs, files in os.walk(path):
num_ints = len(os.path.basename(subdir).split("-"))
for file in files:
if file.endswith(".jsonl"):
file_path = os.path.join(subdir, file)
with open(file_path, "r") as f:
for line in f:
data = json.loads(line)
data["SubdirLevel"] = num_ints
data["n-shot"], data["n-doc"] = (
int(data["n-shot"]),
int(data["n-doc"]),
)
data["Value"] = float(data["Value"])
all_data.append(data)
plot_data = {}
for d in all_data:
key = (d["n-shot"], d["n-doc"])
plot_data.setdefault(key, []).append((d["SubdirLevel"], d["Value"]))
files_end = [d.split("/")[-1] for d, _, _ in os.walk(path)]
shard_ids = [int(i) for i in files_end[1:]]
key = n_shot, n_doc
values = plot_data[key]
_, y_values = zip(*values)
plt.figure(figsize=(10, 6))
try:
plt.plot(shard_ids, y_values, marker="o", linestyle="")
except:
print(f"mismatched size for {key}: {len(shard_ids)}, {len(y_values)}")
print(y_values)
print(f"Saving to {f'per_sample_{files_end[0]}.png'}")
plt.xlabel("Single-shard Index ID")
plt.ylabel("PPL")
plt.grid(True)
plt.savefig(f"per_sample_{files_end[0]}.png")
if __name__ == "__main__":
# # Replace with your directory path
# directory_path = "out/2023_dec_25_single_domain"
# # Extracting data to a table with additional information
# df, grouped_df = extract_data_to_table(directory_path)
# print(grouped_df)
# print(grouped_df.index.get_level_values("Samples (M)").to_numpy())
plot_info_list = [
# {'logfile': 'rpj_c4.log', 'domain': 'rpj-c4', 'plot': 'scaling_c4_single_index_plot.png'},
# {'logfile': 'rpj_arxiv.log', 'domain': 'rpj-arxiv', 'plot': 'scaling_arxiv_plot.png'},
# {'logfile': 'rpj_book_scaling.log', 'domain': 'rpj-book', 'plot': 'scaling_book_plot.png'},
# {'logfile': 'rpj_github_scaling.log', 'domain': 'rpj-github', 'plot': 'scaling_github_plot.png'},
# {'logfile': 'rpj_stackexchange_scaling.log', 'domain': 'rpj-stackexchange', 'plot': 'scaling_stackexchange_plot.png'},
# {'logfile': 'rpj_wiki.log', 'domain': 'rpj-wiki', 'plot': 'scaling_wiki_plot.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_wiki_contriever_ppl.log', 'domain': 'rpj-wiki-decon-contriever', 'plot': 'scaling_wiki_decon_plot_contriever.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_book_contriever_ppl.log', 'domain': 'rpj-book-decon-contriever', 'plot': 'scaling_book_decon_plot_contriever.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_arxiv_contriever_ppl.log', 'domain': 'rpj-arxiv-decon-contriever', 'plot': 'scaling_arxiv_decon_plot_contriever.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_stackexchange_contriever_ppl.log', 'domain': 'rpj-stackexchange-decon-contriever', 'plot': 'scaling_stackexchange_decon_plot_contriever.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_stackexchange_dragon_ppl.log', 'domain': 'rpj-stackexchange-decon-dragon', 'plot': 'scaling_stackexchange_decon_plot_dragon.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_wiki_dragon_ppl.log', 'domain': 'rpj-wiki-decon-dragon', 'plot': 'scaling_wiki_decon_plot_dragon.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_arxiv_dragon_ppl.log', 'domain': 'rpj-arxiv-decon-dragon', 'plot': 'scaling_arxiv_decon_plot_dragon.png'},
# {'logfile': 'out/2024_apr_decon/decon_rpj_book_dragon_ppl.log', 'domain': 'rpj-book-decon-dragon', 'plot': 'scaling_book_decon_plot_dragon.png'},
]
# for plot_info in plot_info_list:
# extract_dense_scaling_results([plot_info['logfile']], plot_info['domain'], plot_info['plot'])
model_name = "lclm"
subject_name = "gsm8k"
datastore_name = "c4"
result_dir = f"/gscratch/zlab/rulins/Scaling/lm_eval_results/{model_name}"
all_subjects = [
file
for file in os.listdir(result_dir)
if subject_name in file and datastore_name in file
]
for subject in all_subjects:
file_name = subject
print(file_name)
extract_lm_eval_results(
os.path.join(result_dir, file_name),
subject,
model_name,
[0, 5], # few-shot
[0, 3], # n-doc
file_name,
)
# plot_mmlu_persub_figures("plots/mmlu")
# compute_variance_across_hards(f'/gscratch/zlab/rulins/Scaling/lm_eval_results/llama2-7b/subsample/nq_open-rpj_c4-32_shards')
# compute_variance_across_hards(f'/gscratch/zlab/rulins/Scaling/lm_eval_results/llama2-7b/subsample/medqa_4options-rpj_c4-32_shards')
# plot_calibration_figures(domain='wiki', shard_id=1, show_all_points=True)
# plot_top1_vs_best_doc_per_sample(domain='stackexchange', shard_id=1, show_top_k=10, special_mark_k=0)

Some files were not shown because too many files have changed in this diff Show More